Skip to content

Commit

Permalink
move validateDomains to management.domain pkg to reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
hakansa committed Jan 30, 2025
1 parent d1025b2 commit 89b1e5a
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 143 deletions.
25 changes: 21 additions & 4 deletions client/cmd/up.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/util"
)

Expand Down Expand Up @@ -105,13 +106,19 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err
}

dnsLabelsConverted, err := validateDnsLabels(cmd.Flag(dnsLabelsFlag).Changed)
if err != nil {
return err
}

ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsConverted,
}

if cmd.Flag(enableRosenpassFlag).Changed {
Expand Down Expand Up @@ -179,10 +186,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.BlockLANAccess = &blockLANAccess
}

if cmd.Flag(dnsLabelsFlag).Changed {
ic.DNSLabels = dnsLabels
}

providedSetupKey, err := getSetupKey()
if err != nil {
return err
Expand Down Expand Up @@ -452,6 +455,20 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
return parsed, nil
}

func validateDnsLabels(modified bool) (domain.List, error) {
var (
domains domain.List
err error
)
if modified {
domains, err = domain.ValidateDomains(dnsLabels)
if err != nil {
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
}
}
return domains, nil
}

func isValidAddrPort(input string) bool {
if input == "" {
return true
Expand Down
9 changes: 5 additions & 4 deletions client/internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh"
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/util"
)

Expand Down Expand Up @@ -69,7 +70,7 @@ type ConfigInput struct {

BlockLANAccess *bool

DNSLabels []string
DNSLabels domain.List
}

// Config Configuration type
Expand All @@ -95,7 +96,7 @@ type Config struct {

BlockLANAccess bool

DNSLabels []string
DNSLabels domain.List

// SSHKey is a private SSH key in a PEM format
SSHKey string
Expand Down Expand Up @@ -495,8 +496,8 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {

if input.DNSLabels != nil && !reflect.DeepEqual(config.DNSLabels, input.DNSLabels) {
log.Infof("updating DNS labels [ %s ] (old value: [ %s ])",
strings.Join(input.DNSLabels, " "),
strings.Join(config.DNSLabels, " "))
input.DNSLabels.SafeString(),
config.DNSLabels.SafeString())
config.DNSLabels = input.DNSLabels
updated = true
}
Expand Down
7 changes: 5 additions & 2 deletions client/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/domain"

"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
Expand Down Expand Up @@ -405,8 +406,10 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}

if msg.DnsLabels != nil {
inputConfig.DNSLabels = msg.DnsLabels
s.latestConfigInput.DNSLabels = msg.DnsLabels
// TODO(hakan): discuss with @Viktor
dnsLabels, _ := domain.ValidateDomains(msg.DnsLabels)
inputConfig.DNSLabels = dnsLabels
s.latestConfigInput.DNSLabels = dnsLabels
}

s.mutex.Unlock()
Expand Down
3 changes: 2 additions & 1 deletion management/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
)

Expand All @@ -15,7 +16,7 @@ type Client interface {
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels []string) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
Expand Down
5 changes: 3 additions & 2 deletions management/client/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
nbgrpc "github.com/netbirdio/netbird/util/grpc"
)
Expand Down Expand Up @@ -373,12 +374,12 @@ func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken s
}

// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels []string) (*proto.LoginResponse, error) {
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels})
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
}

// GetDeviceAuthorizationFlow returns a device authorization flow information.
Expand Down
5 changes: 3 additions & 2 deletions management/client/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
)

Expand All @@ -14,7 +15,7 @@ type MockClient struct {
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels []string) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(sysInfo *system.Info) error
Expand Down Expand Up @@ -52,7 +53,7 @@ func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken s
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey)
}

