From 0b579cc862d0d21496e5f64d15f852fa7210fc39 Mon Sep 17 00:00:00 2001 From: uoosef Date: Sun, 10 Sep 2023 21:18:16 +0330 Subject: [PATCH] local dns resolve + udp fix + socks4/a rewrite --- doh/doh.go | 2 +- server/handle.go | 36 ++++++++++++++---------- server/server.go | 10 +++++-- socks5/option.go | 7 +++++ socks5/server.go | 73 ++++++++++++++++++++++-------------------------- transport/ws.go | 6 ++-- 6 files changed, 75 insertions(+), 59 deletions(-) diff --git a/doh/doh.go b/doh/doh.go index d1fc5be..e00d7d5 100644 --- a/doh/doh.go +++ b/doh/doh.go @@ -103,7 +103,7 @@ func (c *Client) Exchange(req *dns.Msg, address string) (r *dns.Msg, rtt time.Du base64.RawURLEncoding.Encode(b64, buf) if config.G.WorkerEnabled { - address = "https://8.8.8.8/dns-query" + address = "https://8.8.4.4/dns-query" } content, err := c.HTTPClient(address + "?dns=" + string(b64)) diff --git a/server/handle.go b/server/handle.go index dbae7a3..3799ad7 100644 --- a/server/handle.go +++ b/server/handle.go @@ -68,12 +68,14 @@ func (s *Server) extractHostnameOrChangeHTTPHostHeader(data []byte) ( return []byte(hello.ServerName), data, false, nil } -func (s *Server) processFirstPacket(ctx context.Context, w io.Writer, req *socks5.Request) ( +func (s *Server) processFirstPacket(ctx context.Context, w io.Writer, req *socks5.Request, successReply bool) ( *socks5.Request, string, bool, error, ) { - if err := socks5.SendReply(w, statute.RepSuccess, nil); err != nil { - logger.Errorf("failed to send reply: %v", err) - return nil, "", false, err + if successReply { + if err := socks5.SendReply(w, statute.RepSuccess, nil); err != nil { + logger.Errorf("failed to send reply: %v", err) + return nil, "", false, err + } } firstPacket := make([]byte, 32*1024) @@ -88,23 +90,26 @@ func (s *Server) processFirstPacket(ctx context.Context, w io.Writer, req *socks logger.Infof("Hostname %s", string(hostname)) } - IPPort, err := s.resolveDestination(ctx, req) + dest, err := s.resolveDestination(ctx, req) if err != nil { return nil, "", false, err } + IPPort := net.JoinHostPort(dest.IP.String(), strconv.Itoa(dest.Port)) + // if user has a faulty dns, and it returns dpi ip, // we resolve destination based on extracted tls sni or http hostname if hostname != nil && strings.Contains(IPPort, "10.10.3") { logger.Infof("%s is dpi ip extracting destination host from packets...", IPPort) req.RawDestAddr.FQDN = string(hostname) - IPPort, err = s.resolveDestination(ctx, req) + dest, err = s.resolveDestination(ctx, req) if err != nil { // if destination resolved to dpi and we cant resolve to actual destination // it's pointless to connect to dpi logger.Infof("system was unable to extract destination host from packets!") return nil, "", false, err } + IPPort = net.JoinHostPort(dest.IP.String(), strconv.Itoa(dest.Port)) } req.Reader = &utils.BufferedReader{ @@ -116,11 +121,15 @@ func (s *Server) processFirstPacket(ctx context.Context, w io.Writer, req *socks return req, IPPort, isHTTP, nil } -func (s *Server) HandleTCPTunnel(ctx context.Context, w io.Writer, req *socks5.Request) error { - r, _, _, err := s.processFirstPacket(ctx, w, req) +func (s *Server) HandleTCPTunnel(ctx context.Context, w io.Writer, req *socks5.Request, successReply bool) error { + r, _, _, err := s.processFirstPacket(ctx, w, req, successReply) if err != nil { return err } + dest, err := s.resolveDestination(ctx, req) + if err == nil { + req.RawDestAddr = dest + } return s.Transport.TunnelTCP(w, r) } @@ -129,8 +138,8 @@ func (s *Server) HandleUDPTunnel(_ context.Context, w io.Writer, req *socks5.Req } // HandleTCPFragment handles the SOCKS5 request and forwards traffic to the destination. -func (s *Server) HandleTCPFragment(ctx context.Context, w io.Writer, req *socks5.Request) error { - r, IPPort, isHTTP, err := s.processFirstPacket(ctx, w, req) +func (s *Server) HandleTCPFragment(ctx context.Context, w io.Writer, req *socks5.Request, successReply bool) error { + r, IPPort, isHTTP, err := s.processFirstPacket(ctx, w, req, successReply) if err != nil { return err } @@ -174,13 +183,13 @@ func (s *Server) Copy(reader io.Reader, writer io.Writer) error { return err } -func (s *Server) resolveDestination(_ context.Context, req *socks5.Request) (string, error) { +func (s *Server) resolveDestination(_ context.Context, req *socks5.Request) (*statute.AddrSpec, error) { dest := req.RawDestAddr if dest.FQDN != "" { ip, err := s.Resolve(dest.FQDN) if err != nil { - return "", err + return nil, err } dest.IP = net.ParseIP(ip) logger.Infof("resolved %s to %s", req.RawDestAddr, dest) @@ -188,8 +197,7 @@ func (s *Server) resolveDestination(_ context.Context, req *socks5.Request) (str logger.Infof("skipping resolution for %s", req.RawDestAddr) } - addr := net.JoinHostPort(dest.IP.String(), strconv.Itoa(dest.Port)) - return addr, nil + return dest, nil } // Resolve resolves the FQDN to an IP address using the specified resolution mechanism. diff --git a/server/server.go b/server/server.go index 3a31738..f03d930 100644 --- a/server/server.go +++ b/server/server.go @@ -108,7 +108,10 @@ func Run(captureCTRLC bool) error { if workerConfig.WorkerEnabled && !workerConfig.WorkerDNSOnly { s5 = socks5.NewServer( socks5.WithConnectHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error { - return serverHandler.HandleTCPTunnel(ctx, w, req) + return serverHandler.HandleTCPTunnel(ctx, w, req, true) + }), + socks5.WithSocks4ConnectHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error { + return serverHandler.HandleTCPTunnel(ctx, w, req, false) }), socks5.WithAssociateHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error { return serverHandler.HandleUDPTunnel(ctx, w, req) @@ -117,7 +120,10 @@ func Run(captureCTRLC bool) error { } else { s5 = socks5.NewServer( socks5.WithConnectHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error { - return serverHandler.HandleTCPFragment(ctx, w, req) + return serverHandler.HandleTCPFragment(ctx, w, req, true) + }), + socks5.WithSocks4ConnectHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error { + return serverHandler.HandleTCPFragment(ctx, w, req, false) }), ) } diff --git a/socks5/option.go b/socks5/option.go index b6ad548..8197fb4 100644 --- a/socks5/option.go +++ b/socks5/option.go @@ -94,6 +94,13 @@ func WithConnectHandle(h func(ctx context.Context, writer io.Writer, request *Re } } +// WithSocks4ConnectHandle is used to handle a user's connect command. +func WithSocks4ConnectHandle(h func(ctx context.Context, writer io.Writer, request *Request) error) Option { + return func(s *Server) { + s.userSocks4ConnectHandle = h + } +} + // WithBindHandle is used to handle a user's bind command. func WithBindHandle(h func(ctx context.Context, writer io.Writer, request *Request) error) Option { return func(s *Server) { diff --git a/socks5/server.go b/socks5/server.go index 73da411..9d4cfcb 100644 --- a/socks5/server.go +++ b/socks5/server.go @@ -5,23 +5,19 @@ package socks5 import ( "bepass/bufferpool" + "bepass/logger" + "bepass/socks5/statute" "bufio" "bytes" "context" "encoding/binary" "errors" "fmt" + "github.com/elazarl/goproxy" + "golang.org/x/net/proxy" "io" "net" "net/http" - "strconv" - - "golang.org/x/net/proxy" - - "bepass/logger" - "bepass/socks5/statute" - - "github.com/elazarl/goproxy" ) // GPool is used to implement custom goroutine pool default use goroutine @@ -59,13 +55,14 @@ type Server struct { // goroutine pool gPool GPool // user's handle - userConnectHandle func(ctx context.Context, writer io.Writer, request *Request) error - userBindHandle func(ctx context.Context, writer io.Writer, request *Request) error - userAssociateHandle func(ctx context.Context, writer io.Writer, request *Request) error - done chan bool - listen net.Listener - httpProxyBindAddr string - bindAddress string + userSocks4ConnectHandle func(ctx context.Context, writer io.Writer, request *Request) error + userConnectHandle func(ctx context.Context, writer io.Writer, request *Request) error + userBindHandle func(ctx context.Context, writer io.Writer, request *Request) error + userAssociateHandle func(ctx context.Context, writer io.Writer, request *Request) error + done chan bool + listen net.Listener + httpProxyBindAddr string + bindAddress string } // NewServer creates a new Server @@ -103,7 +100,8 @@ func (sf *Server) ListenAndServe(network, addr string) error { sf.bindAddress = addr // Create a custom dialer with DialContext - dialer, err := proxy.SOCKS5(network, sf.bindAddress, nil, proxy.Direct) + dialer, err := proxy.SOCKS5(network, sf.bindAddress, nil, nil) + if err != nil { return err } @@ -286,7 +284,6 @@ func readAsString(r io.Reader) (string, error) { func (sf *Server) handleSocks4Request(conn net.Conn, bufConn *bufio.Reader) error { var cddstportdstip [1 + 1 + 2 + 4]byte - var destination = "" var dstHost = "" if _, err := io.ReadFull(bufConn, cddstportdstip[:]); err != nil { return err @@ -297,7 +294,6 @@ func (sf *Server) handleSocks4Request(conn net.Conn, bufConn *bufio.Reader) erro if command != uint8(1) { return fmt.Errorf("command %d is not supported", command) } - destination = net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))) // Skip USERID if _, err := readAsString(bufConn); err != nil { return err @@ -311,37 +307,36 @@ func (sf *Server) handleSocks4Request(conn net.Conn, bufConn *bufio.Reader) erro } } + atype := statute.ATYPIPv4 + if dstHost != "" { - destination = net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))) + atype = statute.ATYPDomain } if _, err := conn.Write([]byte{0, 90, 0, 0, 0, 0, 0, 0}); err != nil { return err } - d, err := proxy.SOCKS5("tcp", sf.bindAddress, nil, proxy.Direct) - if err != nil { - return err + request := &Request{ + Request: statute.Request{}, + AuthContext: nil, + LocalAddr: conn.LocalAddr(), + RemoteAddr: conn.RemoteAddr(), + DestAddr: nil, + Reader: bufConn, + RawDestAddr: &statute.AddrSpec{ + FQDN: dstHost, + IP: dstIP, + Port: int(dstPort), + AddrType: atype, + }, } - dstConn, err := d.Dial("tcp", destination) - if err != nil { - return err - } - var errCh = make(chan error, 2) - go func() { - _, err := io.Copy(dstConn, bufConn) - errCh <- err - }() - go func() { - _, err := io.Copy(conn, dstConn) - errCh <- err - }() - err = <-errCh - if err != nil { - return err + if sf.userSocks4ConnectHandle != nil { + return sf.userSocks4ConnectHandle(context.Background(), io.Writer(conn), request) } - return <-errCh + logger.Errorf("socks4/a without user defined handler is unsupported") + return errors.New("unsupported") } // authenticate is used to handle connection authentication diff --git a/transport/ws.go b/transport/ws.go index cccb91f..0b10ef8 100644 --- a/transport/ws.go +++ b/transport/ws.go @@ -75,7 +75,7 @@ func (w *WSTunnel) PersistentDial(tunnelEndpoint string, bindWriteChannel chan U if time.Now().Unix()-lastActivityStamp > w.LinkIdleTimeout { return } - for { + for limit := 0; limit < 10; limit++ { done := make(chan struct{}) doneR := make(chan struct{}) @@ -149,12 +149,12 @@ func (w *WSTunnel) PersistentDial(tunnelEndpoint string, bindWriteChannel chan U if err != nil { if strings.Contains(err.Error(), "websocket: close") || - strings.Contains(err.Error(), "i/o") { + strings.Contains(err.Error(), "limit/o") { logger.Errorf("reading from udp over tcp error: %v\r\n", err) return } logger.Errorf("reading from udp over TCP tunnel packet size error: %v\r\n", err) - continue + return } // The first 2 packets of response are channel ID