Skip to content

Commit

Permalink
PROXY protocol support (#4)
Browse files Browse the repository at this point in the history
* Add `go-proxyproto` dependency

* Implement two-way PROXY protocol

* Enable PROXY protocol in tests

* Fix accidental close of TCP handlers

* Fix typo

* Early exit in tests

* Fix unexpected dest address

* Final fix (hopefully)

* Update proxy protocol APIs

* Update sshmux_test.go

* feat: `proxy_protocol` in API response

* refactor test for maintainability

* Try to improve readability

* Restore compatibility with older API servers

* Minor update

* Allow control on PROXY protocol version

* `go-proxyproto` is a direct dependency

* Address feedback from review

* Adjust API according to feedback

* Allow proxy server on the same port

* Split `sshmuxListenAddr`

* Correctly shut down proxy connection in tests

* Minor update

* Add synchronization for TCP handlers

* Let's mux with PROXY protocol!
  • Loading branch information
stevapple authored Jul 29, 2024
1 parent 7da33ed commit 0438321
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 53 deletions.
3 changes: 3 additions & 0 deletions config.example.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
{
"address": "0.0.0.0:8022",
"proxy-protocol-allowed-cidrs": [
"127.0.0.22/32"
],
"host-keys": [
"/tmp/sshmux/ssh_host_ed25519_key",
"/tmp/sshmux/ssh_host_ecdsa_key",
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/USTC-vlab/sshmux
go 1.21

require golang.org/x/crypto v0.24.0
require github.com/pires/go-proxyproto v0.7.0

require golang.org/x/sys v0.21.0 // indirect

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs=
github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
Expand Down
113 changes: 82 additions & 31 deletions sshmux.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@ import (
"log"
"net"
"net/http"
"net/netip"
"os"
"slices"
"time"

"github.com/pires/go-proxyproto"
"golang.org/x/crypto/ssh"
)

type Config struct {
Address string `json:"address"`
ProxyCIDRs []string `json:"proxy-protocol-allowed-cidrs"`
HostKeys []string `json:"host-keys"`
API string `json:"api"`
Token string `json:"token"`
Expand Down Expand Up @@ -60,17 +63,19 @@ type AuthRequestPassword struct {
}

type AuthResponse struct {
Status string `json:"status"`
Address string `json:"address"`
PrivateKey string `json:"private_key"`
Cert string `json:"cert"`
Id int `json:"vmid"`
Status string `json:"status"`
Address string `json:"address"`
PrivateKey string `json:"private_key"`
Cert string `json:"cert"`
Id int `json:"vmid"`
ProxyProtocol byte `json:"proxy_protocol,omitempty"`
}

type UpstreamInformation struct {
Host string
Signer ssh.Signer
Password *string
Host string
Signer ssh.Signer
Password *string
ProxyProtocol byte
}

func parsePrivateKey(key string, cert string) ssh.Signer {
Expand Down Expand Up @@ -127,6 +132,7 @@ func authUser(request any, username string) (*UpstreamInformation, error) {
upstream.Host = response.Address
}
upstream.Signer = parsePrivateKey(response.PrivateKey, response.Cert)
upstream.ProxyProtocol = response.ProxyProtocol
return &upstream, nil
}

Expand Down Expand Up @@ -249,6 +255,13 @@ func handshake(session *ssh.PipeSession) error {
if err != nil {
return err
}
if upstream.ProxyProtocol > 0 {
header := proxyproto.HeaderProxyFromAddrs(upstream.ProxyProtocol, session.Downstream.RemoteAddr(), nil)
_, err := header.WriteTo(conn)
if err != nil {
return err
}
}
config := &ssh.ClientConfig{
User: user,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Expand Down Expand Up @@ -347,37 +360,40 @@ func sendLogAndClose(logMessage *LogMessage, session *ssh.PipeSession, logCh cha
logCh <- *logMessage
}

func sshmuxServer(configFile string) {
configFileBytes, err := os.ReadFile(configFile)
if err != nil {
log.Fatal(err)
}
err = json.Unmarshal(configFileBytes, &config)
func sshmuxListenAddr(address string, sshConfig *ssh.ServerConfig, proxyUpstreams []netip.Prefix) {
// set up TCP listener
listener, err := net.Listen("tcp", address)
if err != nil {
log.Fatal(err)
}
sshConfig := &ssh.ServerConfig{
ServerVersion: "SSH-2.0-taokystrong",
PublicKeyAuthAlgorithms: ssh.DefaultPubKeyAuthAlgos(),
}
for _, keyFile := range config.HostKeys {
bytes, err := os.ReadFile(keyFile)
if err != nil {
log.Fatal(err)
}
key, err := ssh.ParsePrivateKey(bytes)
if err != nil {
log.Fatal(err)
if len(proxyUpstreams) > 0 {
listener = &proxyproto.Listener{
Listener: listener,
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
// parse upstream address
upstreamAddrPort, err := netip.ParseAddrPort(upstream.String())
if err != nil {
return proxyproto.SKIP, nil
}
upstreamAddr := upstreamAddrPort.Addr()
// only read PROXY header from allowed CIDRs
for _, network := range proxyUpstreams {
if network.Contains(upstreamAddr) {
return proxyproto.USE, nil
}
}
// do nothing if upstream not in the allow list
return proxyproto.SKIP, nil
},
}
sshConfig.AddHostKey(key)
}
listener, err := net.Listen("tcp", config.Address)
if err != nil {
log.Fatal(err)
}
defer listener.Close()

// set up log channel
logCh := make(chan LogMessage, 256)
go runLogger(logCh)

// main handler loop
for {
conn, err := listener.Accept()
if err != nil {
Expand All @@ -401,6 +417,41 @@ func sshmuxServer(configFile string) {
}
}

func sshmuxServer(configFile string) {
configFileBytes, err := os.ReadFile(configFile)
if err != nil {
log.Fatal(err)
}
err = json.Unmarshal(configFileBytes, &config)
if err != nil {
log.Fatal(err)
}
sshConfig := &ssh.ServerConfig{
ServerVersion: "SSH-2.0-taokystrong",
PublicKeyAuthAlgorithms: ssh.DefaultPubKeyAuthAlgos(),
}
for _, keyFile := range config.HostKeys {
bytes, err := os.ReadFile(keyFile)
if err != nil {
log.Fatal(err)
}
key, err := ssh.ParsePrivateKey(bytes)
if err != nil {
log.Fatal(err)
}
sshConfig.AddHostKey(key)
}
proxyUpstreams := make([]netip.Prefix, 0)
for _, cidr := range config.ProxyCIDRs {
network, err := netip.ParsePrefix(cidr)
if err != nil {
log.Fatal(err)
}
proxyUpstreams = append(proxyUpstreams, network)
}
sshmuxListenAddr(config.Address, sshConfig, proxyUpstreams)
}

func main() {
flag.StringVar(&configFile, "c", "/etc/sshmux/config.json", "config file")
flag.Parse()
Expand Down
Loading

0 comments on commit 0438321

Please sign in to comment.