From e2915c6b0a312909161b4ca48e5143f6bc0ff6b7 Mon Sep 17 00:00:00 2001 From: Lukas Rist Date: Sun, 29 Dec 2024 12:32:37 +0100 Subject: [PATCH] let parent process clean up --- app/server.go | 20 +++++++-------- glutton.go | 67 ++++++++++++++++++++++++++------------------------ server.go | 18 +++----------- server_test.go | 2 -- 4 files changed, 48 insertions(+), 59 deletions(-) diff --git a/app/server.go b/app/server.go index cf245ea..b0ee91f 100644 --- a/app/server.go +++ b/app/server.go @@ -7,10 +7,10 @@ import ( "os" "os/signal" "runtime/debug" - "sync" "syscall" "github.com/mushorg/glutton" + "github.com/spf13/pflag" "github.com/spf13/viper" ) @@ -32,7 +32,7 @@ func main() { \_____|_|\__,_|\__|\__\___/|_| |_| `) - fmt.Printf("%s %s\n", VERSION, BUILDDATE) + fmt.Printf("%s %s\n\n", VERSION, BUILDDATE) pflag.StringP("interface", "i", "eth0", "Bind to this interface") pflag.IntP("ssh", "s", 0, "Override SSH port") @@ -53,39 +53,37 @@ func main() { return } - gtn, err := glutton.New(context.Background()) + g, err := glutton.New(context.Background()) if err != nil { log.Fatal(err) } - if err := gtn.Init(); err != nil { + if err := g.Init(); err != nil { log.Fatal("Failed to initialize Glutton:", err) } - exitMtx := sync.RWMutex{} exit := func() { // See if there was a panic... if r := recover(); r != nil { fmt.Fprintln(os.Stderr, r) fmt.Println("stacktrace from panic: \n" + string(debug.Stack())) } - exitMtx.Lock() - gtn.Shutdown() - exitMtx.Unlock() + g.Shutdown() } - defer exit() // capture and handle signals sig := make(chan os.Signal, 1) signal.Notify(sig, os.Interrupt, syscall.SIGTERM) go func() { <-sig + fmt.Print("\r") exit() - fmt.Println("\nleaving...") + fmt.Println() os.Exit(0) }() - if err := gtn.Start(); err != nil { + if err := g.Start(); err != nil { + exit() log.Fatal("Failed to start Glutton server:", err) } } diff --git a/glutton.go b/glutton.go index 7a54247..f7f3dda 100644 --- a/glutton.go +++ b/glutton.go @@ -148,55 +148,71 @@ func (g *Glutton) Init() error { } func (g *Glutton) udpListen(wg *sync.WaitGroup) { - defer wg.Done() + defer func() { + wg.Done() + }() buffer := make([]byte, 1024) for { - n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(g.Server.udpListener, buffer) + select { + case <-g.ctx.Done(): + if err := g.Server.udpConn.Close(); err != nil { + g.Logger.Error("Failed to close UDP listener", producer.ErrAttr(err)) + } + return + default: + } + g.Server.udpConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(g.Server.udpConn, buffer) if err != nil { - g.Logger.Error("failed to read UDP packet", producer.ErrAttr(err)) + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + continue + } + g.Logger.Error("Failed to read UDP packet", producer.ErrAttr(err)) } rule, err := g.applyRules("udp", srcAddr, dstAddr) if err != nil { - g.Logger.Error("failed to apply rules", producer.ErrAttr(err)) + g.Logger.Error("Failed to apply rules", producer.ErrAttr(err)) } if rule == nil { rule = &rules.Rule{Target: "udp"} } md, err := g.connTable.Register(srcAddr.IP.String(), strconv.Itoa(int(srcAddr.AddrPort().Port())), dstAddr.AddrPort().Port(), rule) if err != nil { - g.Logger.Error("failed to register UDP packet", producer.ErrAttr(err)) + g.Logger.Error("Failed to register UDP packet", producer.ErrAttr(err)) } if hfunc, ok := g.udpProtocolHandlers[rule.Target]; ok { data := buffer[:n] go func() { if err := hfunc(g.ctx, srcAddr, dstAddr, data, md); err != nil { - g.Logger.Error("failed to handle UDP payload", producer.ErrAttr(err)) + g.Logger.Error("Failed to handle UDP payload", producer.ErrAttr(err)) } }() } } } -func (g *Glutton) tcpListen(wg *sync.WaitGroup) { - defer wg.Done() +func (g *Glutton) tcpListen() { for { select { case <-g.ctx.Done(): + if err := g.Server.tcpListener.Close(); err != nil { + g.Logger.Error("Failed to close TCP listener", producer.ErrAttr(err)) + } return default: } conn, err := g.Server.tcpListener.Accept() if err != nil { - g.Logger.Error("failed to accept connection", producer.ErrAttr(err)) + g.Logger.Error("Failed to accept connection", producer.ErrAttr(err)) continue } rule, err := g.applyRulesOnConn(conn) if err != nil { - g.Logger.Error("failed to apply rules", producer.ErrAttr(err)) + g.Logger.Error("Failed to apply rules", producer.ErrAttr(err)) continue } if rule == nil { @@ -205,7 +221,7 @@ func (g *Glutton) tcpListen(wg *sync.WaitGroup) { md, err := g.connTable.RegisterConn(conn, rule) if err != nil { - g.Logger.Error("failed to register connection", producer.ErrAttr(err)) + g.Logger.Error("Failed to register connection", producer.ErrAttr(err)) continue } @@ -213,13 +229,13 @@ func (g *Glutton) tcpListen(wg *sync.WaitGroup) { g.ctx = context.WithValue(g.ctx, ctxTimeout("timeout"), int64(viper.GetInt("conn_timeout"))) if err := g.UpdateConnectionTimeout(g.ctx, conn); err != nil { - g.Logger.Error("failed to set connection timeout", producer.ErrAttr(err)) + g.Logger.Error("Failed to set connection timeout", producer.ErrAttr(err)) } if hfunc, ok := g.tcpProtocolHandlers[rule.Target]; ok { go func() { if err := hfunc(g.ctx, conn, md); err != nil { - g.Logger.Error("failed to handle TCP connection", producer.ErrAttr(err), slog.String("handler", rule.Target)) + g.Logger.Error("Failed to handle TCP connection", producer.ErrAttr(err), slog.String("handler", rule.Target)) } }() } @@ -228,13 +244,7 @@ func (g *Glutton) tcpListen(wg *sync.WaitGroup) { // Start the listener, this blocks for new connections func (g *Glutton) Start() error { - quit := make(chan struct{}) // stop monitor on shutdown - defer func() { - quit <- struct{}{} - g.Shutdown() - }() - - g.startMonitor(quit) + g.startMonitor() sshPort := viper.GetUint32("ports.ssh") if err := setTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "tcp", uint32(g.Server.tcpPort), sshPort); err != nil { @@ -249,9 +259,7 @@ func (g *Glutton) Start() error { wg.Add(1) go g.udpListen(wg) - - wg.Add(1) - go g.tcpListen(wg) + go g.tcpListen() wg.Wait() @@ -350,18 +358,13 @@ func (g *Glutton) ProduceUDP(handler string, srcAddr, dstAddr *net.UDPAddr, md c func (g *Glutton) Shutdown() { g.cancel() // close all connection - g.Logger.Info("Shutting down listeners") - if err := g.Server.Shutdown(); err != nil { - g.Logger.Error("failed to shutdown server", producer.ErrAttr(err)) - } - - g.Logger.Info("FLushing TCP iptables") + g.Logger.Info("Flushing TCP iptables") if err := flushTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "tcp", uint32(g.Server.tcpPort), uint32(viper.GetInt("ports.ssh"))); err != nil { - g.Logger.Error("failed to drop tcp iptables", producer.ErrAttr(err)) + g.Logger.Error("Failed to drop tcp iptables", producer.ErrAttr(err)) } - g.Logger.Info("FLushing UDP iptables") + g.Logger.Info("Flushing UDP iptables") if err := flushTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "udp", uint32(g.Server.udpPort), uint32(viper.GetInt("ports.ssh"))); err != nil { - g.Logger.Error("failed to drop udp iptables", producer.ErrAttr(err)) + g.Logger.Error("Failed to drop udp iptables", producer.ErrAttr(err)) } g.Logger.Info("All done") diff --git a/server.go b/server.go index f75a2cd..bb8087b 100644 --- a/server.go +++ b/server.go @@ -10,7 +10,7 @@ import ( type Server struct { tcpListener net.Listener - udpListener *net.UDPConn + udpConn *net.UDPConn tcpPort uint udpPort uint } @@ -31,26 +31,16 @@ func (s *Server) Start() error { if s.tcpListener, err = tproxy.ListenTCP("tcp4", tcpAddr); err != nil { return err } + udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", s.udpPort)) if err != nil { return err } - if s.udpListener, err = tproxy.ListenUDP("udp4", udpAddr); err != nil { + if s.udpConn, err = tproxy.ListenUDP("udp4", udpAddr); err != nil { return err } - if s.udpListener == nil { + if s.udpConn == nil { return errors.New("nil udp listener") } return nil } - -func (s *Server) Shutdown() error { - var err error - if s.tcpListener != nil { - err = s.tcpListener.Close() - } - if s.udpListener != nil { - err = s.udpListener.Close() - } - return err -} diff --git a/server_test.go b/server_test.go index 515aa23..fd935b5 100644 --- a/server_test.go +++ b/server_test.go @@ -9,6 +9,4 @@ import ( func TestServer(t *testing.T) { server := NewServer(1234, 1235) require.NotNil(t, server) - err := server.Shutdown() - require.NoError(t, err) }