From d032bbe9178608566afaad1abf554712675bb7c9 Mon Sep 17 00:00:00 2001 From: Frederic Mereu Date: Fri, 9 Feb 2024 08:35:58 +0100 Subject: [PATCH 1/2] feat: add support for additionalControlPlaneIngressRule on AWSManagedControlPlane --- pkg/cloud/scope/managedcontrolplane.go | 2 +- .../services/securitygroup/securitygroups.go | 8 +- .../securitygroup/securitygroups_test.go | 145 +++++++++++++++--- 3 files changed, 131 insertions(+), 24 deletions(-) diff --git a/pkg/cloud/scope/managedcontrolplane.go b/pkg/cloud/scope/managedcontrolplane.go index df64e87f44..14ded92263 100644 --- a/pkg/cloud/scope/managedcontrolplane.go +++ b/pkg/cloud/scope/managedcontrolplane.go @@ -472,7 +472,7 @@ func (s *ManagedControlPlaneScope) Partition() string { // AdditionalControlPlaneIngressRules returns the additional ingress rules for the control plane security group. func (s *ManagedControlPlaneScope) AdditionalControlPlaneIngressRules() []infrav1.IngressRule { - return nil + return s.ControlPlane.Spec.NetworkSpec.DeepCopy().AdditionalControlPlaneIngressRules } // UnstructuredControlPlane returns the unstructured object for the control plane, if any. diff --git a/pkg/cloud/services/securitygroup/securitygroups.go b/pkg/cloud/services/securitygroup/securitygroups.go index bf30f60824..484f04350b 100644 --- a/pkg/cloud/services/securitygroup/securitygroups.go +++ b/pkg/cloud/services/securitygroup/securitygroups.go @@ -682,12 +682,12 @@ func (s *Service) getSecurityGroupIngressRules(role infrav1.SecurityGroupRole) ( } return append(cniRules, rules...), nil case infrav1.SecurityGroupEKSNodeAdditional: + rules := infrav1.IngressRules{} if s.scope.Bastion().Enabled { - return infrav1.IngressRules{ - s.defaultSSHIngressRule(s.scope.SecurityGroups()[infrav1.SecurityGroupBastion].ID), - }, nil + rules = append(rules, s.defaultSSHIngressRule(s.scope.SecurityGroups()[infrav1.SecurityGroupBastion].ID)) } - return infrav1.IngressRules{}, nil + ingressRules := s.scope.AdditionalControlPlaneIngressRules() + return append(rules, ingressRules...), nil case infrav1.SecurityGroupAPIServerLB: kubeletRules := s.getIngressRulesToAllowKubeletToAccessTheControlPlaneLB() customIngressRules, err := s.processIngressRulesSGs(s.getControlPlaneLBIngressRules()) diff --git a/pkg/cloud/services/securitygroup/securitygroups_test.go b/pkg/cloud/services/securitygroup/securitygroups_test.go index f522104210..4f39ac2624 100644 --- a/pkg/cloud/services/securitygroup/securitygroups_test.go +++ b/pkg/cloud/services/securitygroup/securitygroups_test.go @@ -18,6 +18,7 @@ package securitygroup import ( "context" + "reflect" "strings" "testing" @@ -34,6 +35,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" infrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/api/v1beta2" + ekscontrolplanev1 "sigs.k8s.io/cluster-api-provider-aws/v2/controlplane/eks/api/v1beta2" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/awserrors" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/filter" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope" @@ -1192,11 +1194,11 @@ func TestAdditionalControlPlaneSecurityGroup(t *testing.T) { _ = infrav1.AddToScheme(scheme) testCases := []struct { - name string - networkSpec infrav1.NetworkSpec - networkStatus infrav1.NetworkStatus - expectedAdditionalIngresRule infrav1.IngressRule - wantErr bool + name string + networkSpec infrav1.NetworkSpec + networkStatus infrav1.NetworkStatus + expectedAdditionalIngressRule infrav1.IngressRule + wantErr bool }{ { name: "default control plane security group is used", @@ -1220,7 +1222,7 @@ func TestAdditionalControlPlaneSecurityGroup(t *testing.T) { }, }, }, - expectedAdditionalIngresRule: infrav1.IngressRule{ + expectedAdditionalIngressRule: infrav1.IngressRule{ Description: "test", Protocol: infrav1.SecurityGroupProtocolTCP, FromPort: 9345, @@ -1251,7 +1253,7 @@ func TestAdditionalControlPlaneSecurityGroup(t *testing.T) { }, }, }, - expectedAdditionalIngresRule: infrav1.IngressRule{ + expectedAdditionalIngressRule: infrav1.IngressRule{ Description: "test", Protocol: infrav1.SecurityGroupProtocolTCP, FromPort: 9345, @@ -1282,7 +1284,7 @@ func TestAdditionalControlPlaneSecurityGroup(t *testing.T) { }, }, }, - expectedAdditionalIngresRule: infrav1.IngressRule{ + expectedAdditionalIngressRule: infrav1.IngressRule{ Description: "test", Protocol: infrav1.SecurityGroupProtocolTCP, FromPort: 9345, @@ -1314,7 +1316,7 @@ func TestAdditionalControlPlaneSecurityGroup(t *testing.T) { }, }, }, - expectedAdditionalIngresRule: infrav1.IngressRule{ + expectedAdditionalIngressRule: infrav1.IngressRule{ Description: "test", Protocol: infrav1.SecurityGroupProtocolTCP, FromPort: 9345, @@ -1345,7 +1347,7 @@ func TestAdditionalControlPlaneSecurityGroup(t *testing.T) { }, }, }, - expectedAdditionalIngresRule: infrav1.IngressRule{ + expectedAdditionalIngressRule: infrav1.IngressRule{ Description: "test", Protocol: infrav1.SecurityGroupProtocolTCP, FromPort: 9345, @@ -1376,7 +1378,7 @@ func TestAdditionalControlPlaneSecurityGroup(t *testing.T) { }, NatGatewaysIPs: []string{"test-ip"}, }, - expectedAdditionalIngresRule: infrav1.IngressRule{ + expectedAdditionalIngressRule: infrav1.IngressRule{ Description: "test", Protocol: infrav1.SecurityGroupProtocolTCP, CidrBlocks: []string{"test-ip/32"}, @@ -1437,20 +1439,125 @@ func TestAdditionalControlPlaneSecurityGroup(t *testing.T) { } found = true - if r.Protocol != tc.expectedAdditionalIngresRule.Protocol { - t.Fatalf("Expected protocol %s, got %s", tc.expectedAdditionalIngresRule.Protocol, r.Protocol) + if r.Protocol != tc.expectedAdditionalIngressRule.Protocol { + t.Fatalf("Expected protocol %s, got %s", tc.expectedAdditionalIngressRule.Protocol, r.Protocol) } - if r.FromPort != tc.expectedAdditionalIngresRule.FromPort { - t.Fatalf("Expected from port %d, got %d", tc.expectedAdditionalIngresRule.FromPort, r.FromPort) + if r.FromPort != tc.expectedAdditionalIngressRule.FromPort { + t.Fatalf("Expected from port %d, got %d", tc.expectedAdditionalIngressRule.FromPort, r.FromPort) } - if r.ToPort != tc.expectedAdditionalIngresRule.ToPort { - t.Fatalf("Expected to port %d, got %d", tc.expectedAdditionalIngresRule.ToPort, r.ToPort) + if r.ToPort != tc.expectedAdditionalIngressRule.ToPort { + t.Fatalf("Expected to port %d, got %d", tc.expectedAdditionalIngressRule.ToPort, r.ToPort) } - if !sets.New(tc.expectedAdditionalIngresRule.SourceSecurityGroupIDs...).Equal(sets.New(r.SourceSecurityGroupIDs...)) { - t.Fatalf("Expected source security group IDs %v, got %v", tc.expectedAdditionalIngresRule.SourceSecurityGroupIDs, r.SourceSecurityGroupIDs) + if !sets.New[string](tc.expectedAdditionalIngressRule.SourceSecurityGroupIDs...).Equal(sets.New[string](tc.expectedAdditionalIngressRule.SourceSecurityGroupIDs...)) { + t.Fatalf("Expected source security group IDs %v, got %v", tc.expectedAdditionalIngressRule.SourceSecurityGroupIDs, r.SourceSecurityGroupIDs) + } + } + + if !found { + t.Fatal("Additional ingress rule was not found") + } + }) + } +} + +func TestAdditionalManagedControlPlaneSecurityGroup(t *testing.T) { + scheme := runtime.NewScheme() + _ = ekscontrolplanev1.AddToScheme(scheme) + + testCases := []struct { + name string + networkSpec infrav1.NetworkSpec + expectedAdditionalIngressRule infrav1.IngressRule + }{ + { + name: "default control plane security group is used", + networkSpec: infrav1.NetworkSpec{ + AdditionalControlPlaneIngressRules: []infrav1.IngressRule{ + { + Description: "test", + Protocol: infrav1.SecurityGroupProtocolTCP, + FromPort: 9345, + ToPort: 9345, + }, + }, + }, + expectedAdditionalIngressRule: infrav1.IngressRule{ + Description: "test", + Protocol: infrav1.SecurityGroupProtocolTCP, + FromPort: 9345, + ToPort: 9345, + SourceSecurityGroupIDs: []string{"cp-sg-id"}, + }, + }, + { + name: "don't set source security groups if cidr blocks are set", + networkSpec: infrav1.NetworkSpec{ + AdditionalControlPlaneIngressRules: []infrav1.IngressRule{ + { + Description: "test", + Protocol: infrav1.SecurityGroupProtocolTCP, + FromPort: 9345, + ToPort: 9345, + CidrBlocks: []string{"test-cidr-block"}, + }, + }, + }, + expectedAdditionalIngressRule: infrav1.IngressRule{ + Description: "test", + Protocol: infrav1.SecurityGroupProtocolTCP, + FromPort: 9345, + ToPort: 9345, + CidrBlocks: []string{"test-cidr-block"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cs, err := scope.NewManagedControlPlaneScope(scope.ManagedControlPlaneScopeParams{ + Client: fake.NewClientBuilder().WithScheme(scheme).Build(), + Cluster: &clusterv1.Cluster{ + ObjectMeta: metav1.ObjectMeta{Name: "test-cluster"}, + }, + ControlPlane: &ekscontrolplanev1.AWSManagedControlPlane{ + Spec: ekscontrolplanev1.AWSManagedControlPlaneSpec{ + NetworkSpec: tc.networkSpec, + }, + Status: ekscontrolplanev1.AWSManagedControlPlaneStatus{ + Network: infrav1.NetworkStatus{ + SecurityGroups: map[infrav1.SecurityGroupRole]infrav1.SecurityGroup{ + infrav1.SecurityGroupControlPlane: { + ID: "cp-sg-id", + }, + infrav1.SecurityGroupNode: { + ID: "node-sg-id", + }, + }, + }, + }, + }, + }) + if err != nil { + t.Fatalf("Failed to create test context: %v", err) + } + + s := NewService(cs, testSecurityGroupRoles) + rules, err := s.getSecurityGroupIngressRules(infrav1.SecurityGroupControlPlane) + if err != nil { + t.Fatalf("Failed to lookup controlplane security group ingress rules: %v", err) + } + + found := false + for _, r := range rules { + if r.Description == "test" { + found = true + + if !reflect.DeepEqual(r, tc.expectedAdditionalIngressRule) { + t.Fatalf("Expected ingress rule %#v, got %#v", tc.expectedAdditionalIngressRule, r) + } } } From 16b98c361bc6e1c11f4b0def9276275679bbdaf9 Mon Sep 17 00:00:00 2001 From: Frederic Mereu Date: Thu, 24 Oct 2024 13:59:18 +0200 Subject: [PATCH 2/2] refactor: simplify ingressrules --- pkg/cloud/services/securitygroup/securitygroups.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pkg/cloud/services/securitygroup/securitygroups.go b/pkg/cloud/services/securitygroup/securitygroups.go index 484f04350b..d203f8b9ec 100644 --- a/pkg/cloud/services/securitygroup/securitygroups.go +++ b/pkg/cloud/services/securitygroup/securitygroups.go @@ -682,12 +682,11 @@ func (s *Service) getSecurityGroupIngressRules(role infrav1.SecurityGroupRole) ( } return append(cniRules, rules...), nil case infrav1.SecurityGroupEKSNodeAdditional: - rules := infrav1.IngressRules{} + ingressRules := s.scope.AdditionalControlPlaneIngressRules() if s.scope.Bastion().Enabled { - rules = append(rules, s.defaultSSHIngressRule(s.scope.SecurityGroups()[infrav1.SecurityGroupBastion].ID)) + ingressRules = append(ingressRules, s.defaultSSHIngressRule(s.scope.SecurityGroups()[infrav1.SecurityGroupBastion].ID)) } - ingressRules := s.scope.AdditionalControlPlaneIngressRules() - return append(rules, ingressRules...), nil + return ingressRules, nil case infrav1.SecurityGroupAPIServerLB: kubeletRules := s.getIngressRulesToAllowKubeletToAccessTheControlPlaneLB() customIngressRules, err := s.processIngressRulesSGs(s.getControlPlaneLBIngressRules())