diff --git a/service/udp.go b/service/udp.go index 18d96a8b..af0beb19 100644 --- a/service/udp.go +++ b/service/udp.go @@ -108,7 +108,7 @@ func NewPacketHandler(cipherList CipherList, ssMetrics ShadowsocksConnMetrics) P // PacketHandler is a handler that handles UDP assocations. type PacketHandler interface { - Handle(pkt []byte, lazySlice slicepool.LazySlice, assoc PacketAssociation) + Handle(pkt []byte, assoc PacketAssociation, lazySlice slicepool.LazySlice) *packetMetrics // SetLogger sets the logger used to log messages. Uses a no-op logger if nil. SetLogger(l *slog.Logger) // SetTargetIPValidator sets the function to be used to validate the target IP addresses. @@ -126,7 +126,7 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali h.targetIPValidator = targetIPValidator } -func (h *packetHandler) Handle(pkt []byte, lazySlice slicepool.LazySlice, assoc PacketAssociation) { +func (h *packetHandler) Handle(pkt []byte, assoc PacketAssociation, lazySlice slicepool.LazySlice) *packetMetrics { l := h.logger.With(slog.Any("association", assoc)) defer lazySlice.Release() defer debugUDP(l, "Done") @@ -138,7 +138,7 @@ func (h *packetHandler) Handle(pkt []byte, lazySlice slicepool.LazySlice, assoc var textData []byte cryptoKey, err := assoc.Authenticate(func() (keyID string, cryptoKey *shadowsocks.EncryptionKey, keyErr error) { - ip := assoc.RemoteAddr().AddrPort().Addr() + ip := assoc.ClientAddr().AddrPort().Addr() textLazySlice := readBufPool.LazySlice() textBuf := textLazySlice.Acquire() unpackStart := time.Now() @@ -147,7 +147,7 @@ func (h *packetHandler) Handle(pkt []byte, lazySlice slicepool.LazySlice, assoc textLazySlice.Release() h.ssm.AddCipherSearch(keyErr == nil, timeToCipher) return keyID, cryptoKey, keyErr - }) + }, h.handleTargetPacket) if err != nil { return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err) @@ -172,7 +172,7 @@ func (h *packetHandler) Handle(pkt []byte, lazySlice slicepool.LazySlice, assoc } debugUDP(l, "Proxy exit.") - proxyTargetBytes, err = assoc.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature + proxyTargetBytes, err = assoc.WriteToTarget(payload, tgtUDPAddr) // accept only UDPAddr despite the signature if err != nil { return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) } @@ -184,7 +184,11 @@ func (h *packetHandler) Handle(pkt []byte, lazySlice slicepool.LazySlice, assoc debugUDP(l, "Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) status = connError.Status } - assoc.Metrics().AddPacketFromClient(status, int64(len(pkt)), int64(proxyTargetBytes)) + return &packetMetrics{ + status: status, + bytesIn: int64(len(pkt)), + bytesOut: int64(proxyTargetBytes), + } } // Given the decrypted contents of a UDP packet, return @@ -208,13 +212,92 @@ func (h *packetHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, * return payload, tgtUDPAddr, nil } +// Get the maximum length of the shadowsocks address header by parsing +// and serializing an IPv6 address from the example range. +var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) + +func (h *packetHandler) handleTargetPacket(pkt []byte, cryptoKey *shadowsocks.EncryptionKey, assoc PacketAssociation) *packetMetrics { + l := h.logger.With(slog.Any("association", assoc)) + + expired := false + + saltSize := cryptoKey.SaltSize() + // Leave enough room at the beginning of the packet for a max-length header (i.e. IPv6). + bodyStart := saltSize + maxAddrLen + + var bodyLen, proxyClientBytes int + connError := func() *onet.ConnectionError { + var ( + raddr net.Addr + err error + ) + // `readBuf` receives the plaintext body in `pkt`: + // [padding?][salt][address][body][tag][unused] + // |-- bodyStart --|[ readBuf ] + readBuf := pkt[bodyStart:] + bodyLen, raddr, err := assoc.ReadFromTarget(readBuf) + if err != nil { + if netErr, ok := err.(net.Error); ok { + if netErr.Timeout() { + expired = true + } + } + return onet.NewConnectionError("ERR_READ", "Failed to read from target", err) + } + + debugUDP(l, "Got response.", slog.Any("rtarget", raddr)) + srcAddr := socks.ParseAddr(raddr.String()) + addrStart := bodyStart - len(srcAddr) + // `plainTextBuf` concatenates the SOCKS address and body: + // [padding?][salt][address][body][tag][unused] + // |-- addrStart -|[plaintextBuf ] + plaintextBuf := pkt[addrStart : bodyStart+bodyLen] + copy(plaintextBuf, srcAddr) + + // saltStart is 0 if raddr is IPv6. + saltStart := addrStart - saltSize + // `packBuf` adds space for the salt and tag. + // `buf` shows the space that was used. + // [padding?][salt][address][body][tag][unused] + // [ packBuf ] + // [ buf ] + packBuf := pkt[saltStart:] + buf, err := shadowsocks.Pack(packBuf, plaintextBuf, cryptoKey) // Encrypt in-place + if err != nil { + return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err) + } + proxyClientBytes, err = assoc.WriteToClient(buf) + if err != nil { + return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err) + } + return nil + }() + status := "OK" + if connError != nil { + debugUDP(l, "Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) + status = connError.Status + } + return &packetMetrics{ + status: status, + expired: expired, + bytesIn: int64(bodyLen), + bytesOut: int64(proxyClientBytes), + } +} + +type packetMetrics struct { + status string + expired bool + bytesIn, bytesOut int64 +} + type NewAssociationFunc func(conn net.Conn) (PacketAssociation, error) // PacketServe listens for UDP packets on the provided [net.PacketConn], creates // and manages NAT associations, and invokes the `handle` function for each // packet. It uses a NAT map to track active associations and handles their // lifecycle. -func PacketServe(clientConn net.PacketConn, newAssociation NewAssociationFunc, handle PacketHandleFunc, metrics NATMetrics) { +func PacketServe(clientConn net.PacketConn, newAssociation NewAssociationFunc, handle PacketHandleFuncWithLazySlice, metrics NATMetrics) { nm := newNATmap() defer nm.Close() @@ -260,7 +343,10 @@ func PacketServe(clientConn net.PacketConn, newAssociation NewAssociationFunc, h metrics.RemoveNATEntry() nm.Del(addr.String()) default: - go handle(pkt, lazySlice, assoc) + go func() { + metrics := handle(pkt, assoc, lazySlice) + assoc.Metrics().AddPacketFromClient(metrics.status, metrics.bytesIn, metrics.bytesOut) + }() } return false }() @@ -402,17 +488,20 @@ func (m *natmap) Close() error { } // PacketHandleFunc processes a single incoming packet. +type PacketHandleFunc func(pkt []byte, assoc PacketAssociation) *packetMetrics + +// PacketHandleFuncWithLazySlice processes a single incoming packet. // -// pkt contains the raw packet data. + // lazySlice is the LazySlice that holds the pkt buffer, which should be // released after the packet is processed. -type PacketHandleFunc func(pkt []byte, lazySlice slicepool.LazySlice, assoc PacketAssociation) +type PacketHandleFuncWithLazySlice func(pkt []byte, assoc PacketAssociation, lazySlice slicepool.LazySlice) *packetMetrics -func HandleAssociation(assoc PacketAssociation, handle PacketHandleFunc) { +func HandleAssociation(assoc PacketAssociation, handle PacketHandleFuncWithLazySlice) { for { lazySlice := readBufPool.LazySlice() buf := lazySlice.Acquire() - n, err := assoc.Read(buf) + n, err := assoc.ReadFromClient(buf) if errors.Is(err, net.ErrClosed) { lazySlice.Release() return @@ -423,24 +512,51 @@ func HandleAssociation(assoc PacketAssociation, handle PacketHandleFunc) { lazySlice.Release() return default: - go handle(pkt, lazySlice, assoc) + go func() { + metrics := handle(pkt, assoc, lazySlice) + assoc.Metrics().AddPacketFromClient(metrics.status, metrics.bytesIn, metrics.bytesOut) + }() + } + } +} + +// HandleAssociationTimedCopy handles the target-side of the association by +// copying from target to client until read timeout. +func HandleAssociationTimedCopy(assoc PacketAssociation, handle PacketHandleFunc) { + defer assoc.CloseTarget() + + // pkt is used for in-place encryption of downstream UDP packets. + // Padding is only used if the address is IPv4. + pkt := make([]byte, serverUDPBufferSize) + + for { + metrics := handle(pkt, assoc) + if metrics.expired { + break } + assoc.Metrics().AddPacketFromTarget(metrics.status, metrics.bytesIn, metrics.bytesOut) } } // PacketAssociation represents a UDP association. type PacketAssociation interface { - // Read reads data from the association. - Read(b []byte) (n int, err error) + // ReadFromClient reads data from the client side of the association. + ReadFromClient(b []byte) (n int, err error) - // WriteTo writes data to the association. - WriteTo(b []byte, addr net.Addr) (int, error) + // WriteToClient writes data to the client side of the association. + WriteToClient(b []byte) (n int, err error) + + // ReadFromTarget reads data from the target side of the association. + ReadFromTarget(p []byte) (n int, addr net.Addr, err error) + + // WriteToTarget writes data to the target side of the association. + WriteToTarget(b []byte, addr net.Addr) (int, error) // Authenticate authenticates the association. - Authenticate(authenticateFunc PacketAuthenticateFunc) (*shadowsocks.EncryptionKey, error) + Authenticate(authenticate PacketAuthenticateFunc, handleTarget func(pkt []byte, cryptoKey *shadowsocks.EncryptionKey, assoc PacketAssociation) *packetMetrics) (*shadowsocks.EncryptionKey, error) - // RemoteAddr returns the remote network address of the association, if known. - RemoteAddr() *net.UDPAddr + // ClientAddr returns the remote network address of the client connection, if known. + ClientAddr() *net.UDPAddr // Done returns a channel that is closed when the association is closed. Done() <-chan struct{} @@ -448,15 +564,16 @@ type PacketAssociation interface { // Close closes the association and releases any associated resources. Close() error + // Closes the target side of the association. + CloseTarget() error + // Returns association metrics. // TODO(sbruens): Refactor so this isn't needed. Metrics() UDPAssociationMetrics } -type PacketAuthenticateFunc func() (string, *shadowsocks.EncryptionKey, error) - type association struct { - net.Conn + clientConn net.Conn targetConn net.PacketConn authenticateOnce sync.Once cryptoKey *shadowsocks.EncryptionKey @@ -476,35 +593,49 @@ func NewPacketAssociation(conn net.Conn, listener transport.PacketListener, m UD } return &association{ - Conn: conn, + clientConn: conn, targetConn: targetConn, m: m, doneCh: make(chan struct{}), }, nil } -func (a *association) WriteTo(b []byte, addr net.Addr) (int, error) { +func (a *association) ReadFromClient(b []byte) (n int, err error) { + return a.clientConn.Read(b) +} + +func (a *association) WriteToClient(b []byte) (n int, err error) { + return a.clientConn.Write(b) +} + +func (a *association) ReadFromTarget(p []byte) (n int, addr net.Addr, err error) { + return a.targetConn.ReadFrom(p) +} + +func (a *association) WriteToTarget(b []byte, addr net.Addr) (int, error) { return a.targetConn.WriteTo(b, addr) } -func (a *association) RemoteAddr() *net.UDPAddr { - return a.Conn.RemoteAddr().(*net.UDPAddr) +func (a *association) ClientAddr() *net.UDPAddr { + return a.clientConn.RemoteAddr().(*net.UDPAddr) } -func (a *association) Authenticate(authenticateFunc PacketAuthenticateFunc) (*shadowsocks.EncryptionKey, error) { +type PacketAuthenticateFunc func() (string, *shadowsocks.EncryptionKey, error) + +func (a *association) Authenticate(authenticate PacketAuthenticateFunc, handleTarget func(pkt []byte, cryptoKey *shadowsocks.EncryptionKey, assoc PacketAssociation) *packetMetrics) (*shadowsocks.EncryptionKey, error) { var err error a.authenticateOnce.Do(func() { var keyID string - keyID, a.cryptoKey, err = authenticateFunc() + keyID, a.cryptoKey, err = authenticate() a.m.AddAuthentication(keyID) if err != nil { return } - // TODO(sbruens): Pass in a `handle` function to handle shadowsocks and move it to - // the packet handler. - go a.timedCopy() + go HandleAssociationTimedCopy(a, func(pkt []byte, assoc PacketAssociation) *packetMetrics { + return handleTarget(pkt, a.cryptoKey, assoc) + }) }) return a.cryptoKey, err } @@ -515,7 +646,17 @@ func (a *association) Done() <-chan struct{} { func (a *association) Close() error { now := time.Now() - return a.SetReadDeadline(now) + return a.clientConn.SetReadDeadline(now) +} + +func (a *association) CloseTarget() error { + a.m.AddClose() + err := a.targetConn.Close() + if err != nil { + return err + } + close(a.doneCh) + return nil } func (a *association) Metrics() UDPAssociationMetrics { @@ -524,95 +665,11 @@ func (a *association) Metrics() UDPAssociationMetrics { func (a *association) LogValue() slog.Value { return slog.GroupValue( - slog.Any("client", a.Conn.RemoteAddr()), + slog.Any("client", a.clientConn.RemoteAddr()), slog.Any("ltarget", a.targetConn.LocalAddr()), ) } -// Get the maximum length of the shadowsocks address header by parsing -// and serializing an IPv6 address from the example range. -var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) - -// copy from target to client until read timeout -func (a *association) timedCopy() { - defer func() { - a.m.AddClose() - a.targetConn.Close() - close(a.doneCh) - }() - - // pkt is used for in-place encryption of downstream UDP packets, with the layout - // [padding?][salt][address][body][tag][extra] - // Padding is only used if the address is IPv4. - pkt := make([]byte, serverUDPBufferSize) - - saltSize := a.cryptoKey.SaltSize() - // Leave enough room at the beginning of the packet for a max-length header (i.e. IPv6). - bodyStart := saltSize + maxAddrLen - - expired := false - for { - var bodyLen, proxyClientBytes int - connError := func() (connError *onet.ConnectionError) { - var ( - raddr net.Addr - err error - ) - // `readBuf` receives the plaintext body in `pkt`: - // [padding?][salt][address][body][tag][unused] - // |-- bodyStart --|[ readBuf ] - readBuf := pkt[bodyStart:] - bodyLen, raddr, err = a.targetConn.ReadFrom(readBuf) - if err != nil { - if netErr, ok := err.(net.Error); ok { - if netErr.Timeout() { - expired = true - return nil - } - } - return onet.NewConnectionError("ERR_READ", "Failed to read from target", err) - } - - // TODO(sbruens): Figure out the logger here and below. - // debugUDP(l, "Got response.", slog.Any("rtarget", raddr)) - srcAddr := socks.ParseAddr(raddr.String()) - addrStart := bodyStart - len(srcAddr) - // `plainTextBuf` concatenates the SOCKS address and body: - // [padding?][salt][address][body][tag][unused] - // |-- addrStart -|[plaintextBuf ] - plaintextBuf := pkt[addrStart : bodyStart+bodyLen] - copy(plaintextBuf, srcAddr) - - // saltStart is 0 if raddr is IPv6. - saltStart := addrStart - saltSize - // `packBuf` adds space for the salt and tag. - // `buf` shows the space that was used. - // [padding?][salt][address][body][tag][unused] - // [ packBuf ] - // [ buf ] - packBuf := pkt[saltStart:] - buf, err := shadowsocks.Pack(packBuf, plaintextBuf, a.cryptoKey) // Encrypt in-place - if err != nil { - return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err) - } - proxyClientBytes, err = a.Write(buf) - if err != nil { - return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err) - } - return nil - }() - status := "OK" - if connError != nil { - //debugUDP(a.Logger(), "Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) - status = connError.Status - } - if expired { - break - } - a.m.AddPacketFromTarget(status, int64(bodyLen), int64(proxyClientBytes)) - } -} - // NoOpUDPAssociationMetrics is a [UDPAssociationMetrics] that doesn't do anything. Useful in tests // or if you don't want to track metrics. type NoOpUDPAssociationMetrics struct{}