From 8a43dba72020b6af3a5e48782c7e35ffa8aff8a0 Mon Sep 17 00:00:00 2001 From: YR Chen Date: Tue, 27 Aug 2024 03:00:09 +0800 Subject: [PATCH] Add control on PROXY protocol destination --- auth.go | 18 ++++++++++++------ legacy_auth.go | 4 ++-- sshmux.go | 50 ++++++++++++++++++++++++++++++++++++-------------- sshmux_test.go | 16 +++++++++------- 4 files changed, 59 insertions(+), 29 deletions(-) diff --git a/auth.go b/auth.go index 26f77d8..1e83e33 100644 --- a/auth.go +++ b/auth.go @@ -41,12 +41,18 @@ type AuthFailure struct { } type AuthUpstream struct { - Host string `json:"host"` - Port uint16 `json:"port,omitempty"` - PrivateKey string `json:"private_key,omitempty"` - Certificate string `json:"certificate,omitempty"` - Password *string `json:"password,omitempty"` - ProxyProtocol *string `json:"proxy_protocol,omitempty"` + Host string `json:"host"` + Port uint16 `json:"port,omitempty"` + PrivateKey string `json:"private_key,omitempty"` + Certificate string `json:"certificate,omitempty"` + Password *string `json:"password,omitempty"` + Proxy *AuthProxy `json:"proxy,omitempty"` +} + +type AuthProxy struct { + Host string `json:"host,omitempty"` + Port uint16 `json:"port,omitempty"` + Protocol *string `json:"protocol,omitempty"` } type Authenticator interface { diff --git a/legacy_auth.go b/legacy_auth.go index 1d24bdd..eadf435 100644 --- a/legacy_auth.go +++ b/legacy_auth.go @@ -139,8 +139,8 @@ func (auth *LegacyAuthenticator) Auth(request AuthRequest, username string) (int auth_upstream.Password = &unix_password } if upstream.ProxyProtocol > 0 { - proxyProtocol := fmt.Sprintf("v%d", upstream.ProxyProtocol) - auth_upstream.ProxyProtocol = &proxyProtocol + protocolVersion := fmt.Sprintf("v%d", upstream.ProxyProtocol) + auth_upstream.Proxy = &AuthProxy{Protocol: &protocolVersion} } resp := AuthResponse{Upstream: &auth_upstream} return 200, &resp, nil diff --git a/sshmux.go b/sshmux.go index d2ef6a8..274a80c 100644 --- a/sshmux.go +++ b/sshmux.go @@ -33,10 +33,11 @@ type Server struct { } type upstreamInformation struct { - Address string - Signer ssh.Signer - Password *string - ProxyProtocol byte + Address string + Signer ssh.Signer + Password *string + ProxyProtocol *byte + ProxyDestination string } func validateKey(config SSHKeyConfig) (ssh.Signer, error) { @@ -219,15 +220,30 @@ auth_requests: Password: upstreamResp.Password, } upstream.Address = net.JoinHostPort(upstreamResp.Host, fmt.Sprintf("%d", upstreamResp.Port)) - if upstreamResp.ProxyProtocol != nil { - switch *upstreamResp.ProxyProtocol { - case "v1": - upstream.ProxyProtocol = 1 - case "v2": - upstream.ProxyProtocol = 2 - default: - return fmt.Errorf("unknown PROXY protocol version: %s", *upstreamResp.ProxyProtocol) + if upstreamResp.Proxy != nil { + proxyConfig := *upstreamResp.Proxy + // parse protocol version + var protocolVersion byte + if proxyConfig.Protocol != nil { + switch *proxyConfig.Protocol { + case "v1": + protocolVersion = 1 + case "v2": + protocolVersion = 2 + default: + return fmt.Errorf("unknown PROXY protocol version: %s", *proxyConfig.Protocol) + } + } + upstream.ProxyProtocol = &protocolVersion + // parse protocol destination + upstream.ProxyDestination = upstream.Address + if proxyConfig.Host == "" { + proxyConfig.Host = upstreamResp.Host } + if proxyConfig.Port == 0 { + proxyConfig.Port = upstreamResp.Port + } + upstream.Address = net.JoinHostPort(proxyConfig.Host, fmt.Sprintf("%d", proxyConfig.Port)) } break auth_requests case 401: @@ -270,8 +286,14 @@ auth_requests: if err != nil { return err } - if upstream.ProxyProtocol > 0 { - header := proxyproto.HeaderProxyFromAddrs(upstream.ProxyProtocol, session.Downstream.RemoteAddr(), conn.RemoteAddr()) + if upstream.ProxyProtocol != nil { + dest := conn.RemoteAddr() + if upstream.ProxyDestination != upstream.Address { + if addr, err := net.ResolveTCPAddr("tcp", upstream.ProxyDestination); err == nil { + dest = addr + } + } + header := proxyproto.HeaderProxyFromAddrs(*upstream.ProxyProtocol, session.Downstream.RemoteAddr(), dest) _, err := header.WriteTo(conn) if err != nil { return err diff --git a/sshmux_test.go b/sshmux_test.go index 3afff3b..f9467e8 100644 --- a/sshmux_test.go +++ b/sshmux_test.go @@ -78,14 +78,16 @@ func initHttp(sshPrivateKey []byte) { return } - upstream := map[string]any{"private_key": string(sshPrivateKey)} + upstream := map[string]any{ + "host": sshdServerAddr.IP.String(), + "port": sshdServerAddr.Port, + "private_key": string(sshPrivateKey), + } if enableProxy { - upstream["host"] = sshdProxiedAddr.IP.String() - upstream["port"] = sshdProxiedAddr.Port - upstream["proxy_protocol"] = "v2" - } else { - upstream["host"] = sshdServerAddr.IP.String() - upstream["port"] = sshdServerAddr.Port + upstream["proxy"] = map[string]any{ + "host": sshdProxiedAddr.IP.String(), + "port": sshdProxiedAddr.Port, + } } res := map[string]any{"upstream": upstream}