diff --git a/config.example.json b/config.example.json index 4699f07..8a888a9 100644 --- a/config.example.json +++ b/config.example.json @@ -1,5 +1,8 @@ { "address": "0.0.0.0:8022", + "proxy-protocol-allowed-cidrs": [ + "127.0.0.22/32" + ], "host-keys": [ "/tmp/sshmux/ssh_host_ed25519_key", "/tmp/sshmux/ssh_host_ecdsa_key", diff --git a/go.mod b/go.mod index 7071b97..f023e05 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/USTC-vlab/sshmux go 1.21 require golang.org/x/crypto v0.24.0 +require github.com/pires/go-proxyproto v0.7.0 require golang.org/x/sys v0.21.0 // indirect diff --git a/go.sum b/go.sum index 6208220..a9c1221 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs= +github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= diff --git a/sshmux.go b/sshmux.go index e755f10..1156c53 100644 --- a/sshmux.go +++ b/sshmux.go @@ -10,15 +10,18 @@ import ( "log" "net" "net/http" + "net/netip" "os" "slices" "time" + "github.com/pires/go-proxyproto" "golang.org/x/crypto/ssh" ) type Config struct { Address string `json:"address"` + ProxyCIDRs []string `json:"proxy-protocol-allowed-cidrs"` HostKeys []string `json:"host-keys"` API string `json:"api"` Token string `json:"token"` @@ -60,17 +63,19 @@ type AuthRequestPassword struct { } type AuthResponse struct { - Status string `json:"status"` - Address string `json:"address"` - PrivateKey string `json:"private_key"` - Cert string `json:"cert"` - Id int `json:"vmid"` + Status string `json:"status"` + Address string `json:"address"` + PrivateKey string `json:"private_key"` + Cert string `json:"cert"` + Id int `json:"vmid"` + ProxyProtocol byte `json:"proxy_protocol,omitempty"` } type UpstreamInformation struct { - Host string - Signer ssh.Signer - Password *string + Host string + Signer ssh.Signer + Password *string + ProxyProtocol byte } func parsePrivateKey(key string, cert string) ssh.Signer { @@ -127,6 +132,7 @@ func authUser(request any, username string) (*UpstreamInformation, error) { upstream.Host = response.Address } upstream.Signer = parsePrivateKey(response.PrivateKey, response.Cert) + upstream.ProxyProtocol = response.ProxyProtocol return &upstream, nil } @@ -249,6 +255,13 @@ func handshake(session *ssh.PipeSession) error { if err != nil { return err } + if upstream.ProxyProtocol > 0 { + header := proxyproto.HeaderProxyFromAddrs(upstream.ProxyProtocol, session.Downstream.RemoteAddr(), nil) + _, err := header.WriteTo(conn) + if err != nil { + return err + } + } config := &ssh.ClientConfig{ User: user, HostKeyCallback: ssh.InsecureIgnoreHostKey(), @@ -347,37 +360,40 @@ func sendLogAndClose(logMessage *LogMessage, session *ssh.PipeSession, logCh cha logCh <- *logMessage } -func sshmuxServer(configFile string) { - configFileBytes, err := os.ReadFile(configFile) - if err != nil { - log.Fatal(err) - } - err = json.Unmarshal(configFileBytes, &config) +func sshmuxListenAddr(address string, sshConfig *ssh.ServerConfig, proxyUpstreams []netip.Prefix) { + // set up TCP listener + listener, err := net.Listen("tcp", address) if err != nil { log.Fatal(err) } - sshConfig := &ssh.ServerConfig{ - ServerVersion: "SSH-2.0-taokystrong", - PublicKeyAuthAlgorithms: ssh.DefaultPubKeyAuthAlgos(), - } - for _, keyFile := range config.HostKeys { - bytes, err := os.ReadFile(keyFile) - if err != nil { - log.Fatal(err) - } - key, err := ssh.ParsePrivateKey(bytes) - if err != nil { - log.Fatal(err) + if len(proxyUpstreams) > 0 { + listener = &proxyproto.Listener{ + Listener: listener, + Policy: func(upstream net.Addr) (proxyproto.Policy, error) { + // parse upstream address + upstreamAddrPort, err := netip.ParseAddrPort(upstream.String()) + if err != nil { + return proxyproto.SKIP, nil + } + upstreamAddr := upstreamAddrPort.Addr() + // only read PROXY header from allowed CIDRs + for _, network := range proxyUpstreams { + if network.Contains(upstreamAddr) { + return proxyproto.USE, nil + } + } + // do nothing if upstream not in the allow list + return proxyproto.SKIP, nil + }, } - sshConfig.AddHostKey(key) - } - listener, err := net.Listen("tcp", config.Address) - if err != nil { - log.Fatal(err) } defer listener.Close() + + // set up log channel logCh := make(chan LogMessage, 256) go runLogger(logCh) + + // main handler loop for { conn, err := listener.Accept() if err != nil { @@ -401,6 +417,41 @@ func sshmuxServer(configFile string) { } } +func sshmuxServer(configFile string) { + configFileBytes, err := os.ReadFile(configFile) + if err != nil { + log.Fatal(err) + } + err = json.Unmarshal(configFileBytes, &config) + if err != nil { + log.Fatal(err) + } + sshConfig := &ssh.ServerConfig{ + ServerVersion: "SSH-2.0-taokystrong", + PublicKeyAuthAlgorithms: ssh.DefaultPubKeyAuthAlgos(), + } + for _, keyFile := range config.HostKeys { + bytes, err := os.ReadFile(keyFile) + if err != nil { + log.Fatal(err) + } + key, err := ssh.ParsePrivateKey(bytes) + if err != nil { + log.Fatal(err) + } + sshConfig.AddHostKey(key) + } + proxyUpstreams := make([]netip.Prefix, 0) + for _, cidr := range config.ProxyCIDRs { + network, err := netip.ParsePrefix(cidr) + if err != nil { + log.Fatal(err) + } + proxyUpstreams = append(proxyUpstreams, network) + } + sshmuxListenAddr(config.Address, sshConfig, proxyUpstreams) +} + func main() { flag.StringVar(&configFile, "c", "/etc/sshmux/config.json", "config file") flag.Parse() diff --git a/sshmux_test.go b/sshmux_test.go index 350ba1d..20dc6b7 100644 --- a/sshmux_test.go +++ b/sshmux_test.go @@ -2,16 +2,33 @@ package main import ( "encoding/json" + "fmt" "io" "log" + "net" "net/http" "os" "os/exec" "path/filepath" "testing" "time" + + "github.com/pires/go-proxyproto" ) +var sshmuxProxyAddr *net.TCPAddr = localhostTCPAddr(8122) +var sshmuxServerAddr *net.TCPAddr = localhostTCPAddr(8022) +var sshdProxiedAddr *net.TCPAddr = localhostTCPAddr(2332) +var sshdServerAddr *net.TCPAddr = localhostTCPAddr(2333) +var apiServerAddr *net.TCPAddr = localhostTCPAddr(5000) + +func localhostTCPAddr(port int) *net.TCPAddr { + return &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: port, + } +} + func mustGenerateKey(t *testing.T, path, typ string) { err := exec.Command("ssh-keygen", "-t", typ, "-f", path, "-N", "").Run() if err != nil { @@ -20,6 +37,7 @@ func mustGenerateKey(t *testing.T, path, typ string) { } var examplePrivate string +var enableProxy bool func sshAPIHandler(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) @@ -35,10 +53,16 @@ func sshAPIHandler(w http.ResponseWriter, r *http.Request) { res := &AuthResponse{ Status: "ok", - Address: "127.0.0.1:2333", Id: 1141919, PrivateKey: examplePrivate, } + if enableProxy { + res.Address = sshdProxiedAddr.String() + res.ProxyProtocol = 2 + } else { + res.Address = sshdServerAddr.String() + } + jsonRes, err := json.Marshal(res) if err != nil { http.Error(w, "Cannot encode JSON", http.StatusInternalServerError) @@ -51,9 +75,88 @@ func sshAPIHandler(w http.ResponseWriter, r *http.Request) { func initHttp() { http.HandleFunc("/ssh", sshAPIHandler) - if err := http.ListenAndServe("127.0.0.1:5000", nil); err != nil { + if err := http.ListenAndServe(apiServerAddr.String(), nil); err != nil { + log.Fatal(err) + } +} + +func initUpstreamProxyServer() { + listener, err := net.ListenTCP("tcp", sshmuxProxyAddr) + if err != nil { log.Fatal(err) } + defer listener.Close() + + localAddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 22)} + + for { + conn, err := listener.Accept() + if err != nil { + log.Fatal(err) + } + + go func() { + // 1. Set up downstream connection with sshmux + sshmux, err := net.DialTCP("tcp", localAddr, sshmuxServerAddr) + if err != nil { + log.Fatal(err) + } + // 2. Send PROXY header to sshmux + header := proxyproto.HeaderProxyFromAddrs(2, conn.RemoteAddr(), nil) + _, err = header.WriteTo(sshmux) + if err != nil { + log.Fatal(err) + } + // 3. Forward TCP messages in both directions + go func() { + defer sshmux.Close() + io.Copy(sshmux, conn) + }() + go func() { + defer conn.Close() + io.Copy(conn, sshmux) + }() + }() + } +} + +func initDownstreamProxyServer() { + listener, err := net.ListenTCP("tcp", sshdProxiedAddr) + if err != nil { + log.Fatal(err) + } + // Enforce listener to accept PROXY protocol + proxyListener := &proxyproto.Listener{ + Listener: listener, + Policy: func(upstream net.Addr) (proxyproto.Policy, error) { + return proxyproto.REQUIRE, nil + }, + } + defer proxyListener.Close() + + for { + conn, err := proxyListener.Accept() + if err != nil { + log.Fatal(err) + } + + go func() { + // 1. Set up downstream connection with sshd + sshd, err := net.DialTCP("tcp", nil, sshdServerAddr) + if err != nil { + log.Fatal(err) + } + // 2. Forward TCP messages in both directions + go func() { + defer sshd.Close() + io.Copy(sshd, conn) + }() + go func() { + defer conn.Close() + io.Copy(conn, sshd) + }() + }() + } } func initEnv(t *testing.T, baseDir string) { @@ -83,6 +186,8 @@ func initEnv(t *testing.T, baseDir string) { // Setup API Server go initHttp() + go initUpstreamProxyServer() + go initDownstreamProxyServer() } func onetimeSSHDServer(t *testing.T, baseDir string) *exec.Cmd { @@ -93,7 +198,7 @@ func onetimeSSHDServer(t *testing.T, baseDir string) *exec.Cmd { cmd := exec.Command( sshdPath, "-d", "-h", filepath.Join(baseDir, "ssh_host_ed25519_key"), - "-p", "2333", + "-p", fmt.Sprint(sshdServerAddr.Port), "-o", "AuthorizedKeysFile="+filepath.Join(baseDir, "example_rsa.pub"), "-o", "StrictModes=no") cmd.Dir = baseDir @@ -115,38 +220,46 @@ func waitForSSHD(t *testing.T, cmd *exec.Cmd) { } } -func sshCommand(port, privateKeyPath string) *exec.Cmd { +func sshCommand(server *net.TCPAddr, privateKeyPath string) *exec.Cmd { return exec.Command( - "ssh", "-p", port, + "ssh", "-p", fmt.Sprint(server.Port), "-o", "StrictHostKeyChecking=no", "-o", "ControlMaster=no", "-i", privateKeyPath, "-o", "IdentityAgent=no", - "localhost", "uname") + server.IP.String(), "uname") +} + +func testWithSSHClient(t *testing.T, address *net.TCPAddr, description string, proxy bool, baseDir, privateKeyPath string) { + enableProxy = proxy + cmd := onetimeSSHDServer(t, baseDir) + time.Sleep(100 * time.Millisecond) + err := sshCommand(address, privateKeyPath).Run() + if err != nil { + t.Fatal(fmt.Sprintf("%s: ", description), err) + } + waitForSSHD(t, cmd) } func TestSSHClientConnection(t *testing.T) { - sleepDuration := 100 * time.Millisecond baseDir := "/tmp/sshmux" initEnv(t, baseDir) privateKeyPath := filepath.Join(baseDir, "example_rsa") go sshmuxServer("config.example.json") - // Sanity check - cmd := onetimeSSHDServer(t, baseDir) - time.Sleep(sleepDuration) - err := sshCommand("2333", privateKeyPath).Run() - if err != nil { - t.Fatal("sanity check: ", err) - } - waitForSSHD(t, cmd) + // sanity check + testWithSSHClient(t, sshdServerAddr, "sanity check", false, baseDir, privateKeyPath) - cmd = onetimeSSHDServer(t, baseDir) - time.Sleep(sleepDuration) - err = sshCommand("8022", privateKeyPath).Run() - if err != nil { - t.Fatal("ssh: ", err) - } - waitForSSHD(t, cmd) + // test sshmux + testWithSSHClient(t, sshmuxServerAddr, "sshmux", false, baseDir, privateKeyPath) + + // test sshmux with upstream proxy + testWithSSHClient(t, sshmuxProxyAddr, "sshmux (proxied src)", false, baseDir, privateKeyPath) + + // test sshmux with downstream proxy + testWithSSHClient(t, sshmuxServerAddr, "sshmux (proxied dst)", true, baseDir, privateKeyPath) + + // test sshmux with two-way proxy + testWithSSHClient(t, sshmuxProxyAddr, "sshmux (proxied)", true, baseDir, privateKeyPath) }