diff --git a/sshmux.go b/sshmux.go index 71de533..c915284 100644 --- a/sshmux.go +++ b/sshmux.go @@ -12,6 +12,7 @@ import ( "net/http" "os" "slices" + "sync" "time" "github.com/pires/go-proxyproto" @@ -359,7 +360,15 @@ func sendLogAndClose(logMessage *LogMessage, session *ssh.PipeSession, logCh cha logCh <- *logMessage } -func sshmuxListenAddr(address string, sshConfig *ssh.ServerConfig, proxy bool, proxyMux bool) { +func sshmuxListenAddr(address string, waitgroup *sync.WaitGroup, sshConfig *ssh.ServerConfig, proxy bool, proxyMux bool) { + // configure waitgroup callback + defer func() { + if waitgroup != nil { + waitgroup.Done() + } + }() + + // set up TCP listener listener, err := net.Listen("tcp", address) if err != nil { log.Fatal(err) @@ -378,8 +387,12 @@ func sshmuxListenAddr(address string, sshConfig *ssh.ServerConfig, proxy bool, p } } defer listener.Close() + + // set up log channel logCh := make(chan LogMessage, 256) go runLogger(logCh) + + // main handler loop for { conn, err := listener.Accept() if err != nil { @@ -427,21 +440,24 @@ func sshmuxServer(configFile string) { } sshConfig.AddHostKey(key) } + waitgroup := sync.WaitGroup{} + waitgroup.Add(1) if config.Address == config.ProxiedAddress { if config.Address == "" { log.Println("No address specified, defaulting to 0.0.0.0:8022") - go sshmuxListenAddr("0.0.0.0:8022", sshConfig, false, false) + go sshmuxListenAddr("0.0.0.0:8022", &waitgroup, sshConfig, false, false) } else { - go sshmuxListenAddr(config.Address, sshConfig, true, true) + go sshmuxListenAddr(config.Address, &waitgroup, sshConfig, true, true) } } else { if config.Address != "" { - go sshmuxListenAddr(config.Address, sshConfig, false, false) + go sshmuxListenAddr(config.Address, &waitgroup, sshConfig, false, false) } if config.ProxiedAddress != "" { - go sshmuxListenAddr(config.ProxiedAddress, sshConfig, true, false) + go sshmuxListenAddr(config.ProxiedAddress, &waitgroup, sshConfig, true, false) } } + waitgroup.Wait() } func main() {