diff --git a/caddy/shadowsocks_handler.go b/caddy/shadowsocks_handler.go index a0c48747..51f71057 100644 --- a/caddy/shadowsocks_handler.go +++ b/caddy/shadowsocks_handler.go @@ -19,16 +19,21 @@ import ( "fmt" "log/slog" "net" + "time" "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" - outline "github.com/Jigsaw-Code/outline-ss-server/service" "github.com/caddyserver/caddy/v2" "github.com/mholt/caddy-l4/layer4" + + outline "github.com/Jigsaw-Code/outline-ss-server/service" ) const ssModuleName = "layer4.handlers.shadowsocks" +// A UDP NAT timeout of at least 5 minutes is recommended in RFC 4787 Section 4.3. +const defaultNatTimeout time.Duration = 5 * time.Minute + func init() { caddy.RegisterModule(ModuleRegistration{ ID: ssModuleName, @@ -45,8 +50,10 @@ type KeyConfig struct { type ShadowsocksHandler struct { Keys []KeyConfig `json:"keys,omitempty"` - service outline.Service - logger *slog.Logger + streamHandler outline.StreamHandler + associationHandler outline.AssociationHandler + metrics outline.ServiceMetrics + logger *slog.Logger } var ( @@ -70,6 +77,7 @@ func (h *ShadowsocksHandler) Provision(ctx caddy.Context) error { if !ok { return fmt.Errorf("module `%s` is of type `%T`, expected `OutlineApp`", outlineModuleName, app) } + h.metrics = app.Metrics if len(h.Keys) == 0 { h.logger.Warn("no keys configured") @@ -97,16 +105,12 @@ func (h *ShadowsocksHandler) Provision(ctx caddy.Context) error { ciphers := outline.NewCipherList() ciphers.Update(cipherList) - service, err := outline.NewShadowsocksService( + h.streamHandler, h.associationHandler = outline.NewShadowsocksHandlers( outline.WithLogger(h.logger), outline.WithCiphers(ciphers), - outline.WithMetrics(app.Metrics), + outline.WithMetrics(h.metrics), outline.WithReplayCache(&app.ReplayCache), ) - if err != nil { - return err - } - h.service = service return nil } @@ -114,9 +118,9 @@ func (h *ShadowsocksHandler) Provision(ctx caddy.Context) error { func (h *ShadowsocksHandler) Handle(cx *layer4.Connection, _ layer4.Handler) error { switch conn := cx.Conn.(type) { case transport.StreamConn: - h.service.HandleStream(cx.Context, conn) - case net.PacketConn: - h.service.HandlePacket(conn) + h.streamHandler.HandleStream(cx.Context, conn, h.metrics.AddOpenTCPConnection(conn)) + case net.Conn: + h.associationHandler.HandleAssociation(cx.Context, conn, h.metrics.AddOpenUDPAssociation(conn)) default: return fmt.Errorf("failed to handle unknown connection type: %t", conn) } diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 28a32425..e8203d42 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -16,6 +16,7 @@ package main import ( "container/list" + "context" "flag" "fmt" "log/slog" @@ -27,6 +28,7 @@ import ( "syscall" "time" + "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/lmittmann/tint" "github.com/prometheus/client_golang/prometheus" @@ -223,11 +225,11 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { ciphers := service.NewCipherList() ciphers.Update(cipherList) - ssService, err := service.NewShadowsocksService( + streamHandler, associationHandler := service.NewShadowsocksHandlers( service.WithCiphers(ciphers), - service.WithNatTimeout(s.natTimeout), service.WithMetrics(s.serviceMetrics), service.WithReplayCache(&s.replayCache), + service.WithPacketListener(service.MakeTargetUDPListener(s.natTimeout, 0)), service.WithLogger(slog.Default()), ) ln, err := lnSet.ListenStream(addr) @@ -235,14 +237,18 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { return err } slog.Info("TCP service started.", "address", ln.Addr().String()) - go service.StreamServe(ln.AcceptStream, ssService.HandleStream) + go service.StreamServe(ln.AcceptStream, func(ctx context.Context, conn transport.StreamConn) { + streamHandler.HandleStream(ctx, conn, s.serviceMetrics.AddOpenTCPConnection(conn)) + }) pc, err := lnSet.ListenPacket(addr) if err != nil { return err } slog.Info("UDP service started.", "address", pc.LocalAddr().String()) - go ssService.HandlePacket(pc) + go service.PacketServe(pc, func(ctx context.Context, conn net.Conn) { + associationHandler.HandleAssociation(ctx, conn, s.serviceMetrics.AddOpenUDPAssociation(conn)) + }, s.serverMetrics) } for _, serviceConfig := range config.Services { @@ -250,13 +256,12 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { if err != nil { return fmt.Errorf("failed to create cipher list from config: %v", err) } - ssService, err := service.NewShadowsocksService( + streamHandler, associationHandler := service.NewShadowsocksHandlers( service.WithCiphers(ciphers), - service.WithNatTimeout(s.natTimeout), service.WithMetrics(s.serviceMetrics), service.WithReplayCache(&s.replayCache), service.WithStreamDialer(service.MakeValidatingTCPStreamDialer(onet.RequirePublicIP, serviceConfig.Dialer.Fwmark)), - service.WithPacketListener(service.MakeTargetUDPListener(serviceConfig.Dialer.Fwmark)), + service.WithPacketListener(service.MakeTargetUDPListener(s.natTimeout, serviceConfig.Dialer.Fwmark)), service.WithLogger(slog.Default()), ) if err != nil { @@ -275,7 +280,9 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { } return serviceConfig.Dialer.Fwmark }()) - go service.StreamServe(ln.AcceptStream, ssService.HandleStream) + go service.StreamServe(ln.AcceptStream, func(ctx context.Context, conn transport.StreamConn) { + streamHandler.HandleStream(ctx, conn, s.serviceMetrics.AddOpenTCPConnection(conn)) + }) case listenerTypeUDP: pc, err := lnSet.ListenPacket(lnConfig.Address) if err != nil { @@ -287,7 +294,9 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { } return serviceConfig.Dialer.Fwmark }()) - go ssService.HandlePacket(pc) + go service.PacketServe(pc, func(ctx context.Context, conn net.Conn) { + associationHandler.HandleAssociation(ctx, conn, s.serviceMetrics.AddOpenUDPAssociation(conn)) + }, s.serverMetrics) } } totalCipherCount += len(serviceConfig.Keys) diff --git a/cmd/outline-ss-server/metrics.go b/cmd/outline-ss-server/metrics.go index 32c9b0aa..35ead166 100644 --- a/cmd/outline-ss-server/metrics.go +++ b/cmd/outline-ss-server/metrics.go @@ -17,6 +17,7 @@ package main import ( "time" + "github.com/Jigsaw-Code/outline-ss-server/service" "github.com/prometheus/client_golang/prometheus" ) @@ -25,12 +26,15 @@ var now = time.Now type serverMetrics struct { // NOTE: New metrics need to be added to `newPrometheusServerMetrics()`, `Describe()` and `Collect()`. - buildInfo *prometheus.GaugeVec - accessKeys prometheus.Gauge - ports prometheus.Gauge + buildInfo *prometheus.GaugeVec + accessKeys prometheus.Gauge + ports prometheus.Gauge + addedNatEntries prometheus.Counter + removedNatEntries prometheus.Counter } var _ prometheus.Collector = (*serverMetrics)(nil) +var _ service.NATMetrics = (*serverMetrics)(nil) // newPrometheusServerMetrics constructs a Prometheus metrics collector for server // related metrics. @@ -48,6 +52,16 @@ func newPrometheusServerMetrics() *serverMetrics { Name: "ports", Help: "Count of open ports", }), + addedNatEntries: prometheus.NewCounter(prometheus.CounterOpts{ + Subsystem: "udp", + Name: "nat_entries_added", + Help: "Entries added to the UDP NAT table", + }), + removedNatEntries: prometheus.NewCounter(prometheus.CounterOpts{ + Subsystem: "udp", + Name: "nat_entries_removed", + Help: "Entries removed from the UDP NAT table", + }), } } @@ -55,12 +69,16 @@ func (m *serverMetrics) Describe(ch chan<- *prometheus.Desc) { m.buildInfo.Describe(ch) m.accessKeys.Describe(ch) m.ports.Describe(ch) + m.addedNatEntries.Describe(ch) + m.removedNatEntries.Describe(ch) } func (m *serverMetrics) Collect(ch chan<- prometheus.Metric) { m.buildInfo.Collect(ch) m.accessKeys.Collect(ch) m.ports.Collect(ch) + m.addedNatEntries.Collect(ch) + m.removedNatEntries.Collect(ch) } func (m *serverMetrics) SetVersion(version string) { @@ -71,3 +89,11 @@ func (m *serverMetrics) SetNumAccessKeys(numKeys int, ports int) { m.accessKeys.Set(float64(numKeys)) m.ports.Set(float64(ports)) } + +func (m *serverMetrics) AddNATEntry() { + m.addedNatEntries.Inc() +} + +func (m *serverMetrics) RemoveNATEntry() { + m.removedNatEntries.Inc() +} diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 0994b90f..7fb08718 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -17,6 +17,7 @@ package integration_test import ( "bytes" "context" + "errors" "fmt" "io" "net" @@ -104,6 +105,9 @@ func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) { for { n, clientAddr, err := conn.ReadFromUDP(buf) if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } t.Logf("Failed to read from UDP conn: %v", err) return } @@ -138,7 +142,7 @@ func TestTCPEcho(t *testing.T) { go func() { service.StreamServe( func() (transport.StreamConn, error) { return proxyListener.AcceptTCP() }, - func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn, testMetrics) }, + func(ctx context.Context, conn transport.StreamConn) { handler.HandleStream(ctx, conn, testMetrics) }, ) done <- struct{}{} }() @@ -197,7 +201,7 @@ type statusMetrics struct { statuses []string } -func (m *statusMetrics) AddClosed(status string, data metrics.ProxyMetrics, duration time.Duration) { +func (m *statusMetrics) AddClose(status string, data metrics.ProxyMetrics, duration time.Duration) { m.Lock() m.statuses = append(m.statuses, status) m.Unlock() @@ -217,7 +221,7 @@ func TestRestrictedAddresses(t *testing.T) { go func() { service.StreamServe( service.WrapStreamAcceptFunc(proxyListener.AcceptTCP), - func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn, testMetrics) }, + func(ctx context.Context, conn transport.StreamConn) { handler.HandleStream(ctx, conn, testMetrics) }, ) done <- struct{}{} }() @@ -257,47 +261,52 @@ func TestRestrictedAddresses(t *testing.T) { assert.ElementsMatch(t, testMetrics.statuses, expectedStatus) } +// Stub metrics implementation for testing NAT behaviors. +type natTestMetrics struct { + natEntriesAdded int +} + +var _ service.NATMetrics = (*natTestMetrics)(nil) + +func (m *natTestMetrics) AddNATEntry() { + m.natEntriesAdded++ +} +func (m *natTestMetrics) RemoveNATEntry() {} + // Metrics about one UDP packet. type udpRecord struct { - clientAddr net.Addr accessKey, status string in, out int64 } -type fakeUDPConnMetrics struct { - clientAddr net.Addr - accessKey string - up, down []udpRecord +type fakeUDPAssociationMetrics struct { + accessKey string + up, down []udpRecord + mu sync.Mutex } -var _ service.UDPConnMetrics = (*fakeUDPConnMetrics)(nil) +var _ service.UDPAssociationMetrics = (*fakeUDPAssociationMetrics)(nil) -func (m *fakeUDPConnMetrics) AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64) { - m.up = append(m.up, udpRecord{m.clientAddr, m.accessKey, status, clientProxyBytes, proxyTargetBytes}) -} -func (m *fakeUDPConnMetrics) AddPacketFromTarget(status string, targetProxyBytes, proxyClientBytes int64) { - m.down = append(m.down, udpRecord{m.clientAddr, m.accessKey, status, targetProxyBytes, proxyClientBytes}) -} -func (m *fakeUDPConnMetrics) RemoveNatEntry() { - // Not tested because it requires waiting for a long timeout. +func (m *fakeUDPAssociationMetrics) AddAuthentication(key string) { + m.mu.Lock() + defer m.mu.Unlock() + m.accessKey = key } -// Fake metrics implementation for UDP -type fakeUDPMetrics struct { - connMetrics []fakeUDPConnMetrics +func (m *fakeUDPAssociationMetrics) AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.up = append(m.up, udpRecord{m.accessKey, status, clientProxyBytes, proxyTargetBytes}) } -var _ service.UDPMetrics = (*fakeUDPMetrics)(nil) - -func (m *fakeUDPMetrics) AddUDPNatEntry(clientAddr net.Addr, accessKey string) service.UDPConnMetrics { - cm := fakeUDPConnMetrics{ - clientAddr: clientAddr, - accessKey: accessKey, - } - m.connMetrics = append(m.connMetrics, cm) - return &m.connMetrics[len(m.connMetrics)-1] +func (m *fakeUDPAssociationMetrics) AddPacketFromTarget(status string, targetProxyBytes, proxyClientBytes int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.down = append(m.down, udpRecord{m.accessKey, status, targetProxyBytes, proxyClientBytes}) } +func (m *fakeUDPAssociationMetrics) AddClose() {} + func TestUDPEcho(t *testing.T) { echoConn, echoRunning := startUDPEchoServer(t) @@ -310,14 +319,14 @@ func TestUDPEcho(t *testing.T) { if err != nil { t.Fatal(err) } - testMetrics := &fakeUDPMetrics{} - proxy := service.NewPacketHandler(time.Hour, cipherList, testMetrics, &fakeShadowsocksMetrics{}) + proxy := service.NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{}) + proxy.SetTargetIPValidator(allowAll) - done := make(chan struct{}) - go func() { - proxy.Handle(proxyConn) - done <- struct{}{} - }() + natMetrics := &natTestMetrics{} + associationMetrics := &fakeUDPAssociationMetrics{} + go service.PacketServe(proxyConn, func(ctx context.Context, conn net.Conn) { + proxy.HandleAssociation(ctx, conn, associationMetrics) + }, natMetrics) cryptoKey, err := shadowsocks.NewEncryptionKey(shadowsocks.CHACHA20IETFPOLY1305, secrets[0]) require.NoError(t, err) @@ -356,34 +365,28 @@ func TestUDPEcho(t *testing.T) { echoConn.Close() echoRunning.Wait() proxyConn.Close() - <-done // Verify that the expected metrics were reported. snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) keyID := snapshot[0].Value.(*service.CipherEntry).ID - if len(testMetrics.connMetrics) != 1 { - t.Errorf("Wrong NAT count: %d", len(testMetrics.connMetrics)) - } - if len(testMetrics.connMetrics[0].up) != 1 { - t.Errorf("Wrong number of packets sent: %v", testMetrics.connMetrics[0].up) - } else { - record := testMetrics.connMetrics[0].up[0] - require.Equal(t, conn.LocalAddr(), record.clientAddr, "Bad upstream metrics") - require.Equal(t, keyID, record.accessKey, "Bad upstream metrics") - require.Equal(t, "OK", record.status, "Bad upstream metrics") - require.Greater(t, record.in, record.out, "Bad upstream metrics") - require.Equal(t, int64(N), record.out, "Bad upstream metrics") - } - if len(testMetrics.connMetrics[0].down) != 1 { - t.Errorf("Wrong number of packets received: %v", testMetrics.connMetrics[0].down) - } else { - record := testMetrics.connMetrics[0].down[0] - require.Equal(t, conn.LocalAddr(), record.clientAddr, "Bad downstream metrics") - require.Equal(t, keyID, record.accessKey, "Bad downstream metrics") - require.Equal(t, "OK", record.status, "Bad downstream metrics") - require.Greater(t, record.out, record.in, "Bad downstream metrics") - require.Equal(t, int64(N), record.in, "Bad downstream metrics") - } + require.Equal(t, natMetrics.natEntriesAdded, 1, "Wrong NAT count") + + associationMetrics.mu.Lock() + defer associationMetrics.mu.Unlock() + + require.Lenf(t, associationMetrics.up, 1, "Wrong number of packets sent") + record := associationMetrics.up[0] + require.Equal(t, keyID, record.accessKey, "Bad upstream metrics") + require.Equal(t, "OK", record.status, "Bad upstream metrics") + require.Greater(t, record.in, record.out, "Bad upstream metrics") + require.Equal(t, int64(N), record.out, "Bad upstream metrics") + + require.Lenf(t, associationMetrics.down, 1, "Wrong number of packets received") + record = associationMetrics.down[0] + require.Equal(t, keyID, record.accessKey, "Bad downstream metrics") + require.Equal(t, "OK", record.status, "Bad downstream metrics") + require.Greater(t, record.out, record.in, "Bad downstream metrics") + require.Equal(t, int64(N), record.in, "Bad downstream metrics") } func BenchmarkTCPThroughput(b *testing.B) { @@ -407,7 +410,7 @@ func BenchmarkTCPThroughput(b *testing.B) { go func() { service.StreamServe( service.WrapStreamAcceptFunc(proxyListener.AcceptTCP), - func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn, testMetrics) }, + func(ctx context.Context, conn transport.StreamConn) { handler.HandleStream(ctx, conn, testMetrics) }, ) done <- struct{}{} }() @@ -474,7 +477,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) { go func() { service.StreamServe( service.WrapStreamAcceptFunc(proxyListener.AcceptTCP), - func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn, testMetrics) }, + func(ctx context.Context, conn transport.StreamConn) { handler.HandleStream(ctx, conn, testMetrics) }, ) done <- struct{}{} }() @@ -544,11 +547,13 @@ func BenchmarkUDPEcho(b *testing.B) { if err != nil { b.Fatal(err) } - proxy := service.NewPacketHandler(time.Hour, cipherList, &service.NoOpUDPMetrics{}, &fakeShadowsocksMetrics{}) + proxy := service.NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{}) proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(server) + service.PacketServe(server, func(ctx context.Context, conn net.Conn) { + proxy.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) done <- struct{}{} }() @@ -588,11 +593,13 @@ func BenchmarkUDPManyKeys(b *testing.B) { if err != nil { b.Fatal(err) } - proxy := service.NewPacketHandler(time.Hour, cipherList, &service.NoOpUDPMetrics{}, &fakeShadowsocksMetrics{}) + proxy := service.NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{}) proxy.SetTargetIPValidator(allowAll) done := make(chan struct{}) go func() { - proxy.Handle(proxyConn) + service.PacketServe(proxyConn, func(ctx context.Context, conn net.Conn) { + proxy.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) done <- struct{}{} }() diff --git a/prometheus/metrics.go b/prometheus/metrics.go index c87277eb..ffc2a7de 100644 --- a/prometheus/metrics.go +++ b/prometheus/metrics.go @@ -117,7 +117,7 @@ func newTCPConnMetrics(tcpServiceMetrics *tcpServiceMetrics, tunnelTimeMetrics * } } -func (cm *tcpConnMetrics) AddAuthenticated(accessKey string) { +func (cm *tcpConnMetrics) AddAuthentication(accessKey string) { cm.accessKey = accessKey ipKey, err := toIPKey(cm.clientAddr, accessKey) if err == nil { @@ -125,7 +125,7 @@ func (cm *tcpConnMetrics) AddAuthenticated(accessKey string) { } } -func (cm *tcpConnMetrics) AddClosed(status string, data metrics.ProxyMetrics, duration time.Duration) { +func (cm *tcpConnMetrics) AddClose(status string, data metrics.ProxyMetrics, duration time.Duration) { cm.tcpServiceMetrics.proxyCollector.addClientTarget(data.ClientProxy, data.ProxyTarget, cm.accessKey, cm.clientInfo) cm.tcpServiceMetrics.proxyCollector.addTargetClient(data.TargetProxy, data.ProxyClient, cm.accessKey, cm.clientInfo) cm.tcpServiceMetrics.closeConnection(status, duration, cm.accessKey, cm.clientInfo) @@ -250,23 +250,25 @@ type udpConnMetrics struct { accessKey string } -var _ service.UDPConnMetrics = (*udpConnMetrics)(nil) +var _ service.UDPAssociationMetrics = (*udpConnMetrics)(nil) -func newUDPConnMetrics(udpServiceMetrics *udpServiceMetrics, tunnelTimeMetrics *tunnelTimeMetrics, accessKey string, clientAddr net.Addr, clientInfo ipinfo.IPInfo) *udpConnMetrics { - udpServiceMetrics.addNatEntry() - ipKey, err := toIPKey(clientAddr, accessKey) - if err == nil { - tunnelTimeMetrics.startConnection(*ipKey) - } +func newUDPAssociationMetrics(udpServiceMetrics *udpServiceMetrics, tunnelTimeMetrics *tunnelTimeMetrics, clientAddr net.Addr, clientInfo ipinfo.IPInfo) *udpConnMetrics { return &udpConnMetrics{ udpServiceMetrics: udpServiceMetrics, tunnelTimeMetrics: tunnelTimeMetrics, - accessKey: accessKey, clientAddr: clientAddr, clientInfo: clientInfo, } } +func (cm *udpConnMetrics) AddAuthentication(accessKey string) { + cm.accessKey = accessKey + ipKey, err := toIPKey(cm.clientAddr, accessKey) + if err == nil { + cm.tunnelTimeMetrics.startConnection(*ipKey) + } +} + func (cm *udpConnMetrics) AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64) { cm.udpServiceMetrics.addPacketFromClient(status, clientProxyBytes, proxyTargetBytes, cm.accessKey, cm.clientInfo) } @@ -275,12 +277,14 @@ func (cm *udpConnMetrics) AddPacketFromTarget(status string, targetProxyBytes, p cm.udpServiceMetrics.addPacketFromTarget(status, targetProxyBytes, proxyClientBytes, cm.accessKey, cm.clientInfo) } -func (cm *udpConnMetrics) RemoveNatEntry() { - cm.udpServiceMetrics.removeNatEntry() - - ipKey, err := toIPKey(cm.clientAddr, cm.accessKey) - if err == nil { - cm.tunnelTimeMetrics.stopConnection(*ipKey) +func (cm *udpConnMetrics) AddClose() { + // We only track authenticated connections, so ignore unauthenticated closed connections + // when calculating tunneltime. + if cm.accessKey != "" { + ipKey, err := toIPKey(cm.clientAddr, cm.accessKey) + if err == nil { + cm.tunnelTimeMetrics.stopConnection(*ipKey) + } } } @@ -288,8 +292,6 @@ type udpServiceMetrics struct { proxyCollector *proxyCollector // NOTE: New metrics need to be added to `newUDPCollector()`, `Describe()` and `Collect()`. packetsFromClientPerLocation *prometheus.CounterVec - addedNatEntries prometheus.Counter - removedNatEntries prometheus.Counter timeToCipherMs prometheus.ObserverVec } @@ -315,18 +317,6 @@ func newUDPCollector() (*udpServiceMetrics, error) { Name: "packets_from_client_per_location", Help: "Packets received from the client, per location and status", }, []string{"location", "asn", "asorg", "status"}), - addedNatEntries: prometheus.NewCounter( - prometheus.CounterOpts{ - Namespace: namespace, - Name: "nat_entries_added", - Help: "Entries added to the UDP NAT table", - }), - removedNatEntries: prometheus.NewCounter( - prometheus.CounterOpts{ - Namespace: namespace, - Name: "nat_entries_removed", - Help: "Entries removed from the UDP NAT table", - }), }, nil } @@ -334,24 +324,12 @@ func (c *udpServiceMetrics) Describe(ch chan<- *prometheus.Desc) { c.proxyCollector.Describe(ch) c.timeToCipherMs.Describe(ch) c.packetsFromClientPerLocation.Describe(ch) - c.addedNatEntries.Describe(ch) - c.removedNatEntries.Describe(ch) } func (c *udpServiceMetrics) Collect(ch chan<- prometheus.Metric) { c.proxyCollector.Collect(ch) c.timeToCipherMs.Collect(ch) c.packetsFromClientPerLocation.Collect(ch) - c.addedNatEntries.Collect(ch) - c.removedNatEntries.Collect(ch) -} - -func (c *udpServiceMetrics) addNatEntry() { - c.addedNatEntries.Inc() -} - -func (c *udpServiceMetrics) removeNatEntry() { - c.removedNatEntries.Inc() } func (c *udpServiceMetrics) addPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64, accessKey string, clientInfo ipinfo.IPInfo) { @@ -531,23 +509,26 @@ func (m *serviceMetrics) getIPInfoFromAddr(addr net.Addr) ipinfo.IPInfo { return ipInfo } +// TODO: Split TCP and UDP metrics. + func (m *serviceMetrics) AddOpenTCPConnection(clientConn net.Conn) service.TCPConnMetrics { clientAddr := clientConn.RemoteAddr() clientInfo := m.getIPInfoFromAddr(clientAddr) return newTCPConnMetrics(m.tcpServiceMetrics, m.tunnelTimeMetrics, clientConn, clientInfo) } -func (m *serviceMetrics) AddUDPNatEntry(clientAddr net.Addr, accessKey string) service.UDPConnMetrics { +func (m *serviceMetrics) AddOpenUDPAssociation(clientConn net.Conn) service.UDPAssociationMetrics { + clientAddr := clientConn.RemoteAddr() clientInfo := m.getIPInfoFromAddr(clientAddr) - return newUDPConnMetrics(m.udpServiceMetrics, m.tunnelTimeMetrics, accessKey, clientAddr, clientInfo) + return newUDPAssociationMetrics(m.udpServiceMetrics, m.tunnelTimeMetrics, clientAddr, clientInfo) } -func (m *serviceMetrics) AddCipherSearch(proto string, accessKeyFound bool, timeToCipher time.Duration) { - if proto == "tcp" { - m.tcpServiceMetrics.AddCipherSearch(accessKeyFound, timeToCipher) - } else if proto == "udp" { - m.udpServiceMetrics.AddCipherSearch(accessKeyFound, timeToCipher) - } +func (m *serviceMetrics) AddTCPCipherSearch(accessKeyFound bool, timeToCipher time.Duration) { + m.tcpServiceMetrics.AddCipherSearch(accessKeyFound, timeToCipher) +} + +func (m *serviceMetrics) AddUDPCipherSearch(accessKeyFound bool, timeToCipher time.Duration) { + m.udpServiceMetrics.AddCipherSearch(accessKeyFound, timeToCipher) } // addIfNonZero helps avoid the creation of series that are always zero. diff --git a/prometheus/metrics_test.go b/prometheus/metrics_test.go index 5dfcf05a..735b3dd4 100644 --- a/prometheus/metrics_test.go +++ b/prometheus/metrics_test.go @@ -70,17 +70,17 @@ func TestMethodsDontPanic(t *testing.T) { TargetProxy: 3, ProxyClient: 4, } - addr := fakeAddr("127.0.0.1:9") tcpMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) - tcpMetrics.AddAuthenticated("0") - tcpMetrics.AddClosed("OK", proxyMetrics, 10*time.Millisecond) + tcpMetrics.AddAuthentication("0") + tcpMetrics.AddClose("OK", proxyMetrics, 10*time.Millisecond) tcpMetrics.AddProbe("ERR_CIPHER", "eof", proxyMetrics.ClientProxy) - udpMetrics := ssMetrics.AddUDPNatEntry(addr, "key-1") + udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) + udpMetrics.AddAuthentication("0") udpMetrics.AddPacketFromClient("OK", 10, 20) udpMetrics.AddPacketFromTarget("OK", 10, 20) - udpMetrics.RemoveNatEntry() + udpMetrics.AddClose() ssMetrics.tcpServiceMetrics.AddCipherSearch(true, 10*time.Millisecond) ssMetrics.udpServiceMetrics.AddCipherSearch(true, 10*time.Millisecond) @@ -99,7 +99,7 @@ func TestTunnelTime(t *testing.T) { reg.MustRegister(ssMetrics) connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) - connMetrics.AddAuthenticated("key-1") + connMetrics.AddAuthentication("key-1") setNow(time.Date(2010, 1, 2, 3, 4, 20, .0, time.Local)) expected := strings.NewReader(` @@ -122,7 +122,7 @@ func TestTunnelTime(t *testing.T) { reg.MustRegister(ssMetrics) connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) - connMetrics.AddAuthenticated("key-1") + connMetrics.AddAuthentication("key-1") setNow(time.Date(2010, 1, 2, 3, 4, 10, .0, time.Local)) expected := strings.NewReader(` @@ -144,7 +144,7 @@ func TestTunnelTimePerKeyDoesNotPanicOnUnknownClosedConnection(t *testing.T) { ssMetrics, _ := NewServiceMetrics(nil) connMetrics := ssMetrics.AddOpenTCPConnection(&fakeConn{}) - connMetrics.AddClosed("OK", metrics.ProxyMetrics{}, time.Minute) + connMetrics.AddClose("OK", metrics.ProxyMetrics{}, time.Minute) err := promtest.GatherAndCompare( reg, @@ -172,8 +172,8 @@ func BenchmarkCloseTCP(b *testing.B) { duration := time.Minute b.ResetTimer() for i := 0; i < b.N; i++ { - connMetrics.AddAuthenticated(accessKey) - connMetrics.AddClosed(status, data, duration) + connMetrics.AddAuthentication(accessKey) + connMetrics.AddClose(status, data, duration) } } @@ -191,9 +191,7 @@ func BenchmarkProbe(b *testing.B) { func BenchmarkClientUDP(b *testing.B) { ssMetrics, _ := NewServiceMetrics(nil) - addr := fakeAddr("127.0.0.1:9") - accessKey := "key 1" - udpMetrics := ssMetrics.AddUDPNatEntry(addr, accessKey) + udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) status := "OK" size := int64(1000) b.ResetTimer() @@ -204,9 +202,7 @@ func BenchmarkClientUDP(b *testing.B) { func BenchmarkTargetUDP(b *testing.B) { ssMetrics, _ := NewServiceMetrics(nil) - addr := fakeAddr("127.0.0.1:9") - accessKey := "key 1" - udpMetrics := ssMetrics.AddUDPNatEntry(addr, accessKey) + udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) status := "OK" size := int64(1000) b.ResetTimer() @@ -215,12 +211,11 @@ func BenchmarkTargetUDP(b *testing.B) { } } -func BenchmarkNAT(b *testing.B) { +func BenchmarkClose(b *testing.B) { ssMetrics, _ := NewServiceMetrics(nil) - addr := fakeAddr("127.0.0.1:9") b.ResetTimer() for i := 0; i < b.N; i++ { - udpMetrics := ssMetrics.AddUDPNatEntry(addr, "key-0") - udpMetrics.RemoveNatEntry() + udpMetrics := ssMetrics.AddOpenUDPAssociation(&fakeConn{}) + udpMetrics.AddClose() } } diff --git a/service/shadowsocks.go b/service/shadowsocks.go index c194957c..75bfe245 100644 --- a/service/shadowsocks.go +++ b/service/shadowsocks.go @@ -15,7 +15,6 @@ package service import ( - "context" "log/slog" "net" "time" @@ -25,13 +24,8 @@ import ( onet "github.com/Jigsaw-Code/outline-ss-server/net" ) -const ( - // 59 seconds is most common timeout for servers that do not respond to invalid requests - tcpReadTimeout time.Duration = 59 * time.Second - - // A UDP NAT timeout of at least 5 minutes is recommended in RFC 4787 Section 4.3. - defaultNatTimeout time.Duration = 5 * time.Minute -) +// 59 seconds is most common timeout for servers that do not respond to invalid requests +const tcpReadTimeout time.Duration = 59 * time.Second // ShadowsocksConnMetrics is used to report Shadowsocks related metrics on connections. type ShadowsocksConnMetrics interface { @@ -39,14 +33,10 @@ type ShadowsocksConnMetrics interface { } type ServiceMetrics interface { - UDPMetrics + AddOpenUDPAssociation(conn net.Conn) UDPAssociationMetrics AddOpenTCPConnection(conn net.Conn) TCPConnMetrics - AddCipherSearch(proto string, accessKeyFound bool, timeToCipher time.Duration) -} - -type Service interface { - HandleStream(ctx context.Context, conn transport.StreamConn) - HandlePacket(conn net.PacketConn) + AddTCPCipherSearch(accessKeyFound bool, timeToCipher time.Duration) + AddUDPCipherSearch(accessKeyFound bool, timeToCipher time.Duration) } // Option is a Shadowsocks service constructor option. @@ -54,52 +44,42 @@ type Option func(s *ssService) type ssService struct { logger *slog.Logger - metrics ServiceMetrics ciphers CipherList - natTimeout time.Duration + metrics ServiceMetrics targetIPValidator onet.TargetIPValidator replayCache *ReplayCache streamDialer transport.StreamDialer - sh StreamHandler packetListener transport.PacketListener - ph PacketHandler } -// NewShadowsocksService creates a new Shadowsocks service. -func NewShadowsocksService(opts ...Option) (Service, error) { - s := &ssService{} +// NewShadowsocksHandlers creates new Shadowsocks stream and packet handlers. +func NewShadowsocksHandlers(opts ...Option) (StreamHandler, AssociationHandler) { + s := &ssService{ + logger: noopLogger(), + } for _, opt := range opts { opt(s) } - // If no NAT timeout is provided via options, use the recommended default. - if s.natTimeout == 0 { - s.natTimeout = defaultNatTimeout - } - // If no logger is provided via options, use a noop logger. - if s.logger == nil { - s.logger = noopLogger() - } - // TODO: Register initial data metrics at zero. - s.sh = NewStreamHandler( - NewShadowsocksStreamAuthenticator(s.ciphers, s.replayCache, &ssConnMetrics{ServiceMetrics: s.metrics, proto: "tcp"}, s.logger), + sh := NewStreamHandler( + NewShadowsocksStreamAuthenticator(s.ciphers, s.replayCache, &ssConnMetrics{s.metrics.AddTCPCipherSearch}, s.logger), tcpReadTimeout, ) if s.streamDialer != nil { - s.sh.SetTargetDialer(s.streamDialer) + sh.SetTargetDialer(s.streamDialer) } - s.sh.SetLogger(s.logger) + sh.SetLogger(s.logger) - s.ph = NewPacketHandler(s.natTimeout, s.ciphers, s.metrics, &ssConnMetrics{ServiceMetrics: s.metrics, proto: "udp"}) + ah := NewAssociationHandler(s.ciphers, &ssConnMetrics{s.metrics.AddUDPCipherSearch}) if s.packetListener != nil { - s.ph.SetTargetPacketListener(s.packetListener) + ah.SetTargetPacketListener(s.packetListener) } - s.ph.SetLogger(s.logger) + ah.SetLogger(s.logger) - return s, nil + return sh, ah } // WithLogger can be used to provide a custom log target. If not provided, @@ -117,7 +97,6 @@ func WithCiphers(ciphers CipherList) Option { } } -// WithMetrics option function. func WithMetrics(metrics ServiceMetrics) Option { return func(s *ssService) { s.metrics = metrics @@ -131,13 +110,6 @@ func WithReplayCache(replayCache *ReplayCache) Option { } } -// WithNatTimeout option function. -func WithNatTimeout(natTimeout time.Duration) Option { - return func(s *ssService) { - s.natTimeout = natTimeout - } -} - // WithStreamDialer option function. func WithStreamDialer(dialer transport.StreamDialer) Option { return func(s *ssService) { @@ -152,30 +124,15 @@ func WithPacketListener(listener transport.PacketListener) Option { } } -// HandleStream handles a Shadowsocks stream-based connection. -func (s *ssService) HandleStream(ctx context.Context, conn transport.StreamConn) { - var connMetrics TCPConnMetrics - if s.metrics != nil { - connMetrics = s.metrics.AddOpenTCPConnection(conn) - } - s.sh.Handle(ctx, conn, connMetrics) -} - -// HandlePacket handles a Shadowsocks packet connection. -func (s *ssService) HandlePacket(conn net.PacketConn) { - s.ph.Handle(conn) -} - type ssConnMetrics struct { - ServiceMetrics - proto string + addCipherSearch func(accessKeyFound bool, timeToCipher time.Duration) } var _ ShadowsocksConnMetrics = (*ssConnMetrics)(nil) func (cm *ssConnMetrics) AddCipherSearch(accessKeyFound bool, timeToCipher time.Duration) { - if cm.ServiceMetrics != nil { - cm.ServiceMetrics.AddCipherSearch(cm.proto, accessKeyFound, timeToCipher) + if cm.addCipherSearch != nil { + cm.addCipherSearch(accessKeyFound, timeToCipher) } } diff --git a/service/tcp.go b/service/tcp.go index 23db7cd6..50b6b0d7 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -37,8 +37,8 @@ import ( // TCPConnMetrics is used to report metrics on TCP connections. type TCPConnMetrics interface { - AddAuthenticated(accessKey string) - AddClosed(status string, data metrics.ProxyMetrics, duration time.Duration) + AddAuthentication(accessKey string) + AddClose(status string, data metrics.ProxyMetrics, duration time.Duration) AddProbe(status, drainResult string, clientProxyBytes int64) } @@ -164,6 +164,8 @@ type streamHandler struct { dialer transport.StreamDialer } +var _ StreamHandler = (*streamHandler)(nil) + // NewStreamHandler creates a StreamHandler func NewStreamHandler(authenticate StreamAuthenticateFunc, timeout time.Duration) StreamHandler { return &streamHandler{ @@ -176,7 +178,7 @@ func NewStreamHandler(authenticate StreamAuthenticateFunc, timeout time.Duration // StreamHandler is a handler that handles stream connections. type StreamHandler interface { - Handle(ctx context.Context, conn transport.StreamConn, connMetrics TCPConnMetrics) + HandleStream(ctx context.Context, conn transport.StreamConn, connMetrics TCPConnMetrics) // SetLogger sets the logger used to log messages. Uses a no-op logger if nil. SetLogger(l *slog.Logger) // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses. @@ -219,7 +221,7 @@ type StreamHandleFunc func(ctx context.Context, conn transport.StreamConn) // StreamServe repeatedly calls `accept` to obtain connections and `handle` to handle them until // accept() returns [ErrClosed]. When that happens, all connection handlers will be notified // via their [context.Context]. StreamServe will return after all pending handlers return. -func StreamServe(accept StreamAcceptFunc, handle StreamHandleFunc) { +func StreamServe(accept StreamAcceptFunc, streamHandle StreamHandleFunc) { var running sync.WaitGroup defer running.Wait() ctx, contextCancel := context.WithCancel(context.Background()) @@ -243,12 +245,12 @@ func StreamServe(accept StreamAcceptFunc, handle StreamHandleFunc) { slog.Warn("Panic in TCP handler. Continuing to listen.", "err", r) } }() - handle(ctx, clientConn) + streamHandle(ctx, clientConn) }() } } -func (h *streamHandler) Handle(ctx context.Context, clientConn transport.StreamConn, connMetrics TCPConnMetrics) { +func (h *streamHandler) HandleStream(ctx context.Context, clientConn transport.StreamConn, connMetrics TCPConnMetrics) { if connMetrics == nil { connMetrics = &NoOpTCPConnMetrics{} } @@ -264,7 +266,7 @@ func (h *streamHandler) Handle(ctx context.Context, clientConn transport.StreamC status = connError.Status h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) } - connMetrics.AddClosed(status, proxyMetrics, connDuration) + connMetrics.AddClose(status, proxyMetrics, connDuration) measuredClientConn.Close() // Closing after the metrics are added aids integration testing. h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Done.", slog.String("status", status), slog.Duration("duration", connDuration)) } @@ -336,7 +338,7 @@ func (h *streamHandler) handleConnection(ctx context.Context, outerConn transpor h.absorbProbe(outerConn, connMetrics, authErr.Status, proxyMetrics) return authErr } - connMetrics.AddAuthenticated(id) + connMetrics.AddAuthentication(id) // Read target address and dial it. tgtAddr, err := getProxyRequest(innerConn) @@ -387,9 +389,9 @@ type NoOpTCPConnMetrics struct{} var _ TCPConnMetrics = (*NoOpTCPConnMetrics)(nil) -func (m *NoOpTCPConnMetrics) AddAuthenticated(accessKey string) {} +func (m *NoOpTCPConnMetrics) AddAuthentication(accessKey string) {} -func (m *NoOpTCPConnMetrics) AddClosed(status string, data metrics.ProxyMetrics, duration time.Duration) { +func (m *NoOpTCPConnMetrics) AddClose(status string, data metrics.ProxyMetrics, duration time.Duration) { } func (m *NoOpTCPConnMetrics) AddProbe(status, drainResult string, clientProxyBytes int64) {} diff --git a/service/tcp_test.go b/service/tcp_test.go index ab497f91..9f5ecb60 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -235,13 +235,13 @@ var _ TCPConnMetrics = (*probeTestMetrics)(nil) var _ ShadowsocksConnMetrics = (*fakeShadowsocksMetrics)(nil) -func (m *probeTestMetrics) AddClosed(status string, data metrics.ProxyMetrics, duration time.Duration) { +func (m *probeTestMetrics) AddClose(status string, data metrics.ProxyMetrics, duration time.Duration) { m.mu.Lock() m.closeStatus = append(m.closeStatus, status) m.mu.Unlock() } -func (m *probeTestMetrics) AddAuthenticated(accessKey string) { +func (m *probeTestMetrics) AddAuthentication(accessKey string) { } func (m *probeTestMetrics) AddProbe(status, drainResult string, clientProxyBytes int64) { @@ -292,7 +292,7 @@ func TestProbeRandom(t *testing.T) { go func() { StreamServe( WrapStreamAcceptFunc(listener.AcceptTCP), - func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn, testMetrics) }, + func(ctx context.Context, conn transport.StreamConn) { handler.HandleStream(ctx, conn, testMetrics) }, ) done <- struct{}{} }() @@ -373,7 +373,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { go func() { StreamServe( WrapStreamAcceptFunc(listener.AcceptTCP), - func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn, testMetrics) }, + func(ctx context.Context, conn transport.StreamConn) { handler.HandleStream(ctx, conn, testMetrics) }, ) done <- struct{}{} }() @@ -411,7 +411,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { go func() { StreamServe( WrapStreamAcceptFunc(listener.AcceptTCP), - func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn, testMetrics) }, + func(ctx context.Context, conn transport.StreamConn) { handler.HandleStream(ctx, conn, testMetrics) }, ) done <- struct{}{} }() @@ -450,7 +450,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { go func() { StreamServe( WrapStreamAcceptFunc(listener.AcceptTCP), - func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn, testMetrics) }, + func(ctx context.Context, conn transport.StreamConn) { handler.HandleStream(ctx, conn, testMetrics) }, ) done <- struct{}{} }() @@ -495,7 +495,7 @@ func TestProbeServerBytesModified(t *testing.T) { go func() { StreamServe( WrapStreamAcceptFunc(listener.AcceptTCP), - func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn, testMetrics) }, + func(ctx context.Context, conn transport.StreamConn) { handler.HandleStream(ctx, conn, testMetrics) }, ) done <- struct{}{} }() @@ -551,7 +551,7 @@ func TestReplayDefense(t *testing.T) { go func() { StreamServe( WrapStreamAcceptFunc(listener.AcceptTCP), - func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn, testMetrics) }, + func(ctx context.Context, conn transport.StreamConn) { handler.HandleStream(ctx, conn, testMetrics) }, ) done <- struct{}{} }() @@ -624,7 +624,7 @@ func TestReverseReplayDefense(t *testing.T) { go func() { StreamServe( WrapStreamAcceptFunc(listener.AcceptTCP), - func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn, testMetrics) }, + func(ctx context.Context, conn transport.StreamConn) { handler.HandleStream(ctx, conn, testMetrics) }, ) done <- struct{}{} }() @@ -686,7 +686,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) { go func() { StreamServe( WrapStreamAcceptFunc(listener.AcceptTCP), - func(ctx context.Context, conn transport.StreamConn) { handler.Handle(ctx, conn, testMetrics) }, + func(ctx context.Context, conn transport.StreamConn) { handler.HandleStream(ctx, conn, testMetrics) }, ) done <- struct{}{} }() diff --git a/service/udp.go b/service/udp.go index 52af2dc3..775d7b27 100644 --- a/service/udp.go +++ b/service/udp.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "io" "log/slog" "net" "net/netip" @@ -29,35 +30,41 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" "github.com/shadowsocks/go-shadowsocks2/socks" + "github.com/Jigsaw-Code/outline-ss-server/internal/slicepool" onet "github.com/Jigsaw-Code/outline-ss-server/net" ) -// UDPConnMetrics is used to report metrics on UDP connections. -type UDPConnMetrics interface { +// NATMetrics is used to report NAT related metrics. +type NATMetrics interface { + AddNATEntry() + RemoveNATEntry() +} + +// UDPAssociationMetrics is used to report metrics on UDP associations. +type UDPAssociationMetrics interface { + AddAuthentication(accessKey string) AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64) AddPacketFromTarget(status string, targetProxyBytes, proxyClientBytes int64) - RemoveNatEntry() + AddClose() } -type UDPMetrics interface { - AddUDPNatEntry(clientAddr net.Addr, accessKey string) UDPConnMetrics -} +const ( + // Max UDP buffer size for the server code. + serverUDPBufferSize = 64 * 1024 -// Max UDP buffer size for the server code. -const serverUDPBufferSize = 64 * 1024 + // A UDP NAT timeout of at least 5 minutes is recommended in RFC 4787 Section 4.3. + defaultNatTimeout time.Duration = 5 * time.Minute +) + +// Buffer pool used for reading UDP packets. +var readBufPool = slicepool.MakePool(serverUDPBufferSize) // Wrapper for slog.Debug during UDP proxying. -func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) { +func debugUDP(l *slog.Logger, msg string, attrs ...slog.Attr) { // This is an optimization to reduce unnecessary allocations due to an interaction // between Go's inlining/escape analysis and varargs functions like slog.Debug. if l.Enabled(nil, slog.LevelDebug) { - l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("UDP: %s", template), slog.String("ID", cipherID), attr) - } -} - -func debugUDPAddr(l *slog.Logger, template string, addr net.Addr, attr slog.Attr) { - if l.Enabled(nil, slog.LevelDebug) { - l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("UDP: %s", template), slog.String("address", addr.String()), attr) + l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("UDP: %s", msg), attrs...) } } @@ -71,10 +78,10 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis id, cryptoKey := entry.Value.(*CipherEntry).ID, entry.Value.(*CipherEntry).CryptoKey buf, err := shadowsocks.Unpack(dst, src, cryptoKey) if err != nil { - debugUDP(l, "Failed to unpack.", id, slog.Any("err", err)) + debugUDP(l, "Failed to unpack.", slog.String("ID", id), slog.Any("err", err)) continue } - debugUDP(l, "Found cipher.", id, slog.Int("index", ci)) + debugUDP(l, "Found cipher.", slog.String("ID", id), slog.Int("index", ci)) // Move the active cipher to the front, so that the search is quicker next time. cipherList.MarkUsedByClientIP(entry, clientIP) return buf, id, cryptoKey, nil @@ -82,125 +89,120 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis return nil, "", nil, errors.New("could not find valid UDP cipher") } -type packetHandler struct { +type associationHandler struct { logger *slog.Logger - natTimeout time.Duration ciphers CipherList - m UDPMetrics ssm ShadowsocksConnMetrics targetIPValidator onet.TargetIPValidator targetListener transport.PacketListener } -// NewPacketHandler creates a PacketHandler -func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetrics, ssMetrics ShadowsocksConnMetrics) PacketHandler { - if m == nil { - m = &NoOpUDPMetrics{} - } +var _ AssociationHandler = (*associationHandler)(nil) + +// NewAssociationHandler creates a Shadowsocks proxy AssociationHandler. +func NewAssociationHandler(cipherList CipherList, ssMetrics ShadowsocksConnMetrics) AssociationHandler { if ssMetrics == nil { ssMetrics = &NoOpShadowsocksConnMetrics{} } - return &packetHandler{ + return &associationHandler{ logger: noopLogger(), - natTimeout: natTimeout, ciphers: cipherList, - m: m, ssm: ssMetrics, targetIPValidator: onet.RequirePublicIP, - targetListener: MakeTargetUDPListener(0), + targetListener: MakeTargetUDPListener(defaultNatTimeout, 0), } } -// PacketHandler is a running UDP shadowsocks proxy that can be stopped. -type PacketHandler interface { +// AssociationHandler is a handler that handles UDP assocations. +type AssociationHandler interface { + HandleAssociation(ctx context.Context, conn net.Conn, assocMetrics UDPAssociationMetrics) // SetLogger sets the logger used to log messages. Uses a no-op logger if nil. SetLogger(l *slog.Logger) // SetTargetIPValidator sets the function to be used to validate the target IP addresses. SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) // SetTargetPacketListener sets the packet listener to use for target connections. SetTargetPacketListener(targetListener transport.PacketListener) - // Handle returns after clientConn closes and all the sub goroutines return. - Handle(clientConn net.PacketConn) } -func (h *packetHandler) SetLogger(l *slog.Logger) { +func (h *associationHandler) SetLogger(l *slog.Logger) { if l == nil { l = noopLogger() } h.logger = l } -func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { +func (h *associationHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { h.targetIPValidator = targetIPValidator } -func (h *packetHandler) SetTargetPacketListener(targetListener transport.PacketListener) { +func (h *associationHandler) SetTargetPacketListener(targetListener transport.PacketListener) { h.targetListener = targetListener } -// Listen on addr for encrypted packets and basically do UDP NAT. -// We take the ciphers as a pointer because it gets replaced on config updates. -func (h *packetHandler) Handle(clientConn net.PacketConn) { - nm := newNATmap(h.natTimeout, h.m, h.logger) - defer nm.Close() - cipherBuf := make([]byte, serverUDPBufferSize) - textBuf := make([]byte, serverUDPBufferSize) +func (h *associationHandler) HandleAssociation(ctx context.Context, clientConn net.Conn, assocMetrics UDPAssociationMetrics) { + l := h.logger.With(slog.Any("client", clientConn.RemoteAddr())) + + defer func() { + debugUDP(l, "Done") + assocMetrics.AddClose() + }() + var targetConn net.PacketConn + var cryptoKey *shadowsocks.EncryptionKey + + readBufLazySlice := readBufPool.LazySlice() + readBuf := readBufLazySlice.Acquire() + defer readBufLazySlice.Release() for { - clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf) - if errors.Is(err, net.ErrClosed) { + select { + case <-ctx.Done(): + break + default: + } + clientProxyBytes, err := clientConn.Read(readBuf) + if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { + debugUDP(l, "Client closed connection") break } + pkt := readBuf[:clientProxyBytes] + debugUDP(l, "Outbound packet.", slog.Int("bytes", clientProxyBytes)) var proxyTargetBytes int - var targetConn *natconn - - connError := func() (connError *onet.ConnectionError) { - defer func() { - if r := recover(); r != nil { - slog.Error("Panic in UDP loop: %v. Continuing to listen.", r) - debug.PrintStack() - } - }() - - // Error from ReadFrom - if err != nil { - return onet.NewConnectionError("ERR_READ", "Failed to read from client", err) - } - defer slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAddr.String())) - debugUDPAddr(h.logger, "Outbound packet.", clientAddr, slog.Int("bytes", clientProxyBytes)) - cipherData := cipherBuf[:clientProxyBytes] + connError := func() *onet.ConnectionError { var payload []byte var tgtUDPAddr *net.UDPAddr - targetConn = nm.Get(clientAddr.String()) if targetConn == nil { - ip := clientAddr.(*net.UDPAddr).AddrPort().Addr() + ip := clientConn.RemoteAddr().(*net.UDPAddr).AddrPort().Addr() var textData []byte - var cryptoKey *shadowsocks.EncryptionKey + var keyID string + textLazySlice := readBufPool.LazySlice() unpackStart := time.Now() - textData, keyID, cryptoKey, err := findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers, h.logger) + textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textLazySlice.Acquire(), pkt, h.ciphers, h.logger) timeToCipher := time.Since(unpackStart) + textLazySlice.Release() h.ssm.AddCipherSearch(err == nil, timeToCipher) if err != nil { return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err) } + assocMetrics.AddAuthentication(keyID) var onetErr *onet.ConnectionError if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { return onetErr } - udpConn, err := h.targetListener.ListenPacket(context.Background()) + // Create the target connection. + targetConn, err = h.targetListener.ListenPacket(ctx) if err != nil { return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create a `PacketConn`", err) } - - targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, keyID) + l = l.With(slog.Any("tgtListener", targetConn.LocalAddr())) + go relayTargetToClient(targetConn, clientConn, cryptoKey, assocMetrics, l) } else { unpackStart := time.Now() - textData, err := shadowsocks.Unpack(nil, cipherData, targetConn.cryptoKey) + textData, err := shadowsocks.Unpack(nil, pkt, cryptoKey) timeToCipher := time.Since(unpackStart) h.ssm.AddCipherSearch(err == nil, timeToCipher) @@ -214,7 +216,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { } } - debugUDPAddr(h.logger, "Proxy exit.", clientAddr, slog.Any("target", targetConn.LocalAddr())) + debugUDP(l, "Proxy exit.") proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature if err != nil { return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) @@ -224,19 +226,17 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { status := "OK" if connError != nil { - slog.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) + debugUDP(l, "Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) status = connError.Status } - if targetConn != nil { - targetConn.metrics.AddPacketFromClient(status, int64(clientProxyBytes), int64(proxyTargetBytes)) - } + assocMetrics.AddPacketFromClient(status, int64(clientProxyBytes), int64(proxyTargetBytes)) } } // Given the decrypted contents of a UDP packet, return // the payload and the destination address, or an error if // this packet cannot or should not be forwarded. -func (h *packetHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, *onet.ConnectionError) { +func (h *associationHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, *onet.ConnectionError) { tgtAddr := socks.SplitAddr(textData) if tgtAddr == nil { return nil, nil, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", nil) @@ -254,16 +254,150 @@ func (h *packetHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, * return payload, tgtUDPAddr, nil } +type AssociationHandleFunc func(ctx context.Context, conn net.Conn) + +// PacketServe listens for UDP packets on the provided [net.PacketConn] and creates +// and manages NAT associations. It uses a NAT map to track active associations and +// handles their lifecycle. +func PacketServe(clientConn net.PacketConn, assocHandle AssociationHandleFunc, metrics NATMetrics) { + nm := newNATmap() + ctx, contextCancel := context.WithCancel(context.Background()) + defer contextCancel() + + for { + lazySlice := readBufPool.LazySlice() + buffer := lazySlice.Acquire() + + expired := false + func() { + defer func() { + if r := recover(); r != nil { + slog.Error("Panic in UDP loop. Continuing to listen.", "err", r) + debug.PrintStack() + lazySlice.Release() + } + }() + n, clientAddr, err := clientConn.ReadFrom(buffer) + if err != nil { + lazySlice.Release() + if errors.Is(err, net.ErrClosed) { + expired = true + return + } + slog.Warn("Failed to read from client. Continuing to listen.", "err", err) + return + } + pkt := &packet{payload: buffer[:n], done: lazySlice.Release} + + // TODO(#19): Include server address in the NAT key as well. + assoc := nm.Get(clientAddr.String()) + if assoc == nil { + assoc = &association{ + pc: clientConn, + clientAddr: clientAddr, + readCh: make(chan *packet, 5), + doneCh: make(chan struct{}), + } + if err != nil { + slog.Error("Failed to handle association", slog.Any("err", err)) + return + } + + var existing bool + assoc, existing = nm.Add(clientAddr.String(), assoc) + if !existing { + metrics.AddNATEntry() + go func() { + assocHandle(ctx, assoc) + metrics.RemoveNATEntry() + close(assoc.doneCh) + }() + } + } + select { + case <-assoc.doneCh: + nm.Del(clientAddr.String()) + case assoc.readCh <- pkt: + default: + slog.Debug("Dropping packet due to full read queue") + // TODO: Add a metric to track number of dropped packets. + } + }() + if expired { + break + } + } +} + +type packet struct { + // The contents of the packet. + payload []byte + + // A function to call as soon as the payload has been consumed. This can be + // used to release resources. + done func() +} + +// association wraps a [net.PacketConn] with an address into a [net.Conn]. +type association struct { + pc net.PacketConn + clientAddr net.Addr + readCh chan *packet + doneCh chan struct{} +} + +var _ net.Conn = (*association)(nil) + +func (c *association) Read(p []byte) (int, error) { + pkt, ok := <-c.readCh + if !ok { + return 0, net.ErrClosed + } + n := copy(p, pkt.payload) + pkt.done() + if n < len(pkt.payload) { + return n, io.ErrShortBuffer + } + return n, nil +} + +func (c *association) Write(b []byte) (n int, err error) { + return c.pc.WriteTo(b, c.clientAddr) +} + +func (c *association) Close() error { + close(c.readCh) + return c.pc.Close() +} + +func (c *association) LocalAddr() net.Addr { + return c.pc.LocalAddr() +} + +func (c *association) RemoteAddr() net.Addr { + return c.clientAddr +} + +func (c *association) SetDeadline(t time.Time) error { + return errors.ErrUnsupported +} + +func (c *association) SetReadDeadline(t time.Time) error { + return errors.ErrUnsupported +} + +func (c *association) SetWriteDeadline(t time.Time) error { + return errors.ErrUnsupported +} + func isDNS(addr net.Addr) bool { _, port, _ := net.SplitHostPort(addr.String()) return port == "53" } -type natconn struct { +type timedPacketConn struct { net.PacketConn - cryptoKey *shadowsocks.EncryptionKey - metrics UDPConnMetrics - // NAT timeout to apply for non-DNS packets. + // Connection timeout to apply for non-DNS packets. defaultTimeout time.Duration // Current read deadline of PacketConn. Used to avoid decreasing the // deadline. Initially zero. @@ -273,7 +407,7 @@ type natconn struct { fastClose sync.Once } -func (c *natconn) onWrite(addr net.Addr) { +func (c *timedPacketConn) onWrite(addr net.Addr) { // Fast close is only allowed if there has been exactly one write, // and it was a DNS query. isDNS := isDNS(addr) @@ -296,7 +430,7 @@ func (c *natconn) onWrite(addr net.Addr) { } } -func (c *natconn) onRead(addr net.Addr) { +func (c *timedPacketConn) onRead(addr net.Addr) { c.fastClose.Do(func() { if isDNS(addr) { // The next ReadFrom() should time out immediately. @@ -305,12 +439,12 @@ func (c *natconn) onRead(addr net.Addr) { }) } -func (c *natconn) WriteTo(buf []byte, dst net.Addr) (int, error) { +func (c *timedPacketConn) WriteTo(buf []byte, dst net.Addr) (int, error) { c.onWrite(dst) return c.PacketConn.WriteTo(buf, dst) } -func (c *natconn) ReadFrom(buf []byte) (int, net.Addr, error) { +func (c *timedPacketConn) ReadFrom(buf []byte) (int, net.Addr, error) { n, addr, err := c.PacketConn.ReadFrom(buf) if err == nil { c.onRead(addr) @@ -321,99 +455,65 @@ func (c *natconn) ReadFrom(buf []byte) (int, net.Addr, error) { // Packet NAT table type natmap struct { sync.RWMutex - keyConn map[string]*natconn - logger *slog.Logger - timeout time.Duration - metrics UDPMetrics + associations map[string]*association } -func newNATmap(timeout time.Duration, sm UDPMetrics, l *slog.Logger) *natmap { - m := &natmap{logger: l, metrics: sm} - m.keyConn = make(map[string]*natconn) - m.timeout = timeout - return m +func newNATmap() *natmap { + return &natmap{associations: make(map[string]*association)} } -func (m *natmap) Get(key string) *natconn { +// Get returns a UDP NAT entry from the natmap. +func (m *natmap) Get(clientAddr string) *association { m.RLock() defer m.RUnlock() - return m.keyConn[key] + return m.associations[clientAddr] } -func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, connMetrics UDPConnMetrics) *natconn { - entry := &natconn{ - PacketConn: pc, - cryptoKey: cryptoKey, - metrics: connMetrics, - defaultTimeout: m.timeout, - } - +// Del deletes a UDP NAT entry from the natmap. +func (m *natmap) Del(clientAddr string) { m.Lock() defer m.Unlock() - m.keyConn[key] = entry - return entry -} - -func (m *natmap) del(key string) net.PacketConn { - m.Lock() - defer m.Unlock() - - entry, ok := m.keyConn[key] - if ok { - delete(m.keyConn, key) - return entry + if _, ok := m.associations[clientAddr]; ok { + delete(m.associations, clientAddr) } - return nil } -func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, keyID string) *natconn { - connMetrics := m.metrics.AddUDPNatEntry(clientAddr, keyID) - entry := m.set(clientAddr.String(), targetConn, cryptoKey, connMetrics) - - go func() { - timedCopy(clientAddr, clientConn, entry, m.logger) - connMetrics.RemoveNatEntry() - if pc := m.del(clientAddr.String()); pc != nil { - pc.Close() - } - }() - return entry -} - -func (m *natmap) Close() error { +// Add adds a UDP NAT entry to the natmap and returns it. If it already existed, +// in the natmap, the existing entry is returned instead. +func (m *natmap) Add(clientAddr string, assoc *association) (*association, bool) { m.Lock() defer m.Unlock() - var err error - now := time.Now() - for _, pc := range m.keyConn { - if e := pc.SetReadDeadline(now); e != nil { - err = e - } + if existing, ok := m.associations[clientAddr]; ok { + return existing, true } - return err + + m.associations[clientAddr] = assoc + return assoc, false } // Get the maximum length of the shadowsocks address header by parsing // and serializing an IPv6 address from the example range. var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) -// copy from target to client until read timeout -func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, l *slog.Logger) { +// relayTargetToClient copies from target to client until read timeout. +func relayTargetToClient(targetConn net.PacketConn, clientConn io.Writer, cryptoKey *shadowsocks.EncryptionKey, m UDPAssociationMetrics, l *slog.Logger) { + defer targetConn.Close() + // pkt is used for in-place encryption of downstream UDP packets, with the layout // [padding?][salt][address][body][tag][extra] // Padding is only used if the address is IPv4. pkt := make([]byte, serverUDPBufferSize) - saltSize := targetConn.cryptoKey.SaltSize() + saltSize := cryptoKey.SaltSize() // Leave enough room at the beginning of the packet for a max-length header (i.e. IPv6). bodyStart := saltSize + maxAddrLen expired := false for { - var bodyLen, proxyClientBytes int - connError := func() (connError *onet.ConnectionError) { + var targetProxyBytes, proxyClientBytes int + connError := func() *onet.ConnectionError { var ( raddr net.Addr err error @@ -422,7 +522,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco // [padding?][salt][address][body][tag][unused] // |-- bodyStart --|[ readBuf ] readBuf := pkt[bodyStart:] - bodyLen, raddr, err = targetConn.ReadFrom(readBuf) + targetProxyBytes, raddr, err = targetConn.ReadFrom(readBuf) if err != nil { if netErr, ok := err.(net.Error); ok { if netErr.Timeout() { @@ -433,13 +533,13 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco return onet.NewConnectionError("ERR_READ", "Failed to read from target", err) } - debugUDPAddr(l, "Got response.", clientAddr, slog.Any("target", raddr)) + debugUDP(l, "Got response.", slog.Any("rtarget", raddr)) srcAddr := socks.ParseAddr(raddr.String()) addrStart := bodyStart - len(srcAddr) // `plainTextBuf` concatenates the SOCKS address and body: // [padding?][salt][address][body][tag][unused] // |-- addrStart -|[plaintextBuf ] - plaintextBuf := pkt[addrStart : bodyStart+bodyLen] + plaintextBuf := pkt[addrStart : bodyStart+targetProxyBytes] copy(plaintextBuf, srcAddr) // saltStart is 0 if raddr is IPv6. @@ -450,11 +550,11 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco // [ packBuf ] // [ buf ] packBuf := pkt[saltStart:] - buf, err := shadowsocks.Pack(packBuf, plaintextBuf, targetConn.cryptoKey) // Encrypt in-place + buf, err := shadowsocks.Pack(packBuf, plaintextBuf, cryptoKey) // Encrypt in-place if err != nil { return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err) } - proxyClientBytes, err = clientConn.WriteTo(buf, clientAddr) + proxyClientBytes, err = clientConn.Write(buf) if err != nil { return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err) } @@ -462,36 +562,28 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco }() status := "OK" if connError != nil { - slog.LogAttrs(nil, slog.LevelDebug, "UDP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) + debugUDP(l, "Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) status = connError.Status } + if expired { break } - targetConn.metrics.AddPacketFromTarget(status, int64(bodyLen), int64(proxyClientBytes)) + m.AddPacketFromTarget(status, int64(targetProxyBytes), int64(proxyClientBytes)) } } -// NoOpUDPConnMetrics is a [UDPConnMetrics] that doesn't do anything. Useful in tests +// NoOpUDPAssociationMetrics is a [UDPAssociationMetrics] that doesn't do anything. Useful in tests // or if you don't want to track metrics. -type NoOpUDPConnMetrics struct{} +type NoOpUDPAssociationMetrics struct{} -var _ UDPConnMetrics = (*NoOpUDPConnMetrics)(nil) +var _ UDPAssociationMetrics = (*NoOpUDPAssociationMetrics)(nil) -func (m *NoOpUDPConnMetrics) AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64) { -} +func (m *NoOpUDPAssociationMetrics) AddAuthentication(accessKey string) {} -func (m *NoOpUDPConnMetrics) AddPacketFromTarget(status string, targetProxyBytes, proxyClientBytes int64) { +func (m *NoOpUDPAssociationMetrics) AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64) { } - -func (m *NoOpUDPConnMetrics) RemoveNatEntry() {} - -// NoOpUDPMetrics is a [UDPMetrics] that doesn't do anything. Useful in tests -// or if you don't want to track metrics. -type NoOpUDPMetrics struct{} - -var _ UDPMetrics = (*NoOpUDPMetrics)(nil) - -func (m *NoOpUDPMetrics) AddUDPNatEntry(clientAddr net.Addr, accessKey string) UDPConnMetrics { - return &NoOpUDPConnMetrics{} +func (m *NoOpUDPAssociationMetrics) AddPacketFromTarget(status string, targetProxyBytes, proxyClientBytes int64) { +} +func (m *NoOpUDPAssociationMetrics) AddClose() { } diff --git a/service/udp_linux.go b/service/udp_linux.go index 218727ad..10b3282f 100644 --- a/service/udp_linux.go +++ b/service/udp_linux.go @@ -20,11 +20,15 @@ import ( "context" "fmt" "net" + "time" "github.com/Jigsaw-Code/outline-sdk/transport" ) type udpListener struct { + // NAT mapping timeout is the default time a mapping will stay active + // without packets traversing the NAT, applied to non-DNS packets. + timeout time.Duration // fwmark can be used in conjunction with other Linux networking features like cgroups, network // namespaces, and TC (Traffic Control) for sophisticated network management. // Value of 0 disables fwmark (SO_MARK) (Linux only) @@ -33,8 +37,8 @@ type udpListener struct { // NewPacketListener creates a new PacketListener that listens on UDP // and optionally sets a firewall mark on the socket (Linux only). -func MakeTargetUDPListener(fwmark uint) transport.PacketListener { - return &udpListener{fwmark: fwmark} +func MakeTargetUDPListener(timeout time.Duration, fwmark uint) transport.PacketListener { + return &udpListener{timeout: timeout, fwmark: fwmark} } func (ln *udpListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { @@ -57,5 +61,5 @@ func (ln *udpListener) ListenPacket(ctx context.Context) (net.PacketConn, error) } } - return conn, nil + return &timedPacketConn{PacketConn: conn, defaultTimeout: ln.timeout}, nil } diff --git a/service/udp_other.go b/service/udp_other.go index 046cf910..0cfde76a 100644 --- a/service/udp_other.go +++ b/service/udp_other.go @@ -17,14 +17,34 @@ package service import ( + "context" + "net" + "time" + "github.com/Jigsaw-Code/outline-sdk/transport" ) +type udpListener struct { + *transport.UDPListener + + // NAT mapping timeout is the default time a mapping will stay active + // without packets traversing the NAT, applied to non-DNS packets. + timeout time.Duration +} + // fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management. // Value of 0 disables fwmark (SO_MARK) -func MakeTargetUDPListener(fwmark uint) transport.PacketListener { +func MakeTargetUDPListener(timeout time.Duration, fwmark uint) transport.PacketListener { if fwmark != 0 { panic("fwmark is linux-specific feature and should be 0") } - return &transport.UDPListener{Address: ""} + return &udpListener{UDPListener: &transport.UDPListener{Address: ""}} +} + +func (ln *udpListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { + conn, err := ln.UDPListener.ListenPacket(ctx) + if err != nil { + return nil, err + } + return &timedPacketConn{PacketConn: conn, defaultTimeout: ln.timeout}, nil } diff --git a/service/udp_test.go b/service/udp_test.go index 71aef184..964012fb 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -16,9 +16,11 @@ package service import ( "bytes" - "errors" + "context" + "io" "net" "net/netip" + "sync" "testing" "time" @@ -36,7 +38,7 @@ const timeout = 5 * time.Minute var clientAddr = net.UDPAddr{IP: []byte{192, 0, 2, 1}, Port: 12345} var targetAddr = net.UDPAddr{IP: []byte{192, 0, 2, 2}, Port: 54321} - +var localAddr = net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 9} var dnsAddr = net.UDPAddr{IP: []byte{192, 0, 2, 3}, Port: 53} var natCryptoKey *shadowsocks.EncryptionKey @@ -46,34 +48,61 @@ func init() { natCryptoKey, _ = shadowsocks.NewEncryptionKey(shadowsocks.CHACHA20IETFPOLY1305, "test password") } -type packet struct { +type fakePacket struct { addr net.Addr payload []byte err error } +type packetListener struct { + conn net.PacketConn +} + +func (ln *packetListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { + return ln.conn, nil +} + type fakePacketConn struct { net.PacketConn - send chan packet - recv chan packet + send chan fakePacket + recv chan fakePacket deadline time.Time + mu sync.Mutex } func makePacketConn() *fakePacketConn { return &fakePacketConn{ - send: make(chan packet, 1), - recv: make(chan packet), + send: make(chan fakePacket, 1), + recv: make(chan fakePacket), } } +func (conn *fakePacketConn) getReadDeadline() time.Time { + conn.mu.Lock() + defer conn.mu.Unlock() + return conn.deadline +} + func (conn *fakePacketConn) SetReadDeadline(deadline time.Time) error { + conn.mu.Lock() + defer conn.mu.Unlock() conn.deadline = deadline return nil } func (conn *fakePacketConn) WriteTo(payload []byte, addr net.Addr) (int, error) { - conn.send <- packet{addr, payload, nil} - return len(payload), nil + conn.mu.Lock() + defer conn.mu.Unlock() + + var err error + defer func() { + if recover() != nil { + err = net.ErrClosed + } + }() + + conn.send <- fakePacket{addr, payload, nil} + return len(payload), err } func (conn *fakePacketConn) ReadFrom(buffer []byte) (int, net.Addr, error) { @@ -83,119 +112,167 @@ func (conn *fakePacketConn) ReadFrom(buffer []byte) (int, net.Addr, error) { } n := copy(buffer, pkt.payload) if n < len(pkt.payload) { - return n, pkt.addr, errors.New("buffer was too short") + return n, pkt.addr, io.ErrShortBuffer } return n, pkt.addr, pkt.err } func (conn *fakePacketConn) Close() error { + conn.mu.Lock() + defer conn.mu.Unlock() close(conn.send) close(conn.recv) return nil } +func (conn *fakePacketConn) LocalAddr() net.Addr { + return &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 9999} +} + +func (conn *fakePacketConn) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 8888} +} + type udpReport struct { - clientAddr net.Addr accessKey, status string clientProxyBytes, proxyTargetBytes int64 } // Stub metrics implementation for testing NAT behaviors. -type fakeUDPConnMetrics struct { - clientAddr net.Addr - accessKey string - upstreamPackets []udpReport +type natTestMetrics struct { + natEntriesAdded int } -var _ UDPConnMetrics = (*fakeUDPConnMetrics)(nil) +var _ NATMetrics = (*natTestMetrics)(nil) -func (m *fakeUDPConnMetrics) AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64) { - m.upstreamPackets = append(m.upstreamPackets, udpReport{m.clientAddr, m.accessKey, status, clientProxyBytes, proxyTargetBytes}) +func (m *natTestMetrics) AddNATEntry() { + m.natEntriesAdded++ +} +func (m *natTestMetrics) RemoveNATEntry() {} + +type fakeUDPAssociationMetrics struct { + accessKey string + upstreamPackets []udpReport + mu sync.Mutex } -func (m *fakeUDPConnMetrics) AddPacketFromTarget(status string, targetProxyBytes, proxyClientBytes int64) { +var _ UDPAssociationMetrics = (*fakeUDPAssociationMetrics)(nil) + +func (m *fakeUDPAssociationMetrics) AddAuthentication(key string) { + m.mu.Lock() + defer m.mu.Unlock() + m.accessKey = key } -func (m *fakeUDPConnMetrics) RemoveNatEntry() { +func (m *fakeUDPAssociationMetrics) AddPacketFromClient(status string, clientProxyBytes, proxyTargetBytes int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.upstreamPackets = append(m.upstreamPackets, udpReport{m.accessKey, status, clientProxyBytes, proxyTargetBytes}) } -type natTestMetrics struct { - connMetrics []fakeUDPConnMetrics +func (m *fakeUDPAssociationMetrics) AddPacketFromTarget(status string, targetProxyBytes, proxyClientBytes int64) { } -var _ UDPMetrics = (*natTestMetrics)(nil) +func (m *fakeUDPAssociationMetrics) AddClose() {} -func (m *natTestMetrics) AddUDPNatEntry(clientAddr net.Addr, accessKey string) UDPConnMetrics { - cm := fakeUDPConnMetrics{ - clientAddr: clientAddr, - accessKey: accessKey, +// sendSSPayload sends a single Shadowsocks packet to the provided connection. +// The packet is constructed with the given address, cipher, and payload. +func sendSSPayload(conn *fakePacketConn, addr net.Addr, cipher *shadowsocks.EncryptionKey, payload []byte) { + socksAddr := socks.ParseAddr(addr.String()) + plaintext := append(socksAddr, payload...) + ciphertext := make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize()) + shadowsocks.Pack(ciphertext, plaintext, cipher) + conn.recv <- fakePacket{ + addr: &clientAddr, + payload: ciphertext, } - m.connMetrics = append(m.connMetrics, cm) - return &m.connMetrics[len(m.connMetrics)-1] } -// Takes a validation policy, and returns the metrics it -// generates when localhost access is attempted -func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics { +// startTestHandler creates a new association handler with a fake +// client and target connection for testing purposes. It also starts a +// PacketServe goroutine to handle incoming packets on the client connection. +func startTestHandler() (AssociationHandler, func(target net.Addr, payload []byte), *fakePacketConn) { ciphers, _ := MakeTestCiphers([]string{"asdf"}) cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey + handler := NewAssociationHandler(ciphers, nil) clientConn := makePacketConn() - metrics := &natTestMetrics{} - handler := NewPacketHandler(timeout, ciphers, metrics, &fakeShadowsocksMetrics{}) - handler.SetTargetIPValidator(validator) - done := make(chan struct{}) + targetConn := makePacketConn() + handler.SetTargetPacketListener(&packetListener{targetConn}) + go PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + handler.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) + return handler, func(target net.Addr, payload []byte) { + sendSSPayload(clientConn, target, cipher, payload) + }, targetConn +} + +func TestAssociationCloseWhileReading(t *testing.T) { + assoc := &association{ + pc: makePacketConn(), + clientAddr: &clientAddr, + readCh: make(chan *packet), + } go func() { - handler.Handle(clientConn) - done <- struct{}{} + buf := make([]byte, 1024) + assoc.Read(buf) }() - // Send one packet to the "discard" port on localhost - targetAddr := socks.ParseAddr("127.0.0.1:9") - for _, payload := range payloads { - plaintext := append(targetAddr, payload...) - ciphertext := make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize()) - shadowsocks.Pack(ciphertext, plaintext, cipher) - clientConn.recv <- packet{ - addr: &net.UDPAddr{ - IP: net.ParseIP("192.0.2.1"), - Port: 54321, - }, - payload: ciphertext, - } - } + err := assoc.Close() - clientConn.Close() - <-done - return metrics + assert.NoError(t, err, "Close should not panic or return an error") } -func TestIPFilter(t *testing.T) { - // Test both the first-packet and subsequent-packet cases. - payloads := [][]byte{[]byte("payload1"), []byte("payload2")} +func TestAssociationHandler_Handle_IPFilter(t *testing.T) { + t.Run("RequirePublicIP blocks localhost", func(t *testing.T) { + handler, sendPayload, targetConn := startTestHandler() + handler.SetTargetIPValidator(onet.RequirePublicIP) + + sendPayload(&localAddr, []byte{1, 2, 3}) - t.Run("Localhost allowed", func(t *testing.T) { - metrics := sendToDiscard(payloads, allowAll) - assert.Equal(t, len(metrics.connMetrics), 1, "Expected 1 NAT entry, not %d", len(metrics.connMetrics)) + select { + case <-targetConn.send: + t.Errorf("Expected no packets to be sent") + case <-time.After(100 * time.Millisecond): + return + } }) - t.Run("Localhost not allowed", func(t *testing.T) { - metrics := sendToDiscard(payloads, onet.RequirePublicIP) - assert.Equal(t, 0, len(metrics.connMetrics), "Unexpected NAT entry on rejected packet") + t.Run("allowAll allows localhost", func(t *testing.T) { + handler, sendPayload, targetConn := startTestHandler() + handler.SetTargetIPValidator(allowAll) + + sendPayload(&localAddr, []byte{1, 2, 3}) + + sent := <-targetConn.send + if !bytes.Equal([]byte{1, 2, 3}, sent.payload) { + t.Errorf("Expected %v, but got %v", []byte{1, 2, 3}, sent.payload) + } }) } func TestUpstreamMetrics(t *testing.T) { + ciphers, _ := MakeTestCiphers([]string{"asdf"}) + cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey + handler := NewAssociationHandler(ciphers, nil) + clientConn := makePacketConn() + targetConn := makePacketConn() + handler.SetTargetPacketListener(&packetListener{targetConn}) + metrics := &fakeUDPAssociationMetrics{} + go PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + handler.HandleAssociation(ctx, conn, metrics) + }, &natTestMetrics{}) + // Test both the first-packet and subsequent-packet cases. const N = 10 - payloads := make([][]byte, 0) for i := 1; i <= N; i++ { - payloads = append(payloads, make([]byte, i)) + sendSSPayload(clientConn, &targetAddr, cipher, make([]byte, i)) + <-targetConn.send } - metrics := sendToDiscard(payloads, allowAll) - - assert.Equal(t, N, len(metrics.connMetrics[0].upstreamPackets), "Expected %d reports, not %v", N, metrics.connMetrics[0].upstreamPackets) - for i, report := range metrics.connMetrics[0].upstreamPackets { + metrics.mu.Lock() + defer metrics.mu.Unlock() + assert.Equal(t, N, len(metrics.upstreamPackets), "Expected %d reports, not %d", N, len(metrics.upstreamPackets)) + for i, report := range metrics.upstreamPackets { assert.Equal(t, int64(i+1), report.proxyTargetBytes, "Expected %d payload bytes, not %d", i+1, report.proxyTargetBytes) assert.Greater(t, report.clientProxyBytes, report.proxyTargetBytes, "Expected nonzero input overhead (%d > %d)", report.clientProxyBytes, report.proxyTargetBytes) assert.Equal(t, "id-0", report.accessKey, "Unexpected access key name: %s", report.accessKey) @@ -211,196 +288,233 @@ func assertAlmostEqual(t *testing.T, a, b time.Time) { } } -func TestNATEmpty(t *testing.T) { - nat := newNATmap(timeout, &natTestMetrics{}, noopLogger()) - if nat.Get("foo") != nil { - t.Error("Expected nil value from empty NAT map") +func assertUDPAddrEqual(t *testing.T, a net.Addr, b *net.UDPAddr) { + addr, ok := a.(*net.UDPAddr) + if !ok || !addr.IP.Equal(b.IP) || addr.Port != b.Port || addr.Zone != b.Zone { + t.Errorf("Mismatched address: %v != %v", a, b) } } -func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) { - nat := newNATmap(timeout, &natTestMetrics{}, noopLogger()) - clientConn := makePacketConn() - targetConn := makePacketConn() - nat.Add(&clientAddr, clientConn, natCryptoKey, targetConn, "key id") - entry := nat.Get(clientAddr.String()) - return clientConn, targetConn, entry +// Implements net.Error +type fakeTimeoutError struct { + error } -func TestNATGet(t *testing.T) { - _, targetConn, entry := setupNAT() - if entry == nil { - t.Fatal("Failed to find target conn") - } - if entry.PacketConn != targetConn { - t.Error("Mismatched connection returned") - } +func (e *fakeTimeoutError) Timeout() bool { + return true } -func TestNATWrite(t *testing.T) { - _, targetConn, entry := setupNAT() - - // Simulate one generic packet being sent - buf := []byte{1} - entry.WriteTo([]byte{1}, &targetAddr) - assertAlmostEqual(t, targetConn.deadline, time.Now().Add(timeout)) - sent := <-targetConn.send - if !bytes.Equal(sent.payload, buf) { - t.Errorf("Mismatched payload: %v != %v", sent.payload, buf) - } - if sent.addr != &targetAddr { - t.Errorf("Mismatched address: %v != %v", sent.addr, &targetAddr) - } +func (e *fakeTimeoutError) Temporary() bool { + return false } -func TestNATWriteDNS(t *testing.T) { - _, targetConn, entry := setupNAT() +func TestTimedPacketConn(t *testing.T) { + t.Run("Write", func(t *testing.T) { + _, sendPayload, targetConn := startTestHandler() - // Simulate one DNS query being sent. - buf := []byte{1} - entry.WriteTo(buf, &dnsAddr) - // DNS-only connections have a fixed timeout of 17 seconds. - assertAlmostEqual(t, targetConn.deadline, time.Now().Add(17*time.Second)) - sent := <-targetConn.send - if !bytes.Equal(sent.payload, buf) { - t.Errorf("Mismatched payload: %v != %v", sent.payload, buf) - } - if sent.addr != &dnsAddr { - t.Errorf("Mismatched address: %v != %v", sent.addr, &targetAddr) - } -} + buf := []byte{1} + sendPayload(&targetAddr, buf) -func TestNATWriteDNSMultiple(t *testing.T) { - _, targetConn, entry := setupNAT() - - // Simulate three DNS queries being sent. - buf := []byte{1} - entry.WriteTo(buf, &dnsAddr) - <-targetConn.send - entry.WriteTo(buf, &dnsAddr) - <-targetConn.send - entry.WriteTo(buf, &dnsAddr) - <-targetConn.send - // DNS-only connections have a fixed timeout of 17 seconds. - assertAlmostEqual(t, targetConn.deadline, time.Now().Add(17*time.Second)) -} - -func TestNATWriteMixed(t *testing.T) { - _, targetConn, entry := setupNAT() - - // Simulate both non-DNS and DNS packets being sent. - buf := []byte{1} - entry.WriteTo(buf, &targetAddr) - <-targetConn.send - entry.WriteTo(buf, &dnsAddr) - <-targetConn.send - // Mixed DNS and non-DNS connections should have the user-specified timeout. - assertAlmostEqual(t, targetConn.deadline, time.Now().Add(timeout)) -} - -func TestNATFastClose(t *testing.T) { - clientConn, targetConn, entry := setupNAT() - - // Send one DNS query. - query := []byte{1} - entry.WriteTo(query, &dnsAddr) - sent := <-targetConn.send - require.Len(t, sent.payload, 1) - // Send the response. - response := []byte{1, 2, 3, 4, 5} - received := packet{addr: &dnsAddr, payload: response} - targetConn.recv <- received - sent, ok := <-clientConn.send - if !ok { - t.Error("clientConn was closed") - } - if len(sent.payload) <= len(response) { - t.Error("Packet is too short to be shadowsocks-AEAD") - } - if sent.addr != &clientAddr { - t.Errorf("Address mismatch: %v != %v", sent.addr, clientAddr) - } + assertAlmostEqual(t, targetConn.getReadDeadline(), time.Now().Add(timeout)) + sent := <-targetConn.send + if !bytes.Equal(sent.payload, buf) { + t.Errorf("Mismatched payload: %v != %v", sent.payload, buf) + } + assertUDPAddrEqual(t, sent.addr, &targetAddr) + }) - // targetConn should be scheduled to close immediately. - assertAlmostEqual(t, targetConn.deadline, time.Now()) -} + t.Run("WriteDNS", func(t *testing.T) { + _, sendPayload, targetConn := startTestHandler() -func TestNATNoFastClose_NotDNS(t *testing.T) { - clientConn, targetConn, entry := setupNAT() + // Simulate one DNS query being sent. + buf := []byte{1} + sendPayload(&dnsAddr, buf) - // Send one non-DNS packet. - query := []byte{1} - entry.WriteTo(query, &targetAddr) - sent := <-targetConn.send - require.Len(t, sent.payload, 1) - // Send the response. - response := []byte{1, 2, 3, 4, 5} - received := packet{addr: &targetAddr, payload: response} - targetConn.recv <- received - sent, ok := <-clientConn.send - if !ok { - t.Error("clientConn was closed") - } - if len(sent.payload) <= len(response) { - t.Error("Packet is too short to be shadowsocks-AEAD") - } - if sent.addr != &clientAddr { - t.Errorf("Address mismatch: %v != %v", sent.addr, clientAddr) - } - // targetConn should be scheduled to close after the full timeout. - assertAlmostEqual(t, targetConn.deadline, time.Now().Add(timeout)) -} + // DNS-only connections have a fixed timeout of 17 seconds. + assertAlmostEqual(t, targetConn.getReadDeadline(), time.Now().Add(17*time.Second)) + sent := <-targetConn.send + if !bytes.Equal(sent.payload, buf) { + t.Errorf("Mismatched payload: %v != %v", sent.payload, buf) + } + assertUDPAddrEqual(t, sent.addr, &dnsAddr) + }) -func TestNATNoFastClose_MultipleDNS(t *testing.T) { - clientConn, targetConn, entry := setupNAT() + t.Run("WriteDNSMultiple", func(t *testing.T) { + _, sendPayload, targetConn := startTestHandler() - // Send two DNS packets. - query1 := []byte{1} - entry.WriteTo(query1, &dnsAddr) - <-targetConn.send - query2 := []byte{2} - entry.WriteTo(query2, &dnsAddr) - <-targetConn.send + // Simulate three DNS queries being sent. + buf := []byte{1} + sendPayload(&dnsAddr, buf) + <-targetConn.send + sendPayload(&dnsAddr, buf) + <-targetConn.send + sendPayload(&dnsAddr, buf) + <-targetConn.send - // Send a response. - response := []byte{1, 2, 3, 4, 5} - received := packet{addr: &dnsAddr, payload: response} - targetConn.recv <- received - <-clientConn.send + // DNS-only connections have a fixed timeout of 17 seconds. + assertAlmostEqual(t, targetConn.getReadDeadline(), time.Now().Add(17*time.Second)) + }) - // targetConn should be scheduled to close after the DNS timeout. - assertAlmostEqual(t, targetConn.deadline, time.Now().Add(17*time.Second)) -} + t.Run("WriteMixed", func(t *testing.T) { + _, sendPayload, targetConn := startTestHandler() -// Implements net.Error -type fakeTimeoutError struct { - error -} + // Simulate both non-DNS and DNS packets being sent. + buf := []byte{1} + sendPayload(&targetAddr, buf) + <-targetConn.send + sendPayload(&dnsAddr, buf) + <-targetConn.send -func (e *fakeTimeoutError) Timeout() bool { - return true -} + // Mixed DNS and non-DNS connections should have the user-specified timeout. + assertAlmostEqual(t, targetConn.getReadDeadline(), time.Now().Add(timeout)) + }) -func (e *fakeTimeoutError) Temporary() bool { - return false + t.Run("FastClose", func(t *testing.T) { + ciphers, _ := MakeTestCiphers([]string{"asdf"}) + cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey + handler := NewAssociationHandler(ciphers, nil) + clientConn := makePacketConn() + targetConn := makePacketConn() + handler.SetTargetPacketListener(&packetListener{targetConn}) + go PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + handler.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) + + // Send one DNS query. + sendSSPayload(clientConn, &dnsAddr, cipher, []byte{1}) + sent := <-targetConn.send + require.Len(t, sent.payload, 1) + // Send the response. + response := []byte{1, 2, 3, 4, 5} + received := fakePacket{addr: &dnsAddr, payload: response} + targetConn.recv <- received + sent, ok := <-clientConn.send + if !ok { + t.Error("clientConn was closed") + } + + // targetConn should be scheduled to close immediately. + assertAlmostEqual(t, targetConn.getReadDeadline(), time.Now()) + }) + + t.Run("NoFastClose_NotDNS", func(t *testing.T) { + ciphers, _ := MakeTestCiphers([]string{"asdf"}) + cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey + handler := NewAssociationHandler(ciphers, nil) + clientConn := makePacketConn() + targetConn := makePacketConn() + handler.SetTargetPacketListener(&packetListener{targetConn}) + go PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + handler.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) + + // Send one non-DNS packet. + sendSSPayload(clientConn, &targetAddr, cipher, []byte{1}) + sent := <-targetConn.send + require.Len(t, sent.payload, 1) + // Send the response. + response := []byte{1, 2, 3, 4, 5} + received := fakePacket{addr: &targetAddr, payload: response} + targetConn.recv <- received + sent, ok := <-clientConn.send + if !ok { + t.Error("clientConn was closed") + } + + // targetConn should be scheduled to close after the full timeout. + assertAlmostEqual(t, targetConn.getReadDeadline(), time.Now().Add(timeout)) + }) + + t.Run("NoFastClose_MultipleDNS", func(t *testing.T) { + ciphers, _ := MakeTestCiphers([]string{"asdf"}) + cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey + handler := NewAssociationHandler(ciphers, nil) + clientConn := makePacketConn() + targetConn := makePacketConn() + handler.SetTargetPacketListener(&packetListener{targetConn}) + go PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + handler.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) + + // Send two DNS packets. + sendSSPayload(clientConn, &dnsAddr, cipher, []byte{1}) + <-targetConn.send + sendSSPayload(clientConn, &dnsAddr, cipher, []byte{2}) + <-targetConn.send + + // Send a response. + response := []byte{1, 2, 3, 4, 5} + received := fakePacket{addr: &dnsAddr, payload: response} + targetConn.recv <- received + <-clientConn.send + + // targetConn should be scheduled to close after the DNS timeout. + assertAlmostEqual(t, targetConn.getReadDeadline(), time.Now().Add(17*time.Second)) + }) + + t.Run("Timeout", func(t *testing.T) { + _, sendPayload, targetConn := startTestHandler() + + // Simulate a non-DNS initial packet. + sendPayload(&targetAddr, []byte{1}) + <-targetConn.send + // Simulate a read timeout. + received := fakePacket{err: &fakeTimeoutError{}} + before := time.Now() + targetConn.recv <- received + // Wait for targetConn to close. + if _, ok := <-targetConn.send; ok { + t.Error("targetConn should be closed due to read timeout") + } + + // targetConn should be closed as soon as the timeout error is received. + assertAlmostEqual(t, before, time.Now()) + }) } -func TestNATTimeout(t *testing.T) { - _, targetConn, entry := setupNAT() +func TestNATMap(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + nm := newNATmap() + if nm.Get("foo") != nil { + t.Error("Expected nil value from empty NAT map") + } + }) - // Simulate a non-DNS initial packet. - entry.WriteTo([]byte{1}, &targetAddr) - <-targetConn.send - // Simulate a read timeout. - received := packet{err: &fakeTimeoutError{}} - before := time.Now() - targetConn.recv <- received - // Wait for targetConn to close. - if _, ok := <-targetConn.send; ok { - t.Error("targetConn should be closed due to read timeout") - } - // targetConn should be closed as soon as the timeout error is received. - assertAlmostEqual(t, before, time.Now()) + t.Run("Add", func(t *testing.T) { + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + assoc1 := &association{} + + nm.Add(addr.String(), assoc1) + assert.Equal(t, assoc1, nm.Get(addr.String()), "Get should return the correct connection") + + assoc2 := &association{} + nm.Add(addr.String(), assoc2) + assert.Equal(t, assoc2, nm.Get(addr.String()), "Adding with the same address should overwrite the entry") + }) + + t.Run("Get", func(t *testing.T) { + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + assoc := &association{} + nm.Add(addr.String(), assoc) + + assert.Equal(t, assoc, nm.Get(addr.String()), "Get should return the correct connection for an existing address") + + addr2 := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 5678} + assert.Nil(t, nm.Get(addr2.String()), "Get should return nil for a non-existent address") + }) + + t.Run("Del", func(t *testing.T) { + nm := newNATmap() + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234} + assoc := &association{} + nm.Add(addr.String(), assoc) + + nm.Del(addr.String()) + + assert.Nil(t, nm.Get(addr.String()), "Get should return nil after deleting the entry") + }) } // Simulates receiving invalid UDP packets on a server with 100 ciphers. @@ -485,9 +599,9 @@ func TestUDPEarlyClose(t *testing.T) { if err != nil { t.Fatal(err) } - testMetrics := &natTestMetrics{} const testTimeout = 200 * time.Millisecond - s := NewPacketHandler(testTimeout, cipherList, testMetrics, &fakeShadowsocksMetrics{}) + handler := NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{}) + handler.SetTargetPacketListener(&packetListener{makePacketConn()}) clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { @@ -495,7 +609,9 @@ func TestUDPEarlyClose(t *testing.T) { } require.Nil(t, clientConn.Close()) // This should return quickly without timing out. - s.Handle(clientConn) + go PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + handler.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{}) + }, &natTestMetrics{}) } // Makes sure the UDP listener returns [io.ErrClosed] on reads and writes after Close().