func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels []string) (*proto.LoginResponse, error) {
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.LoginFunc == nil {
return nil, nil
}
Expand Down
40 changes: 40 additions & 0 deletions management/domain/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package domain

import (
"fmt"
"regexp"
"strings"
)

const maxDomains = 32

// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
func ValidateDomains(domains []string) (List, error) {
if len(domains) == 0 {
return nil, fmt.Errorf("domains list is empty")
}
if len(domains) > maxDomains {
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
}

domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)

var domainList List

for _, d := range domains {
d := strings.ToLower(d)

// handles length and idna conversion
punycode, err := FromString(d)
if err != nil {
return domainList, fmt.Errorf("failed to convert domain to punycode: %s: %v", d, err)
}

if !domainRegex.MatchString(string(punycode)) {
return domainList, fmt.Errorf("invalid domain format: %s", d)
}

domainList = append(domainList, punycode)
}
return domainList, nil
}
97 changes: 97 additions & 0 deletions management/domain/validate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package domain

import (
"testing"

"github.com/magiconair/properties/assert"
)

func TestValidateDomains(t *testing.T) {
tests := []struct {
name string
domains []string
expected List
wantErr bool
}{
{
name: "Empty list",
domains: nil,
expected: nil,
wantErr: true,
},
{
name: "Valid ASCII domain",
domains: []string{"sub.ex-ample.com"},
expected: List{"sub.ex-ample.com"},
wantErr: false,
},
{
name: "Valid Unicode domain",
domains: []string{"münchen.de"},
expected: List{"xn--mnchen-3ya.de"},
wantErr: false,
},
{
name: "Valid Unicode, all labels",
domains: []string{"中国.中国.中国"},
expected: List{"xn--fiqs8s.xn--fiqs8s.xn--fiqs8s"},
wantErr: false,
},
{
name: "With underscores",
domains: []string{"_jabber._tcp.gmail.com"},
expected: List{"_jabber._tcp.gmail.com"},
wantErr: false,
},
{
name: "Invalid domain format",
domains: []string{"-example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid domain format 2",
domains: []string{"example.com-"},
expected: nil,
wantErr: true,
},
{
name: "Multiple domains valid and invalid",
domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"},
expected: List{"google.com"},
wantErr: true,
},
{
name: "Valid wildcard domain",
domains: []string{"*.example.com"},
expected: List{"*.example.com"},
wantErr: false,
},
{
name: "Wildcard with dot domain",
domains: []string{".*.example.com"},
expected: nil,
wantErr: true,
},
{
name: "Wildcard with dot domain",
domains: []string{".*.example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid wildcard domain",
domains: []string{"a.*.example.com"},
expected: nil,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ValidateDomains(tt.domains)
assert.Equal(t, tt.wantErr, err != nil)
assert.Equal(t, got, tt.expected)
})
}
}
39 changes: 2 additions & 37 deletions management/server/http/handlers/routes/routes_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ package routes

import (
"encoding/json"
"fmt"
"net/http"
"net/netip"
"regexp"
"strings"
"unicode/utf8"

"github.com/gorilla/mux"
Expand All @@ -21,7 +18,6 @@ import (
"github.com/netbirdio/netbird/route"
)

const maxDomains = 32
const failedToConvertRoute = "failed to convert route to response: %v"

// handler is the routes handler of the account
Expand Down Expand Up @@ -102,7 +98,7 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
var networkType route.NetworkType
var newPrefix netip.Prefix
if req.Domains != nil {
d, err := validateDomains(*req.Domains)
d, err := domain.ValidateDomains(*req.Domains)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return
Expand Down Expand Up @@ -225,7 +221,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
}

if req.Domains != nil {
d, err := validateDomains(*req.Domains)
d, err := domain.ValidateDomains(*req.Domains)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return
Expand Down Expand Up @@ -350,34 +346,3 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {
}
return route, nil
}

// validateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
func validateDomains(domains []string) (domain.List, error) {
if len(domains) == 0 {
return nil, fmt.Errorf("domains list is empty")
}
if len(domains) > maxDomains {
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
}

domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)

var domainList domain.List

for _, d := range domains {
d := strings.ToLower(d)

// handles length and idna conversion
punycode, err := domain.FromString(d)
if err != nil {
return domainList, fmt.Errorf("failed to convert domain to punycode: %s: %v", d, err)
}

if !domainRegex.MatchString(string(punycode)) {
return domainList, fmt.Errorf("invalid domain format: %s", d)
}

domainList = append(domainList, punycode)
}
return domainList, nil
}
Loading

0 comments on commit 89b1e5a

Please sign in to comment.