Skip to content

Commit

Permalink
Merge pull request #304 from xiaods/dev
Browse files Browse the repository at this point in the history
Dev for servicelb
  • Loading branch information
xiaods authored Feb 23, 2023
2 parents 98ed533 + 96548ec commit ed12134
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 36 deletions.
45 changes: 9 additions & 36 deletions pkg/cloudprovider/servicelb.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cloudprovider

import (
"context"
"errors"
"fmt"
"sort"
"strconv"
Expand Down Expand Up @@ -375,11 +374,11 @@ func (k *k8e) podIPs(pods []*core.Pod, svc *core.Service, readyNodes map[string]
return ips, nil
}

// filterByIPFamily filters ips based on dual-stack parameters of the service
// filterByIPFamily filters node IPs based on dual-stack parameters of the service
func filterByIPFamily(ips []string, svc *core.Service) ([]string, error) {
var ipFamilyPolicy core.IPFamilyPolicyType
var ipv4Addresses []string
var ipv6Addresses []string
var allAddresses []string

for _, ip := range ips {
if utilsnet.IsIPv4String(ip) {
Expand All @@ -390,42 +389,16 @@ func filterByIPFamily(ips []string, svc *core.Service) ([]string, error) {
}
}

if svc.Spec.IPFamilyPolicy != nil {
ipFamilyPolicy = *svc.Spec.IPFamilyPolicy
}

switch ipFamilyPolicy {
case core.IPFamilyPolicySingleStack:
if svc.Spec.IPFamilies[0] == core.IPv4Protocol {
return ipv4Addresses, nil
}
if svc.Spec.IPFamilies[0] == core.IPv6Protocol {
return ipv6Addresses, nil
}
case core.IPFamilyPolicyPreferDualStack:
if svc.Spec.IPFamilies[0] == core.IPv4Protocol {
ipAddresses := append(ipv4Addresses, ipv6Addresses...)
return ipAddresses, nil
}
if svc.Spec.IPFamilies[0] == core.IPv6Protocol {
ipAddresses := append(ipv6Addresses, ipv4Addresses...)
return ipAddresses, nil
}
case core.IPFamilyPolicyRequireDualStack:
if (len(ipv4Addresses) == 0) || (len(ipv6Addresses) == 0) {
return nil, errors.New("one or more IP families did not have addresses available for service with ipFamilyPolicy=RequireDualStack")
}
if svc.Spec.IPFamilies[0] == core.IPv4Protocol {
ipAddresses := append(ipv4Addresses, ipv6Addresses...)
return ipAddresses, nil
}
if svc.Spec.IPFamilies[0] == core.IPv6Protocol {
ipAddresses := append(ipv6Addresses, ipv4Addresses...)
return ipAddresses, nil
for _, ipFamily := range svc.Spec.IPFamilies {
switch ipFamily {
case core.IPv4Protocol:
allAddresses = append(allAddresses, ipv4Addresses...)
case core.IPv6Protocol:
allAddresses = append(allAddresses, ipv6Addresses...)
}
}

return nil, errors.New("unhandled ipFamilyPolicy")
return allAddresses, nil
}

// deployDaemonSet ensures that there is a DaemonSet for the service.
Expand Down
91 changes: 91 additions & 0 deletions pkg/cloudprovider/servicelb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package cloudprovider

import (
"reflect"
"testing"

core "k8s.io/api/core/v1"
)

const (
addrv4 = "1.2.3.4"
addrv6 = "2001:db8::1"
)

func Test_UnitFilterByIPFamily(t *testing.T) {
type args struct {
ips []string
svc *core.Service
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
name: "No IPFamily",
args: args{
ips: []string{addrv4, addrv6},
svc: &core.Service{
Spec: core.ServiceSpec{
IPFamilies: []core.IPFamily{},
},
},
},
want: nil,
wantErr: false,
},
{
name: "IPv4 Only",
args: args{
ips: []string{addrv4, addrv6},
svc: &core.Service{
Spec: core.ServiceSpec{
IPFamilies: []core.IPFamily{core.IPv4Protocol},
},
},
},
want: []string{addrv4},
wantErr: false,
},
{
name: "IPv6 Only",
args: args{
ips: []string{addrv4, addrv6},
svc: &core.Service{
Spec: core.ServiceSpec{
IPFamilies: []core.IPFamily{core.IPv6Protocol},
},
},
},
want: []string{addrv6},
wantErr: false,
},
{
name: "Dual-Stack",
args: args{
ips: []string{addrv4, addrv6},
svc: &core.Service{
Spec: core.ServiceSpec{
IPFamilies: []core.IPFamily{core.IPv4Protocol, core.IPv6Protocol},
},
},
},
want: []string{addrv4, addrv6},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := filterByIPFamily(tt.args.ips, tt.args.svc)
if (err != nil) != tt.wantErr {
t.Errorf("filterByIPFamily() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("filterByIPFamily() = %+v\nWant = %+v", got, tt.want)
}
})
}
}

0 comments on commit ed12134

Please sign in to comment.