Skip to content

Commit

Permalink
local dns resolve + udp fix + socks4/a rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
uoosef committed Sep 10, 2023
1 parent e15af09 commit 0b579cc
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 59 deletions.
2 changes: 1 addition & 1 deletion doh/doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (c *Client) Exchange(req *dns.Msg, address string) (r *dns.Msg, rtt time.Du
base64.RawURLEncoding.Encode(b64, buf)

if config.G.WorkerEnabled {
address = "https://8.8.8.8/dns-query"
address = "https://8.8.4.4/dns-query"
}

content, err := c.HTTPClient(address + "?dns=" + string(b64))
Expand Down
36 changes: 22 additions & 14 deletions server/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@ func (s *Server) extractHostnameOrChangeHTTPHostHeader(data []byte) (
return []byte(hello.ServerName), data, false, nil
}

func (s *Server) processFirstPacket(ctx context.Context, w io.Writer, req *socks5.Request) (
func (s *Server) processFirstPacket(ctx context.Context, w io.Writer, req *socks5.Request, successReply bool) (
*socks5.Request, string, bool, error,
) {
if err := socks5.SendReply(w, statute.RepSuccess, nil); err != nil {
logger.Errorf("failed to send reply: %v", err)
return nil, "", false, err
if successReply {
if err := socks5.SendReply(w, statute.RepSuccess, nil); err != nil {
logger.Errorf("failed to send reply: %v", err)
return nil, "", false, err
}
}

firstPacket := make([]byte, 32*1024)
Expand All @@ -88,23 +90,26 @@ func (s *Server) processFirstPacket(ctx context.Context, w io.Writer, req *socks
logger.Infof("Hostname %s", string(hostname))
}

IPPort, err := s.resolveDestination(ctx, req)
dest, err := s.resolveDestination(ctx, req)
if err != nil {
return nil, "", false, err
}

IPPort := net.JoinHostPort(dest.IP.String(), strconv.Itoa(dest.Port))

// if user has a faulty dns, and it returns dpi ip,
// we resolve destination based on extracted tls sni or http hostname
if hostname != nil && strings.Contains(IPPort, "10.10.3") {
logger.Infof("%s is dpi ip extracting destination host from packets...", IPPort)
req.RawDestAddr.FQDN = string(hostname)
IPPort, err = s.resolveDestination(ctx, req)
dest, err = s.resolveDestination(ctx, req)
if err != nil {
// if destination resolved to dpi and we cant resolve to actual destination
// it's pointless to connect to dpi
logger.Infof("system was unable to extract destination host from packets!")
return nil, "", false, err
}
IPPort = net.JoinHostPort(dest.IP.String(), strconv.Itoa(dest.Port))
}

req.Reader = &utils.BufferedReader{
Expand All @@ -116,11 +121,15 @@ func (s *Server) processFirstPacket(ctx context.Context, w io.Writer, req *socks
return req, IPPort, isHTTP, nil
}

func (s *Server) HandleTCPTunnel(ctx context.Context, w io.Writer, req *socks5.Request) error {
r, _, _, err := s.processFirstPacket(ctx, w, req)
func (s *Server) HandleTCPTunnel(ctx context.Context, w io.Writer, req *socks5.Request, successReply bool) error {
r, _, _, err := s.processFirstPacket(ctx, w, req, successReply)
if err != nil {
return err
}
dest, err := s.resolveDestination(ctx, req)
if err == nil {
req.RawDestAddr = dest
}
return s.Transport.TunnelTCP(w, r)
}

Expand All @@ -129,8 +138,8 @@ func (s *Server) HandleUDPTunnel(_ context.Context, w io.Writer, req *socks5.Req
}

// HandleTCPFragment handles the SOCKS5 request and forwards traffic to the destination.
func (s *Server) HandleTCPFragment(ctx context.Context, w io.Writer, req *socks5.Request) error {
r, IPPort, isHTTP, err := s.processFirstPacket(ctx, w, req)
func (s *Server) HandleTCPFragment(ctx context.Context, w io.Writer, req *socks5.Request, successReply bool) error {
r, IPPort, isHTTP, err := s.processFirstPacket(ctx, w, req, successReply)
if err != nil {
return err
}
Expand Down Expand Up @@ -174,22 +183,21 @@ func (s *Server) Copy(reader io.Reader, writer io.Writer) error {
return err
}

func (s *Server) resolveDestination(_ context.Context, req *socks5.Request) (string, error) {
func (s *Server) resolveDestination(_ context.Context, req *socks5.Request) (*statute.AddrSpec, error) {
dest := req.RawDestAddr

if dest.FQDN != "" {
ip, err := s.Resolve(dest.FQDN)
if err != nil {
return "", err
return nil, err
}
dest.IP = net.ParseIP(ip)
logger.Infof("resolved %s to %s", req.RawDestAddr, dest)
} else {
logger.Infof("skipping resolution for %s", req.RawDestAddr)
}

addr := net.JoinHostPort(dest.IP.String(), strconv.Itoa(dest.Port))
return addr, nil
return dest, nil
}

// Resolve resolves the FQDN to an IP address using the specified resolution mechanism.
Expand Down
10 changes: 8 additions & 2 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ func Run(captureCTRLC bool) error {
if workerConfig.WorkerEnabled && !workerConfig.WorkerDNSOnly {
s5 = socks5.NewServer(
socks5.WithConnectHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error {
return serverHandler.HandleTCPTunnel(ctx, w, req)
return serverHandler.HandleTCPTunnel(ctx, w, req, true)
}),
socks5.WithSocks4ConnectHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error {
return serverHandler.HandleTCPTunnel(ctx, w, req, false)
}),
socks5.WithAssociateHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error {
return serverHandler.HandleUDPTunnel(ctx, w, req)
Expand All @@ -117,7 +120,10 @@ func Run(captureCTRLC bool) error {
} else {
s5 = socks5.NewServer(
socks5.WithConnectHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error {
return serverHandler.HandleTCPFragment(ctx, w, req)
return serverHandler.HandleTCPFragment(ctx, w, req, true)
}),
socks5.WithSocks4ConnectHandle(func(ctx context.Context, w io.Writer, req *socks5.Request) error {
return serverHandler.HandleTCPFragment(ctx, w, req, false)
}),
)
}
Expand Down
7 changes: 7 additions & 0 deletions socks5/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ func WithConnectHandle(h func(ctx context.Context, writer io.Writer, request *Re
}
}

// WithSocks4ConnectHandle is used to handle a user's connect command.
func WithSocks4ConnectHandle(h func(ctx context.Context, writer io.Writer, request *Request) error) Option {
return func(s *Server) {
s.userSocks4ConnectHandle = h
}
}

// WithBindHandle is used to handle a user's bind command.
func WithBindHandle(h func(ctx context.Context, writer io.Writer, request *Request) error) Option {
return func(s *Server) {
Expand Down
73 changes: 34 additions & 39 deletions socks5/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,19 @@ package socks5

import (
"bepass/bufferpool"
"bepass/logger"
"bepass/socks5/statute"
"bufio"
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"github.com/elazarl/goproxy"
"golang.org/x/net/proxy"
"io"
"net"
"net/http"
"strconv"

"golang.org/x/net/proxy"

"bepass/logger"
"bepass/socks5/statute"

"github.com/elazarl/goproxy"
)

// GPool is used to implement custom goroutine pool default use goroutine
Expand Down Expand Up @@ -59,13 +55,14 @@ type Server struct {
// goroutine pool
gPool GPool
// user's handle
userConnectHandle func(ctx context.Context, writer io.Writer, request *Request) error
userBindHandle func(ctx context.Context, writer io.Writer, request *Request) error
userAssociateHandle func(ctx context.Context, writer io.Writer, request *Request) error
done chan bool
listen net.Listener
httpProxyBindAddr string
bindAddress string
userSocks4ConnectHandle func(ctx context.Context, writer io.Writer, request *Request) error
userConnectHandle func(ctx context.Context, writer io.Writer, request *Request) error
userBindHandle func(ctx context.Context, writer io.Writer, request *Request) error
userAssociateHandle func(ctx context.Context, writer io.Writer, request *Request) error
done chan bool
listen net.Listener
httpProxyBindAddr string
bindAddress string
}

// NewServer creates a new Server
Expand Down Expand Up @@ -103,7 +100,8 @@ func (sf *Server) ListenAndServe(network, addr string) error {
sf.bindAddress = addr

// Create a custom dialer with DialContext
dialer, err := proxy.SOCKS5(network, sf.bindAddress, nil, proxy.Direct)
dialer, err := proxy.SOCKS5(network, sf.bindAddress, nil, nil)

if err != nil {
return err
}
Expand Down Expand Up @@ -286,7 +284,6 @@ func readAsString(r io.Reader) (string, error) {

func (sf *Server) handleSocks4Request(conn net.Conn, bufConn *bufio.Reader) error {
var cddstportdstip [1 + 1 + 2 + 4]byte
var destination = ""
var dstHost = ""
if _, err := io.ReadFull(bufConn, cddstportdstip[:]); err != nil {
return err
Expand All @@ -297,7 +294,6 @@ func (sf *Server) handleSocks4Request(conn net.Conn, bufConn *bufio.Reader) erro
if command != uint8(1) {
return fmt.Errorf("command %d is not supported", command)
}
destination = net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
// Skip USERID
if _, err := readAsString(bufConn); err != nil {
return err
Expand All @@ -311,37 +307,36 @@ func (sf *Server) handleSocks4Request(conn net.Conn, bufConn *bufio.Reader) erro
}
}

atype := statute.ATYPIPv4

if dstHost != "" {
destination = net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort)))
atype = statute.ATYPDomain
}

if _, err := conn.Write([]byte{0, 90, 0, 0, 0, 0, 0, 0}); err != nil {
return err
}

d, err := proxy.SOCKS5("tcp", sf.bindAddress, nil, proxy.Direct)
if err != nil {
return err
request := &Request{
Request: statute.Request{},
AuthContext: nil,
LocalAddr: conn.LocalAddr(),
RemoteAddr: conn.RemoteAddr(),
DestAddr: nil,
Reader: bufConn,
RawDestAddr: &statute.AddrSpec{
FQDN: dstHost,
IP: dstIP,
Port: int(dstPort),
AddrType: atype,
},
}
dstConn, err := d.Dial("tcp", destination)

if err != nil {
return err
}
var errCh = make(chan error, 2)
go func() {
_, err := io.Copy(dstConn, bufConn)
errCh <- err
}()
go func() {
_, err := io.Copy(conn, dstConn)
errCh <- err
}()
err = <-errCh
if err != nil {
return err
if sf.userSocks4ConnectHandle != nil {
return sf.userSocks4ConnectHandle(context.Background(), io.Writer(conn), request)
}
return <-errCh
logger.Errorf("socks4/a without user defined handler is unsupported")
return errors.New("unsupported")
}

// authenticate is used to handle connection authentication
Expand Down
6 changes: 3 additions & 3 deletions transport/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (w *WSTunnel) PersistentDial(tunnelEndpoint string, bindWriteChannel chan U
if time.Now().Unix()-lastActivityStamp > w.LinkIdleTimeout {
return
}
for {
for limit := 0; limit < 10; limit++ {
done := make(chan struct{})
doneR := make(chan struct{})

Expand Down Expand Up @@ -149,12 +149,12 @@ func (w *WSTunnel) PersistentDial(tunnelEndpoint string, bindWriteChannel chan U

if err != nil {
if strings.Contains(err.Error(), "websocket: close") ||
strings.Contains(err.Error(), "i/o") {
strings.Contains(err.Error(), "limit/o") {
logger.Errorf("reading from udp over tcp error: %v\r\n", err)
return
}
logger.Errorf("reading from udp over TCP tunnel packet size error: %v\r\n", err)
continue
return
}

// The first 2 packets of response are channel ID
Expand Down

0 comments on commit 0b579cc

Please sign in to comment.