diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e3d06dd..ef1e908 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,3 +29,5 @@ jobs: go mod download go mod verify make test + env: + CI: true diff --git a/Makefile b/Makefile index f9928f5..a184c02 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,10 @@ # Variables ## General Variables + +# CI +CI ?= false + # Branch Variables PROTECTED_BRANCH := master CURRENT_BRANCH := $(shell git rev-parse --abbrev-ref HEAD) @@ -290,6 +294,7 @@ go-test: ## to run tests $(AT)$(DOCKER) run ${DOCKER_OPTS} \ -v $(PWD):/app -w /app \ -e GOCACHE="/tmp" \ + -e CI=${CI} \ $(DOCKER_IMAGE_GO) \ /bin/sh -c \ "cd /app && \ diff --git a/config/config.sample.toml b/config/config.sample.toml index 900da7e..c736b36 100644 --- a/config/config.sample.toml +++ b/config/config.sample.toml @@ -39,6 +39,10 @@ ice_address_tcp = "" # The TCP port used to route media (audio/screen/video tracks). This is used to # generate TCP candidates. ice_port_tcp = 8443 +# Enables experimental IPv6 support. When this setting is true the RTC service +# will work in dual-stack mode, listening for IPv6 connections and generating +# candidates in addition to IPv4 ones. +enable_ipv6 = false # An optional hostname used to override the default value. By default, the # service will try to guess its own public IP through STUN (if configured). # diff --git a/service/rtc/config.go b/service/rtc/config.go index 4ebb6f7..ca3be09 100644 --- a/service/rtc/config.go +++ b/service/rtc/config.go @@ -25,6 +25,8 @@ type ServerConfig struct { // A list of ICE server (STUN/TURN) configurations to use. ICEServers ICEServers `toml:"ice_servers"` TURNConfig TURNConfig `toml:"turn"` + // EnableIPv6 specifies whether or not IPv6 should be used. + EnableIPv6 bool `toml:"enable_ipv6"` } func (c ServerConfig) IsValid() error { diff --git a/service/rtc/net.go b/service/rtc/net.go index 1b41d6c..0a7c110 100644 --- a/service/rtc/net.go +++ b/service/rtc/net.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "net" + "net/netip" "runtime" "syscall" "time" @@ -22,9 +23,9 @@ const ( tcpSocketWriteBufferSize = 1024 * 1024 * 4 // 4MB ) -// getSystemIPs returns a list of all the available IPv4 addresses. -func getSystemIPs(log mlog.LoggerIFace) ([]string, error) { - var ips []string +// getSystemIPs returns a list of all the available local addresses. +func getSystemIPs(log mlog.LoggerIFace, dualStack bool) ([]netip.Addr, error) { + var ips []netip.Addr interfaces, err := net.Interfaces() if err != nil { @@ -34,36 +35,43 @@ func getSystemIPs(log mlog.LoggerIFace) ([]string, error) { for _, iface := range interfaces { // filter out inactive interfaces if iface.Flags&net.FlagUp == 0 { - log.Debug("skipping inactive interface", mlog.String("interface", iface.Name)) + log.Info("skipping inactive interface", mlog.String("interface", iface.Name)) continue } addrs, err := iface.Addrs() if err != nil { - log.Debug("failed to get addresses for interface", mlog.String("interface", iface.Name)) + log.Warn("failed to get addresses for interface", mlog.String("interface", iface.Name)) continue } for _, addr := range addrs { - ip, _, err := net.ParseCIDR(addr.String()) + prefix, err := netip.ParsePrefix(addr.String()) if err != nil { - log.Debug("failed to parse address", mlog.Err(err), mlog.String("addr", addr.String())) + log.Warn("failed to parse prefix", mlog.Err(err), mlog.String("prefix", prefix.String())) continue } - // IPv4 only (for the time being at least, see MM-50294) - if ip.To4() == nil { + ip := prefix.Addr() + + if !dualStack && ip.Is6() { + log.Debug("ignoring IPv6 address: dual stack support is disabled by config", mlog.String("addr", ip.String())) + continue + } + + if ip.Is6() && !ip.IsGlobalUnicast() { + log.Debug("ignoring non global IPv6 address", mlog.String("addr", ip.String())) continue } - ips = append(ips, ip.String()) + ips = append(ips, ip) } } return ips, nil } -func createUDPConnsForAddr(log mlog.LoggerIFace, listenAddress string) ([]net.PacketConn, error) { +func createUDPConnsForAddr(log mlog.LoggerIFace, network, listenAddress string) ([]net.PacketConn, error) { var conns []net.PacketConn for i := 0; i < runtime.NumCPU(); i++ { @@ -84,7 +92,7 @@ func createUDPConnsForAddr(log mlog.LoggerIFace, listenAddress string) ([]net.Pa }, } - udpConn, err := listenConfig.ListenPacket(context.Background(), "udp4", listenAddress) + udpConn, err := listenConfig.ListenPacket(context.Background(), network, listenAddress) if err != nil { return nil, fmt.Errorf("failed to listen on udp: %w", err) } @@ -132,12 +140,12 @@ func createUDPConnsForAddr(log mlog.LoggerIFace, listenAddress string) ([]net.Pa return conns, nil } -func resolveHost(host string, timeout time.Duration) (string, error) { +func resolveHost(host, network string, timeout time.Duration) (string, error) { var ip string r := net.Resolver{} ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - addrs, err := r.LookupIP(ctx, "ip4", host) + addrs, err := r.LookupIP(ctx, network, host) if err != nil { return ip, fmt.Errorf("failed to resolve host %q: %w", host, err) } @@ -146,3 +154,7 @@ func resolveHost(host string, timeout time.Duration) (string, error) { } return ip, err } + +func areAddressesSameStack(addrA, addrB netip.Addr) bool { + return (addrA.Is4() && addrB.Is4()) || (addrA.Is6() && addrB.Is6()) +} diff --git a/service/rtc/net_test.go b/service/rtc/net_test.go index 3ce3ff3..8ec1d98 100644 --- a/service/rtc/net_test.go +++ b/service/rtc/net_test.go @@ -4,6 +4,8 @@ package rtc import ( + "net/netip" + "os" "runtime" "testing" @@ -20,9 +22,40 @@ func TestGetSystemIPs(t *testing.T) { require.NoError(t, err) }() - ips, err := getSystemIPs(log) - require.NoError(t, err) - require.NotEmpty(t, ips) + t.Run("ipv4", func(t *testing.T) { + ips, err := getSystemIPs(log, false) + require.NoError(t, err) + require.NotEmpty(t, ips) + + for _, ip := range ips { + require.True(t, ip.Is4()) + } + }) + + t.Run("dual stack", func(t *testing.T) { + // Skipping this test in CI since IPv6 is not yet supported by Github actions. + if os.Getenv("CI") != "" { + t.Skip() + } + + ips, err := getSystemIPs(log, true) + require.NoError(t, err) + require.NotEmpty(t, ips) + + var hasIPv4 bool + var hasIPv6 bool + for _, ip := range ips { + if ip.Is4() { + hasIPv4 = true + } + if ip.Is6() { + hasIPv6 = true + } + } + + require.True(t, hasIPv4) + require.True(t, hasIPv6) + }) } func TestCreateUDPConnsForAddr(t *testing.T) { @@ -33,16 +66,38 @@ func TestCreateUDPConnsForAddr(t *testing.T) { require.NoError(t, err) }() - ips, err := getSystemIPs(log) - require.NoError(t, err) - require.NotEmpty(t, ips) + t.Run("IPv4", func(t *testing.T) { + ips, err := getSystemIPs(log, false) + require.NoError(t, err) + require.NotEmpty(t, ips) - for _, ip := range ips { - conns, err := createUDPConnsForAddr(log, ip+":30443") + for _, ip := range ips { + conns, err := createUDPConnsForAddr(log, "udp4", netip.AddrPortFrom(ip, 30443).String()) + require.NoError(t, err) + require.Len(t, conns, runtime.NumCPU()) + for _, conn := range conns { + require.NoError(t, conn.Close()) + } + } + }) + + t.Run("dual stack", func(t *testing.T) { + // Skipping this test in CI since IPv6 is not yet supported by Github actions. + if os.Getenv("CI") != "" { + t.Skip() + } + + ips, err := getSystemIPs(log, false) require.NoError(t, err) - require.Len(t, conns, runtime.NumCPU()) - for _, conn := range conns { - require.NoError(t, conn.Close()) + require.NotEmpty(t, ips) + + for _, ip := range ips { + conns, err := createUDPConnsForAddr(log, "udp", netip.AddrPortFrom(ip, 30443).String()) + require.NoError(t, err) + require.Len(t, conns, runtime.NumCPU()) + for _, conn := range conns { + require.NoError(t, conn.Close()) + } } - } + }) } diff --git a/service/rtc/server.go b/service/rtc/server.go index 4abd89c..d297727 100644 --- a/service/rtc/server.go +++ b/service/rtc/server.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "net" + "net/netip" "sync" "time" @@ -32,8 +33,8 @@ type Server struct { udpMux ice.UDPMux tcpMux ice.TCPMux - publicAddrsMap map[string]string - localIPs []string + publicAddrsMap map[netip.Addr]string + localIPs []netip.Addr sendCh chan Message receiveCh chan Message @@ -63,7 +64,7 @@ func NewServer(cfg ServerConfig, log mlog.LoggerIFace, metrics Metrics) (*Server sendCh: make(chan Message, msgChSize), receiveCh: make(chan Message, msgChSize), bufPool: &sync.Pool{New: func() interface{} { return make([]byte, receiveMTU) }}, - publicAddrsMap: make(map[string]string), + publicAddrsMap: make(map[netip.Addr]string), } return s, nil @@ -83,9 +84,16 @@ func (s *Server) ReceiveCh() <-chan Message { } func (s *Server) Start() error { - var err error + udpNetwork := "udp4" + tcpNetwork := "tcp4" - localIPs, err := getSystemIPs(s.log) + if s.cfg.EnableIPv6 { + s.log.Info("rtc: experimental IPv6 support enabled") + udpNetwork = "udp" + tcpNetwork = "tcp" + } + + localIPs, err := getSystemIPs(s.log, s.cfg.EnableIPv6) if err != nil { return fmt.Errorf("failed to get system IPs: %w", err) } @@ -95,11 +103,13 @@ func (s *Server) Start() error { s.localIPs = localIPs + s.log.Debug("rtc: found local IPs", mlog.Any("ips", s.localIPs)) + // Populate public IP addresses map if override is not set and STUN is provided. if s.cfg.ICEHostOverride == "" && len(s.cfg.ICEServers) > 0 { for _, ip := range localIPs { - udpListenAddr := fmt.Sprintf("%s:%d", ip, s.cfg.ICEPortUDP) - udpAddr, err := net.ResolveUDPAddr("udp4", udpListenAddr) + udpListenAddr := netip.AddrPortFrom(ip, uint16(s.cfg.ICEPortUDP)).String() + udpAddr, err := net.ResolveUDPAddr(udpNetwork, udpListenAddr) if err != nil { s.log.Error("failed to resolve UDP address", mlog.Err(err)) continue @@ -107,22 +117,22 @@ func (s *Server) Start() error { // TODO: consider making this logic concurrent to lower total time taken // in case of multiple interfaces. - addr, err := getPublicIP(udpAddr, s.cfg.ICEServers.getSTUN()) + addr, err := getPublicIP(udpAddr, udpNetwork, s.cfg.ICEServers.getSTUN()) if err != nil { - s.log.Warn("failed to get public IP address for local interface", mlog.String("localAddr", ip), mlog.Err(err)) + s.log.Warn("failed to get public IP address for local interface", mlog.String("localAddr", ip.String()), mlog.Err(err)) } else { - s.log.Info("got public IP address for local interface", mlog.String("localAddr", ip), mlog.String("remoteAddr", addr)) + s.log.Info("got public IP address for local interface", mlog.String("localAddr", ip.String()), mlog.String("remoteAddr", addr)) } s.publicAddrsMap[ip] = addr } } - if err := s.initUDP(localIPs); err != nil { + if err := s.initUDP(localIPs, udpNetwork); err != nil { return err } - if err := s.initTCP(); err != nil { + if err := s.initTCP(tcpNetwork); err != nil { return err } @@ -291,11 +301,11 @@ func (s *Server) msgReader() { } } -func (s *Server) initUDP(localIPs []string) error { +func (s *Server) initUDP(localIPs []netip.Addr, network string) error { var udpMuxes []ice.UDPMux initUDPMux := func(addr string) error { - conns, err := createUDPConnsForAddr(s.log, addr) + conns, err := createUDPConnsForAddr(s.log, network, addr) if err != nil { return fmt.Errorf("failed to create UDP connections: %w", err) } @@ -315,7 +325,7 @@ func (s *Server) initUDP(localIPs []string) error { // If an address is specified we create a single udp mux. if s.cfg.ICEAddressUDP != "" { - if err := initUDPMux(fmt.Sprintf("%s:%d", s.cfg.ICEAddressUDP, s.cfg.ICEPortUDP)); err != nil { + if err := initUDPMux(net.JoinHostPort(s.cfg.ICEAddressUDP, fmt.Sprintf("%d", s.cfg.ICEPortUDP))); err != nil { return err } s.udpMux = udpMuxes[0] @@ -324,7 +334,7 @@ func (s *Server) initUDP(localIPs []string) error { // If no address is specified we create a mux for each interface we find. for _, ip := range localIPs { - if err := initUDPMux(fmt.Sprintf("%s:%d", ip, s.cfg.ICEPortUDP)); err != nil { + if err := initUDPMux(netip.AddrPortFrom(ip, uint16(s.cfg.ICEPortUDP)).String()); err != nil { return err } } @@ -334,8 +344,8 @@ func (s *Server) initUDP(localIPs []string) error { return nil } -func (s *Server) initTCP() error { - tcpListener, err := net.Listen("tcp4", fmt.Sprintf("%s:%d", s.cfg.ICEAddressTCP, s.cfg.ICEPortTCP)) +func (s *Server) initTCP(network string) error { + tcpListener, err := net.Listen(network, net.JoinHostPort(s.cfg.ICEAddressTCP, fmt.Sprintf("%d", s.cfg.ICEPortTCP))) if err != nil { return fmt.Errorf("failed to create TCP listener: %w", err) } diff --git a/service/rtc/server_test.go b/service/rtc/server_test.go index f621cdd..fc1d64b 100644 --- a/service/rtc/server_test.go +++ b/service/rtc/server_test.go @@ -121,7 +121,7 @@ func TestStartServer(t *testing.T) { require.NoError(t, err) defer udpConn.Close() - ips, err := getSystemIPs(log) + ips, err := getSystemIPs(log, false) require.NoError(t, err) require.NotEmpty(t, ips) diff --git a/service/rtc/sfu.go b/service/rtc/sfu.go index 29af569..010cea0 100644 --- a/service/rtc/sfu.go +++ b/service/rtc/sfu.go @@ -187,15 +187,19 @@ func (s *Server) InitSession(cfg SessionConfig, closeCb func() error) error { sEngine := webrtc.SettingEngine{} sEngine.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) - sEngine.SetNetworkTypes([]webrtc.NetworkType{ + networkTypes := []webrtc.NetworkType{ webrtc.NetworkTypeUDP4, webrtc.NetworkTypeTCP4, - }) + } + if s.cfg.EnableIPv6 { + networkTypes = append(networkTypes, webrtc.NetworkTypeUDP6, webrtc.NetworkTypeTCP6) + } + sEngine.SetNetworkTypes(networkTypes) sEngine.SetICEUDPMux(s.udpMux) sEngine.SetICETCPMux(s.tcpMux) sEngine.SetIncludeLoopbackCandidate(true) - pairs, err := generateAddrsPairs(s.localIPs, s.publicAddrsMap, s.cfg.ICEHostOverride) + pairs, err := generateAddrsPairs(s.localIPs, s.publicAddrsMap, s.cfg.ICEHostOverride, s.cfg.EnableIPv6) if err != nil { return fmt.Errorf("failed to generate addresses pairs: %w", err) } else if len(pairs) > 0 { diff --git a/service/rtc/stun.go b/service/rtc/stun.go index ffc0558..635f9a3 100644 --- a/service/rtc/stun.go +++ b/service/rtc/stun.go @@ -12,19 +12,19 @@ import ( "github.com/pion/stun" ) -func getPublicIP(addr *net.UDPAddr, stunURL string) (string, error) { +func getPublicIP(addr *net.UDPAddr, network, stunURL string) (string, error) { if stunURL == "" { return "", fmt.Errorf("no STUN server URL was provided") } - conn, err := net.ListenUDP("udp4", addr) + conn, err := net.ListenUDP(network, addr) if err != nil { return "", err } defer conn.Close() serverURL := stunURL[strings.Index(stunURL, ":")+1:] - serverAddr, err := net.ResolveUDPAddr("udp", serverURL) + serverAddr, err := net.ResolveUDPAddr(network, serverURL) if err != nil { return "", fmt.Errorf("failed to resolve stun host: %w", err) } diff --git a/service/rtc/utils.go b/service/rtc/utils.go index 87eaa9d..e868555 100644 --- a/service/rtc/utils.go +++ b/service/rtc/utils.go @@ -5,6 +5,7 @@ package rtc import ( "fmt" + "net/netip" "strings" "time" @@ -29,7 +30,7 @@ func getTrackType(kind webrtc.RTPCodecType) string { return "unknown" } -func generateAddrsPairs(localIPs []string, publicAddrsMap map[string]string, hostOverride string) ([]string, error) { +func generateAddrsPairs(localIPs []netip.Addr, publicAddrsMap map[netip.Addr]string, hostOverride string, dualStack bool) ([]string, error) { var err error var pairs []string var hostOverrideIP string @@ -40,9 +41,14 @@ func generateAddrsPairs(localIPs []string, publicAddrsMap map[string]string, hos return strings.Split(hostOverride, ","), nil } + ipNetwork := "ip4" + if dualStack { + ipNetwork = "ip" + } + // If the override is set we resolve it in case it's a hostname. if hostOverride != "" { - hostOverrideIP, err = resolveHost(hostOverride, time.Second) + hostOverrideIP, err = resolveHost(hostOverride, ipNetwork, time.Second) if err != nil { return pairs, fmt.Errorf("failed to resolve host: %w", err) } @@ -56,21 +62,25 @@ func generateAddrsPairs(localIPs []string, publicAddrsMap map[string]string, hos // If the override is set but no explicit mapping is given, we try to // generate one. if hostOverrideIP != "" { + hostOverrideAddr, err := netip.ParseAddr(hostOverrideIP) + if err != nil { + return nil, fmt.Errorf("failed to parse hostOverrideIP: %w", err) + } + // If only one local interface is found, we map that to the given public ip // override. - if len(localIPs) == 1 { + if len(localIPs) == 1 && areAddressesSameStack(hostOverrideAddr, localIPs[0]) { return []string{ - fmt.Sprintf("%s/%s", hostOverrideIP, localIPs[0]), + fmt.Sprintf("%s/%s", hostOverrideAddr.String(), localIPs[0].String()), }, nil } // Otherwise we map the override to any non-loopback IP. for _, localAddr := range localIPs { - // TODO: consider a better check to figure out if it's loopback. - if localAddr == "127.0.0.1" { - pairs = append(pairs, fmt.Sprintf("%s/%s", localAddr, localAddr)) - } else { - pairs = append(pairs, fmt.Sprintf("%s/%s", hostOverrideIP, localAddr)) + if localAddr.IsLoopback() { + pairs = append(pairs, fmt.Sprintf("%s/%s", localAddr.String(), localAddr.String())) + } else if areAddressesSameStack(hostOverrideAddr, localAddr) { + pairs = append(pairs, fmt.Sprintf("%s/%s", hostOverrideAddr.String(), localAddr.String())) } } @@ -87,9 +97,9 @@ func generateAddrsPairs(localIPs []string, publicAddrsMap map[string]string, hos for _, localAddr := range localIPs { publicAddr := publicAddrsMap[localAddr] if publicAddr == "" { - publicAddr = localAddr + publicAddr = localAddr.String() } - pairs = append(pairs, fmt.Sprintf("%s/%s", publicAddr, localAddr)) + pairs = append(pairs, fmt.Sprintf("%s/%s", publicAddr, localAddr.String())) } return pairs, nil diff --git a/service/rtc/utils_test.go b/service/rtc/utils_test.go index 8757cdf..d4fe045 100644 --- a/service/rtc/utils_test.go +++ b/service/rtc/utils_test.go @@ -4,6 +4,7 @@ package rtc import ( + "net/netip" "testing" "github.com/stretchr/testify/require" @@ -11,62 +12,80 @@ import ( func TestGenerateAddrsPairs(t *testing.T) { t.Run("nil/empty inputs", func(t *testing.T) { - pairs, err := generateAddrsPairs(nil, nil, "") + pairs, err := generateAddrsPairs(nil, nil, "", false) require.NoError(t, err) require.Empty(t, pairs) - pairs, err = generateAddrsPairs([]string{}, map[string]string{}, "") + pairs, err = generateAddrsPairs([]netip.Addr{}, map[netip.Addr]string{}, "", false) require.NoError(t, err) require.Empty(t, pairs) }) t.Run("no public addresses", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{ - "127.0.0.1": "", - "10.1.1.1": "", - }, "") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{ + netip.MustParseAddr("127.0.0.1"): "", + netip.MustParseAddr("10.1.1.1"): "", + }, "", false) require.NoError(t, err) require.Equal(t, []string{"127.0.0.1/127.0.0.1", "10.1.1.1/10.1.1.1"}, pairs) }) t.Run("full NAT mapping", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{}, "1.1.1.1/127.0.0.1,1.1.1.1/10.1.1.1") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{}, "1.1.1.1/127.0.0.1,1.1.1.1/10.1.1.1", false) require.NoError(t, err) require.Equal(t, []string{"1.1.1.1/127.0.0.1", "1.1.1.1/10.1.1.1"}, pairs) }) t.Run("no public addresses with override", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{ - "127.0.0.1": "", - "10.1.1.1": "", - }, "1.1.1.1") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{ + netip.MustParseAddr("127.0.0.1"): "", + netip.MustParseAddr("10.1.1.1"): "", + }, "1.1.1.1", false) require.NoError(t, err) require.Equal(t, []string{"127.0.0.1/127.0.0.1", "1.1.1.1/10.1.1.1"}, pairs) }) t.Run("single public address for multiple local addrs, no override", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{ - "127.0.0.1": "", - "10.1.1.1": "1.1.1.1", - }, "") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{ + netip.MustParseAddr("127.0.0.1"): "", + netip.MustParseAddr("10.1.1.1"): "1.1.1.1", + }, "", false) require.NoError(t, err) require.Equal(t, []string{"127.0.0.1/127.0.0.1", "1.1.1.1/10.1.1.1"}, pairs) }) t.Run("single local/public address map, no override", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{ - "127.0.0.1": "", - "10.1.1.1": "1.1.1.1", - }, "") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{ + netip.MustParseAddr("127.0.0.1"): "", + netip.MustParseAddr("10.1.1.1"): "1.1.1.1", + }, "", false) require.NoError(t, err) require.Equal(t, []string{"127.0.0.1/127.0.0.1", "1.1.1.1/10.1.1.1"}, pairs) }) t.Run("multiple public addresses, no override", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{ - "127.0.0.1": "1.1.1.1", - "10.1.1.1": "1.1.1.2", - }, "") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{ + netip.MustParseAddr("127.0.0.1"): "1.1.1.1", + netip.MustParseAddr("10.1.1.1"): "1.1.1.2", + }, "", false) require.NoError(t, err) require.Equal(t, []string{"1.1.1.1/127.0.0.1", "1.1.1.2/10.1.1.1"}, pairs) }) @@ -74,10 +93,13 @@ func TestGenerateAddrsPairs(t *testing.T) { // This is not a case that would happen in the application because the // override would prevent us from finding public IPs. t.Run("multiple public addresses, with override", func(t *testing.T) { - pairs, err := generateAddrsPairs([]string{"127.0.0.1", "10.1.1.1"}, map[string]string{ - "127.0.0.1": "1.1.1.1", - "10.1.1.1": "1.1.1.2", - }, "8.8.8.8") + pairs, err := generateAddrsPairs([]netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("10.1.1.1"), + }, map[netip.Addr]string{ + netip.MustParseAddr("127.0.0.1"): "1.1.1.1", + netip.MustParseAddr("10.1.1.1"): "1.1.1.2", + }, "8.8.8.8", false) require.NoError(t, err) require.Equal(t, []string{"127.0.0.1/127.0.0.1", "8.8.8.8/10.1.1.1"}, pairs) }) diff --git a/service/service.go b/service/service.go index 70584a6..022e195 100644 --- a/service/service.go +++ b/service/service.go @@ -122,6 +122,8 @@ func New(cfg Config) (*Service, error) { } func (s *Service) Start() error { + defer s.log.Flush() + if err := s.apiServer.Start(); err != nil { return fmt.Errorf("failed to start api server: %w", err) } @@ -185,6 +187,7 @@ func (s *Service) Start() error { } func (s *Service) Stop() error { + defer s.log.Flush() s.log.Info("rtcd: shutting down") if err := s.rtcServer.Stop(); err != nil {