diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index da028761..68867481 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -22,7 +22,6 @@ services: volumes: - .:/usr/local/bin/wag:ro - ./docker-test-config.json:/opt/config.json - - ./devices.db:/opt/devices.db networks: custom_network: ipv4_address: 172.20.0.3 diff --git a/docker-test-config.json b/docker-test-config.json index b7a840ca..d52b8755 100644 --- a/docker-test-config.json +++ b/docker-test-config.json @@ -57,7 +57,6 @@ ], "TLSManagerListenURL": "https://container2:3434" }, - "DatabaseLocation": "devices.db", "Acls": { "Groups": { "group:administrators": [ @@ -114,4 +113,4 @@ } } } -} \ No newline at end of file +} diff --git a/internal/config/testing_config.json b/internal/config/testing_config.json index 65be4ea0..d106c863 100644 --- a/internal/config/testing_config.json +++ b/internal/config/testing_config.json @@ -24,7 +24,7 @@ }, "Wireguard": { "DevName": "loopbackTun1", - "ListenPort": 53230, + "ListenPort": 53231, "PrivateKey": "cFYv9YROACD78hFBxQ29mkXol974NMLMt4hFOe+oXl4=", "Address": "192.168.1.1/24", "MTU": 1420 diff --git a/internal/config/testing_config2.json b/internal/config/testing_config2.json index 4a413106..7a6d9c0a 100644 --- a/internal/config/testing_config2.json +++ b/internal/config/testing_config2.json @@ -27,7 +27,7 @@ }, "Wireguard": { "DevName": "loopbackTun1", - "ListenPort": 53230, + "ListenPort": 53232, "PrivateKey": "cFYv9YROACD78hFBxQ29mkXol974NMLMt4hFOe+oXl4=", "Address": "192.168.1.1/24", "MTU": 1420 diff --git a/internal/data/acls.go b/internal/data/acls.go index 41916c4d..18eac3c4 100644 --- a/internal/data/acls.go +++ b/internal/data/acls.go @@ -88,7 +88,7 @@ func insertMap(m map[string]bool, values ...string) { func hostIPWithMask(ip net.IP) string { mask := "/32" - if ip.To16() == nil { + if ip.To4() == nil && ip.To16() != nil { mask = "/128" } diff --git a/internal/data/acls_test.go b/internal/data/acls_test.go new file mode 100644 index 00000000..7d170257 --- /dev/null +++ b/internal/data/acls_test.go @@ -0,0 +1,52 @@ +package data + +import ( + "net" + "testing" +) + +func TestHostIPWithMask(t *testing.T) { + tests := []struct { + name string + ip net.IP + expected string + }{ + { + name: "IPv4 address", + ip: net.ParseIP("192.168.1.1"), + expected: "192.168.1.1/32", + }, + { + name: "IPv6 address", + ip: net.ParseIP("2001:db8::1"), + expected: "2001:db8::1/128", + }, + { + name: "IPv4 loopback", + ip: net.ParseIP("127.0.0.1"), + expected: "127.0.0.1/32", + }, + { + name: "IPv6 loopback", + ip: net.ParseIP("::1"), + expected: "::1/128", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := hostIPWithMask(tt.ip); got != tt.expected { + t.Errorf("hostIPWithMask() = %v, want %v", got, tt.expected) + } + }) + } + + // Test with nil IP + t.Run("nil IP", func(t *testing.T) { + var nilIP net.IP + got := hostIPWithMask(nilIP) + if got != "/32" { + t.Errorf("hostIPWithMask() with nil IP = %v, want /32", got) + } + }) +} diff --git a/internal/router/firewall_test.go b/internal/router/firewall_test.go index 4fdf7cad..2323e100 100644 --- a/internal/router/firewall_test.go +++ b/internal/router/firewall_test.go @@ -16,6 +16,7 @@ import ( "github.com/NHAS/wag/internal/data" "github.com/NHAS/wag/internal/routetypes" "golang.org/x/net/ipv4" + "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun/tuntest" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -52,6 +53,33 @@ var ( mockTun *tuntest.ChannelTUN ) +func TestSetupRealWireguardDevice(t *testing.T) { + + const dummyIPv4Device = "dev-ipv4" + tdev4, err := tun.CreateTUN(dummyIPv4Device, 1500) + if err != nil { + t.Fatal(err) + } + defer tdev4.Close() + + const dummyIPv6Device = "dev-ipv6" + tdev6, err := tun.CreateTUN(dummyIPv6Device, 1500) + if err != nil { + t.Fatal(err) + } + defer tdev6.Close() + + err = testFw.bringUpInterface(dummyIPv4Device, "192.168.0.1/24") + if err != nil { + t.Fatal(err) + } + + err = testFw.bringUpInterface(dummyIPv6Device, "2001:db8::1/6") + if err != nil { + t.Fatal(err) + } +} + func TestBlankPacket(t *testing.T) { buff := make([]byte, 15) diff --git a/internal/router/init.go b/internal/router/init.go index 417f897e..238ad45b 100644 --- a/internal/router/init.go +++ b/internal/router/init.go @@ -59,7 +59,7 @@ func newFw(testing, iptables bool, testDev tun.Device) (*Firewall, error) { if testing { err = fw.setupWireguardDebug(testDev) } else { - err = fw.setupWireguard() + err = fw.setupWireguard(config.Values.Wireguard.Address, config.Values.Wireguard.DevName, config.Values.Wireguard.MTU) } if err != nil { diff --git a/internal/router/wireguard.go b/internal/router/wireguard.go index dd4322bc..9160c164 100644 --- a/internal/router/wireguard.go +++ b/internal/router/wireguard.go @@ -209,43 +209,45 @@ func (f *Firewall) endpointChange(e device.Event) { } } -func (f *Firewall) setupWireguard() error { - // open TUN device +func (f *Firewall) bringUpInterface(devName, network string) error { - tdev, err := tun.CreateTUN(config.Values.Wireguard.DevName, config.Values.Wireguard.MTU) + conn, err := netlink.Dial(unix.NETLINK_ROUTE, nil) if err != nil { - return fmt.Errorf("failed to create TUN device: path: %q mtu: %d, err %v", config.Values.Wireguard.DevName, config.Values.Wireguard.MTU, err) + return err } + defer conn.Close() - err = f.openWireguard(tdev) + ip, ipNet, err := net.ParseCIDR(network) if err != nil { - return fmt.Errorf("failed to open wireguard device: %s", err) + return err } - err = func(network string) error { + ipNet.IP = ip - conn, err := netlink.Dial(unix.NETLINK_ROUTE, nil) - if err != nil { - return err - } - defer conn.Close() + err = f.setIp(conn, devName, *ipNet) + if err != nil { + return err + } - ip, ipNet, err := net.ParseCIDR(network) - if err != nil { - return err - } + return f.setUp(conn, devName) +} - ipNet.IP = ip +func (f *Firewall) setupWireguard(devName, address string, MTU int) error { + // open TUN device - err = f.setIp(conn, config.Values.Wireguard.DevName, *ipNet) - if err != nil { - return err - } + tdev, err := tun.CreateTUN(devName, MTU) + if err != nil { + return fmt.Errorf("failed to create TUN device: path: %q mtu: %d, err %v", devName, MTU, err) + } + + err = f.openWireguard(tdev) + if err != nil { + return fmt.Errorf("failed to open wireguard device: %s", err) + } - return f.setUp(conn, config.Values.Wireguard.DevName) - }(config.Values.Wireguard.Address) + err = f.bringUpInterface(devName, address) if err != nil { - return fmt.Errorf("unable to set wireguard tunnel ip: %s", err) + return fmt.Errorf("failed to set ip and bring wireguard tun device up: %s", err) } return err @@ -480,10 +482,10 @@ func (f *Firewall) setIp(c *netlink.Conn, name string, address net.IPNet) error preflen, _ := address.Mask.Size() - if address.IP.To4() == nil { + if address.IP.To4() != nil { IP = address.IP.To4() addrMsg.Family = unix.AF_INET - } else if address.IP.To16() == nil { + } else if address.IP.To16() != nil { IP = address.IP.To16() addrMsg.Family = unix.AF_INET6 } else {