Skip to content

Commit

Permalink
let parent process clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
glaslos committed Dec 29, 2024
1 parent 52e236f commit e2915c6
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 59 deletions.
20 changes: 9 additions & 11 deletions app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import (
"os"
"os/signal"
"runtime/debug"
"sync"
"syscall"

"github.com/mushorg/glutton"

"github.com/spf13/pflag"
"github.com/spf13/viper"
)
Expand All @@ -32,7 +32,7 @@ func main() {
\_____|_|\__,_|\__|\__\___/|_| |_|
`)
fmt.Printf("%s %s\n", VERSION, BUILDDATE)
fmt.Printf("%s %s\n\n", VERSION, BUILDDATE)

pflag.StringP("interface", "i", "eth0", "Bind to this interface")
pflag.IntP("ssh", "s", 0, "Override SSH port")
Expand All @@ -53,39 +53,37 @@ func main() {
return
}

gtn, err := glutton.New(context.Background())
g, err := glutton.New(context.Background())
if err != nil {
log.Fatal(err)
}

if err := gtn.Init(); err != nil {
if err := g.Init(); err != nil {
log.Fatal("Failed to initialize Glutton:", err)
}

exitMtx := sync.RWMutex{}
exit := func() {
// See if there was a panic...
if r := recover(); r != nil {
fmt.Fprintln(os.Stderr, r)
fmt.Println("stacktrace from panic: \n" + string(debug.Stack()))
}
exitMtx.Lock()
gtn.Shutdown()
exitMtx.Unlock()
g.Shutdown()
}
defer exit()

// capture and handle signals
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
go func() {
<-sig
fmt.Print("\r")
exit()
fmt.Println("\nleaving...")
fmt.Println()
os.Exit(0)
}()

if err := gtn.Start(); err != nil {
if err := g.Start(); err != nil {
exit()
log.Fatal("Failed to start Glutton server:", err)
}
}
67 changes: 35 additions & 32 deletions glutton.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,55 +148,71 @@ func (g *Glutton) Init() error {
}

func (g *Glutton) udpListen(wg *sync.WaitGroup) {
defer wg.Done()
defer func() {
wg.Done()
}()
buffer := make([]byte, 1024)
for {
n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(g.Server.udpListener, buffer)
select {
case <-g.ctx.Done():
if err := g.Server.udpConn.Close(); err != nil {
g.Logger.Error("Failed to close UDP listener", producer.ErrAttr(err))
}
return
default:
}
g.Server.udpConn.SetReadDeadline(time.Now().Add(1 * time.Second))
n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(g.Server.udpConn, buffer)
if err != nil {
g.Logger.Error("failed to read UDP packet", producer.ErrAttr(err))
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
continue
}
g.Logger.Error("Failed to read UDP packet", producer.ErrAttr(err))
}

rule, err := g.applyRules("udp", srcAddr, dstAddr)
if err != nil {
g.Logger.Error("failed to apply rules", producer.ErrAttr(err))
g.Logger.Error("Failed to apply rules", producer.ErrAttr(err))
}
if rule == nil {
rule = &rules.Rule{Target: "udp"}
}
md, err := g.connTable.Register(srcAddr.IP.String(), strconv.Itoa(int(srcAddr.AddrPort().Port())), dstAddr.AddrPort().Port(), rule)
if err != nil {
g.Logger.Error("failed to register UDP packet", producer.ErrAttr(err))
g.Logger.Error("Failed to register UDP packet", producer.ErrAttr(err))
}

if hfunc, ok := g.udpProtocolHandlers[rule.Target]; ok {
data := buffer[:n]
go func() {
if err := hfunc(g.ctx, srcAddr, dstAddr, data, md); err != nil {
g.Logger.Error("failed to handle UDP payload", producer.ErrAttr(err))
g.Logger.Error("Failed to handle UDP payload", producer.ErrAttr(err))
}
}()
}
}
}

func (g *Glutton) tcpListen(wg *sync.WaitGroup) {
defer wg.Done()
func (g *Glutton) tcpListen() {
for {
select {
case <-g.ctx.Done():
if err := g.Server.tcpListener.Close(); err != nil {
g.Logger.Error("Failed to close TCP listener", producer.ErrAttr(err))
}
return
default:
}

conn, err := g.Server.tcpListener.Accept()
if err != nil {
g.Logger.Error("failed to accept connection", producer.ErrAttr(err))
g.Logger.Error("Failed to accept connection", producer.ErrAttr(err))
continue
}

rule, err := g.applyRulesOnConn(conn)
if err != nil {
g.Logger.Error("failed to apply rules", producer.ErrAttr(err))
g.Logger.Error("Failed to apply rules", producer.ErrAttr(err))
continue
}
if rule == nil {
Expand All @@ -205,21 +221,21 @@ func (g *Glutton) tcpListen(wg *sync.WaitGroup) {

md, err := g.connTable.RegisterConn(conn, rule)
if err != nil {
g.Logger.Error("failed to register connection", producer.ErrAttr(err))
g.Logger.Error("Failed to register connection", producer.ErrAttr(err))
continue
}

g.Logger.Debug("new connection", slog.String("addr", conn.LocalAddr().String()), slog.String("handler", rule.Target))

g.ctx = context.WithValue(g.ctx, ctxTimeout("timeout"), int64(viper.GetInt("conn_timeout")))
if err := g.UpdateConnectionTimeout(g.ctx, conn); err != nil {
g.Logger.Error("failed to set connection timeout", producer.ErrAttr(err))
g.Logger.Error("Failed to set connection timeout", producer.ErrAttr(err))
}

if hfunc, ok := g.tcpProtocolHandlers[rule.Target]; ok {
go func() {
if err := hfunc(g.ctx, conn, md); err != nil {
g.Logger.Error("failed to handle TCP connection", producer.ErrAttr(err), slog.String("handler", rule.Target))
g.Logger.Error("Failed to handle TCP connection", producer.ErrAttr(err), slog.String("handler", rule.Target))
}
}()
}
Expand All @@ -228,13 +244,7 @@ func (g *Glutton) tcpListen(wg *sync.WaitGroup) {

// Start the listener, this blocks for new connections
func (g *Glutton) Start() error {
quit := make(chan struct{}) // stop monitor on shutdown
defer func() {
quit <- struct{}{}
g.Shutdown()
}()

g.startMonitor(quit)
g.startMonitor()

sshPort := viper.GetUint32("ports.ssh")
if err := setTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "tcp", uint32(g.Server.tcpPort), sshPort); err != nil {
Expand All @@ -249,9 +259,7 @@ func (g *Glutton) Start() error {

wg.Add(1)
go g.udpListen(wg)

wg.Add(1)
go g.tcpListen(wg)
go g.tcpListen()

wg.Wait()

Expand Down Expand Up @@ -350,18 +358,13 @@ func (g *Glutton) ProduceUDP(handler string, srcAddr, dstAddr *net.UDPAddr, md c
func (g *Glutton) Shutdown() {
g.cancel() // close all connection

g.Logger.Info("Shutting down listeners")
if err := g.Server.Shutdown(); err != nil {
g.Logger.Error("failed to shutdown server", producer.ErrAttr(err))
}

g.Logger.Info("FLushing TCP iptables")
g.Logger.Info("Flushing TCP iptables")
if err := flushTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "tcp", uint32(g.Server.tcpPort), uint32(viper.GetInt("ports.ssh"))); err != nil {
g.Logger.Error("failed to drop tcp iptables", producer.ErrAttr(err))
g.Logger.Error("Failed to drop tcp iptables", producer.ErrAttr(err))
}
g.Logger.Info("FLushing UDP iptables")
g.Logger.Info("Flushing UDP iptables")
if err := flushTProxyIPTables(viper.GetString("interface"), g.publicAddrs[0].String(), "udp", uint32(g.Server.udpPort), uint32(viper.GetInt("ports.ssh"))); err != nil {
g.Logger.Error("failed to drop udp iptables", producer.ErrAttr(err))
g.Logger.Error("Failed to drop udp iptables", producer.ErrAttr(err))
}

g.Logger.Info("All done")
Expand Down
18 changes: 4 additions & 14 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

type Server struct {
tcpListener net.Listener
udpListener *net.UDPConn
udpConn *net.UDPConn
tcpPort uint
udpPort uint
}
Expand All @@ -31,26 +31,16 @@ func (s *Server) Start() error {
if s.tcpListener, err = tproxy.ListenTCP("tcp4", tcpAddr); err != nil {
return err
}

udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", s.udpPort))
if err != nil {
return err
}
if s.udpListener, err = tproxy.ListenUDP("udp4", udpAddr); err != nil {
if s.udpConn, err = tproxy.ListenUDP("udp4", udpAddr); err != nil {
return err
}
if s.udpListener == nil {
if s.udpConn == nil {
return errors.New("nil udp listener")
}
return nil
}

func (s *Server) Shutdown() error {
var err error
if s.tcpListener != nil {
err = s.tcpListener.Close()
}
if s.udpListener != nil {
err = s.udpListener.Close()
}
return err
}
2 changes: 0 additions & 2 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,4 @@ import (
func TestServer(t *testing.T) {
server := NewServer(1234, 1235)
require.NotNil(t, server)
err := server.Shutdown()
require.NoError(t, err)
}

0 comments on commit e2915c6

Please sign in to comment.