From 89b1e5a2d423c5bd056ae5000fbe6a60ded159b3 Mon Sep 17 00:00:00 2001 From: Hakan Sariman Date: Thu, 30 Jan 2025 15:06:51 +0300 Subject: [PATCH] move validateDomains to management.domain pkg to reuse --- client/cmd/up.go | 25 ++++- client/internal/config.go | 9 +- client/server/server.go | 7 +- management/client/client.go | 3 +- management/client/grpc.go | 5 +- management/client/mock.go | 5 +- management/domain/validate.go | 40 ++++++++ management/domain/validate_test.go | 97 +++++++++++++++++++ .../http/handlers/routes/routes_handler.go | 39 +------- .../handlers/routes/routes_handler_test.go | 93 +----------------- 10 files changed, 180 insertions(+), 143 deletions(-) create mode 100644 management/domain/validate.go create mode 100644 management/domain/validate_test.go diff --git a/client/cmd/up.go b/client/cmd/up.go index 81a0eb986bf..39137b033df 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/util" ) @@ -105,6 +106,11 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { return err } + dnsLabelsConverted, err := validateDnsLabels(cmd.Flag(dnsLabelsFlag).Changed) + if err != nil { + return err + } + ic := internal.ConfigInput{ ManagementURL: managementURL, AdminURL: adminURL, @@ -112,6 +118,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { NATExternalIPs: natExternalIPs, CustomDNSAddress: customDNSAddressConverted, ExtraIFaceBlackList: extraIFaceBlackList, + DNSLabels: dnsLabelsConverted, } if cmd.Flag(enableRosenpassFlag).Changed { @@ -179,10 +186,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { ic.BlockLANAccess = &blockLANAccess } - if cmd.Flag(dnsLabelsFlag).Changed { - ic.DNSLabels = dnsLabels - } - providedSetupKey, err := getSetupKey() if err != nil { return err @@ -452,6 +455,20 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) { return parsed, nil } +func validateDnsLabels(modified bool) (domain.List, error) { + var ( + domains domain.List + err error + ) + if modified { + domains, err = domain.ValidateDomains(dnsLabels) + if err != nil { + return nil, fmt.Errorf("failed to validate dns labels: %v", err) + } + } + return domains, nil +} + func isValidAddrPort(input string) bool { if input == "" { return true diff --git a/client/internal/config.go b/client/internal/config.go index d490eaf3e32..d381ac87328 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/ssh" mgm "github.com/netbirdio/netbird/management/client" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/util" ) @@ -69,7 +70,7 @@ type ConfigInput struct { BlockLANAccess *bool - DNSLabels []string + DNSLabels domain.List } // Config Configuration type @@ -95,7 +96,7 @@ type Config struct { BlockLANAccess bool - DNSLabels []string + DNSLabels domain.List // SSHKey is a private SSH key in a PEM format SSHKey string @@ -495,8 +496,8 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { if input.DNSLabels != nil && !reflect.DeepEqual(config.DNSLabels, input.DNSLabels) { log.Infof("updating DNS labels [ %s ] (old value: [ %s ])", - strings.Join(input.DNSLabels, " "), - strings.Join(config.DNSLabels, " ")) + input.DNSLabels.SafeString(), + config.DNSLabels.SafeString()) config.DNSLabels = input.DNSLabels updated = true } diff --git a/client/server/server.go b/client/server/server.go index b0149989508..2979c721b70 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" @@ -405,8 +406,10 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro } if msg.DnsLabels != nil { - inputConfig.DNSLabels = msg.DnsLabels - s.latestConfigInput.DNSLabels = msg.DnsLabels + // TODO(hakan): discuss with @Viktor + dnsLabels, _ := domain.ValidateDomains(msg.DnsLabels) + inputConfig.DNSLabels = dnsLabels + s.latestConfigInput.DNSLabels = dnsLabels } s.mutex.Unlock() diff --git a/management/client/client.go b/management/client/client.go index 15ae54a17a0..e9eeaccc144 100644 --- a/management/client/client.go +++ b/management/client/client.go @@ -7,6 +7,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" ) @@ -15,7 +16,7 @@ type Client interface { Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error GetServerPublicKey() (*wgtypes.Key, error) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error) - Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels []string) (*proto.LoginResponse, error) + Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) diff --git a/management/client/grpc.go b/management/client/grpc.go index 51a33a4231c..e1409973176 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" nbgrpc "github.com/netbirdio/netbird/util/grpc" ) @@ -373,12 +374,12 @@ func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken s } // Login attempts login to Management Server. Takes care of encrypting and decrypting messages. -func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels []string) (*proto.LoginResponse, error) { +func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { keys := &proto.PeerKeys{ SshPubKey: pubSSHKey, WgPubKey: []byte(c.key.PublicKey().String()), } - return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels}) + return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) } // GetDeviceAuthorizationFlow returns a device authorization flow information. diff --git a/management/client/mock.go b/management/client/mock.go index 1d139b4d6d9..11564093a26 100644 --- a/management/client/mock.go +++ b/management/client/mock.go @@ -6,6 +6,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" ) @@ -14,7 +15,7 @@ type MockClient struct { SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error GetServerPublicKeyFunc func() (*wgtypes.Key, error) RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error) - LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels []string) (*proto.LoginResponse, error) + LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) SyncMetaFunc func(sysInfo *system.Info) error @@ -52,7 +53,7 @@ func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken s return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey) } -func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels []string) (*proto.LoginResponse, error) { +func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { if m.LoginFunc == nil { return nil, nil } diff --git a/management/domain/validate.go b/management/domain/validate.go new file mode 100644 index 00000000000..46e97f0451a --- /dev/null +++ b/management/domain/validate.go @@ -0,0 +1,40 @@ +package domain + +import ( + "fmt" + "regexp" + "strings" +) + +const maxDomains = 32 + +// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. +func ValidateDomains(domains []string) (List, error) { + if len(domains) == 0 { + return nil, fmt.Errorf("domains list is empty") + } + if len(domains) > maxDomains { + return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) + } + + domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) + + var domainList List + + for _, d := range domains { + d := strings.ToLower(d) + + // handles length and idna conversion + punycode, err := FromString(d) + if err != nil { + return domainList, fmt.Errorf("failed to convert domain to punycode: %s: %v", d, err) + } + + if !domainRegex.MatchString(string(punycode)) { + return domainList, fmt.Errorf("invalid domain format: %s", d) + } + + domainList = append(domainList, punycode) + } + return domainList, nil +} diff --git a/management/domain/validate_test.go b/management/domain/validate_test.go new file mode 100644 index 00000000000..c3d8119c028 --- /dev/null +++ b/management/domain/validate_test.go @@ -0,0 +1,97 @@ +package domain + +import ( + "testing" + + "github.com/magiconair/properties/assert" +) + +func TestValidateDomains(t *testing.T) { + tests := []struct { + name string + domains []string + expected List + wantErr bool + }{ + { + name: "Empty list", + domains: nil, + expected: nil, + wantErr: true, + }, + { + name: "Valid ASCII domain", + domains: []string{"sub.ex-ample.com"}, + expected: List{"sub.ex-ample.com"}, + wantErr: false, + }, + { + name: "Valid Unicode domain", + domains: []string{"münchen.de"}, + expected: List{"xn--mnchen-3ya.de"}, + wantErr: false, + }, + { + name: "Valid Unicode, all labels", + domains: []string{"中国.中国.中国"}, + expected: List{"xn--fiqs8s.xn--fiqs8s.xn--fiqs8s"}, + wantErr: false, + }, + { + name: "With underscores", + domains: []string{"_jabber._tcp.gmail.com"}, + expected: List{"_jabber._tcp.gmail.com"}, + wantErr: false, + }, + { + name: "Invalid domain format", + domains: []string{"-example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid domain format 2", + domains: []string{"example.com-"}, + expected: nil, + wantErr: true, + }, + { + name: "Multiple domains valid and invalid", + domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"}, + expected: List{"google.com"}, + wantErr: true, + }, + { + name: "Valid wildcard domain", + domains: []string{"*.example.com"}, + expected: List{"*.example.com"}, + wantErr: false, + }, + { + name: "Wildcard with dot domain", + domains: []string{".*.example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Wildcard with dot domain", + domains: []string{".*.example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid wildcard domain", + domains: []string{"a.*.example.com"}, + expected: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ValidateDomains(tt.domains) + assert.Equal(t, tt.wantErr, err != nil) + assert.Equal(t, got, tt.expected) + }) + } +} diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index a29ba45629d..6b6c3791000 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -2,11 +2,8 @@ package routes import ( "encoding/json" - "fmt" "net/http" "net/netip" - "regexp" - "strings" "unicode/utf8" "github.com/gorilla/mux" @@ -21,7 +18,6 @@ import ( "github.com/netbirdio/netbird/route" ) -const maxDomains = 32 const failedToConvertRoute = "failed to convert route to response: %v" // handler is the routes handler of the account @@ -102,7 +98,7 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { var networkType route.NetworkType var newPrefix netip.Prefix if req.Domains != nil { - d, err := validateDomains(*req.Domains) + d, err := domain.ValidateDomains(*req.Domains) if err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) return @@ -225,7 +221,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { } if req.Domains != nil { - d, err := validateDomains(*req.Domains) + d, err := domain.ValidateDomains(*req.Domains) if err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) return @@ -350,34 +346,3 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) { } return route, nil } - -// validateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. -func validateDomains(domains []string) (domain.List, error) { - if len(domains) == 0 { - return nil, fmt.Errorf("domains list is empty") - } - if len(domains) > maxDomains { - return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) - } - - domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) - - var domainList domain.List - - for _, d := range domains { - d := strings.ToLower(d) - - // handles length and idna conversion - punycode, err := domain.FromString(d) - if err != nil { - return domainList, fmt.Errorf("failed to convert domain to punycode: %s: %v", d, err) - } - - if !domainRegex.MatchString(string(punycode)) { - return domainList, fmt.Errorf("invalid domain format: %s", d) - } - - domainList = append(domainList, punycode) - } - return domainList, nil -} diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index 45c4655871a..51afbfde3bc 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -11,9 +11,10 @@ import ( "net/netip" "testing" - "github.com/netbirdio/netbird/management/server/util" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -563,96 +564,6 @@ func TestRoutesHandlers(t *testing.T) { } } -func TestValidateDomains(t *testing.T) { - tests := []struct { - name string - domains []string - expected domain.List - wantErr bool - }{ - { - name: "Empty list", - domains: nil, - expected: nil, - wantErr: true, - }, - { - name: "Valid ASCII domain", - domains: []string{"sub.ex-ample.com"}, - expected: domain.List{"sub.ex-ample.com"}, - wantErr: false, - }, - { - name: "Valid Unicode domain", - domains: []string{"münchen.de"}, - expected: domain.List{"xn--mnchen-3ya.de"}, - wantErr: false, - }, - { - name: "Valid Unicode, all labels", - domains: []string{"中国.中国.中国"}, - expected: domain.List{"xn--fiqs8s.xn--fiqs8s.xn--fiqs8s"}, - wantErr: false, - }, - { - name: "With underscores", - domains: []string{"_jabber._tcp.gmail.com"}, - expected: domain.List{"_jabber._tcp.gmail.com"}, - wantErr: false, - }, - { - name: "Invalid domain format", - domains: []string{"-example.com"}, - expected: nil, - wantErr: true, - }, - { - name: "Invalid domain format 2", - domains: []string{"example.com-"}, - expected: nil, - wantErr: true, - }, - { - name: "Multiple domains valid and invalid", - domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"}, - expected: domain.List{"google.com"}, - wantErr: true, - }, - { - name: "Valid wildcard domain", - domains: []string{"*.example.com"}, - expected: domain.List{"*.example.com"}, - wantErr: false, - }, - { - name: "Wildcard with dot domain", - domains: []string{".*.example.com"}, - expected: nil, - wantErr: true, - }, - { - name: "Wildcard with dot domain", - domains: []string{".*.example.com"}, - expected: nil, - wantErr: true, - }, - { - name: "Invalid wildcard domain", - domains: []string{"a.*.example.com"}, - expected: nil, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := validateDomains(tt.domains) - assert.Equal(t, tt.wantErr, err != nil) - assert.Equal(t, got, tt.expected) - }) - } -} - func toApiRoute(t *testing.T, r *route.Route) *api.Route { t.Helper()