Skip to content

Commit

Permalink
Allow fine-grained control on PROXY protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
stevapple committed Aug 26, 2024
1 parent a6d05fc commit c3ab411
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 38 deletions.
18 changes: 12 additions & 6 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type AuthResponse struct {
Challenges []AuthChallenge `json:"challenges,omitempty"`
Failure *AuthFailure `json:"failure,omitempty"`
Upstream *AuthUpstream `json:"upstream,omitempty"`
Proxy *AuthProxy `json:"proxy,omitempty"`
}

type AuthChallenge struct {
Expand All @@ -41,12 +42,17 @@ type AuthFailure struct {
}

type AuthUpstream struct {
Host string `json:"host"`
Port uint16 `json:"port,omitempty"`
PrivateKey string `json:"private_key,omitempty"`
Certificate string `json:"certificate,omitempty"`
Password *string `json:"password,omitempty"`
ProxyProtocol *string `json:"proxy_protocol,omitempty"`
Host string `json:"host"`
Port uint16 `json:"port,omitempty"`
PrivateKey string `json:"private_key,omitempty"`
Certificate string `json:"certificate,omitempty"`
Password *string `json:"password,omitempty"`
}

type AuthProxy struct {
Host string `json:"host,omitempty"`
Port uint16 `json:"port,omitempty"`
Protocol *string `json:"protocol,omitempty"`
}

type Authenticator interface {
Expand Down
21 changes: 11 additions & 10 deletions legacy_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,23 @@ func (auth *LegacyAuthenticator) Auth(request AuthRequest, username string) (int
if err != nil {
return 500, nil, err
}
auth_upstream := AuthUpstream{
Host: address.Addr().String(),
Port: address.Port(),
PrivateKey: upstream.PrivateKey,
Certificate: upstream.Certificate,
Password: upstream.Password,
resp := AuthResponse{
Upstream: &AuthUpstream{
Host: address.Addr().String(),
Port: address.Port(),
PrivateKey: upstream.PrivateKey,
Certificate: upstream.Certificate,
Password: upstream.Password,
},
}
unix_password, has_unix_password := request.Payload["unix_password"]
if has_unix_password {
auth_upstream.Password = &unix_password
resp.Upstream.Password = &unix_password
}
if upstream.ProxyProtocol > 0 {
proxyProtocol := fmt.Sprintf("v%d", upstream.ProxyProtocol)
auth_upstream.ProxyProtocol = &proxyProtocol
protocolVersion := fmt.Sprintf("v%d", upstream.ProxyProtocol)
resp.Proxy = &AuthProxy{Protocol: &protocolVersion}
}
resp := AuthResponse{Upstream: &auth_upstream}
return 200, &resp, nil
}
return 403, &AuthResponse{}, nil
Expand Down
50 changes: 36 additions & 14 deletions sshmux.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ type Server struct {
}

type upstreamInformation struct {
Address string
Signer ssh.Signer
Password *string
ProxyProtocol byte
Address string
Signer ssh.Signer
Password *string
ProxyProtocol *byte
ProxyDestination string
}

func validateKey(config SSHKeyConfig) (ssh.Signer, error) {
Expand Down Expand Up @@ -219,15 +220,30 @@ auth_requests:
Password: upstreamResp.Password,
}
upstream.Address = net.JoinHostPort(upstreamResp.Host, fmt.Sprintf("%d", upstreamResp.Port))
if upstreamResp.ProxyProtocol != nil {
switch *upstreamResp.ProxyProtocol {
case "v1":
upstream.ProxyProtocol = 1
case "v2":
upstream.ProxyProtocol = 2
default:
return fmt.Errorf("unknown PROXY protocol version: %s", *upstreamResp.ProxyProtocol)
if resp.Proxy != nil {
proxyConfig := *resp.Proxy
// parse protocol version
var protocolVersion byte
if proxyConfig.Protocol != nil {
switch *proxyConfig.Protocol {
case "v1":
protocolVersion = 1
case "v2":
protocolVersion = 2
default:
return fmt.Errorf("unknown PROXY protocol version: %s", *proxyConfig.Protocol)
}
}
upstream.ProxyProtocol = &protocolVersion
// parse protocol destination
upstream.ProxyDestination = upstream.Address
if proxyConfig.Host == "" {
proxyConfig.Host = upstreamResp.Host
}
if proxyConfig.Port == 0 {
proxyConfig.Port = upstreamResp.Port
}
upstream.Address = net.JoinHostPort(proxyConfig.Host, fmt.Sprintf("%d", proxyConfig.Port))
}
break auth_requests
case 401:
Expand Down Expand Up @@ -270,8 +286,14 @@ auth_requests:
if err != nil {
return err
}
if upstream.ProxyProtocol > 0 {
header := proxyproto.HeaderProxyFromAddrs(upstream.ProxyProtocol, session.Downstream.RemoteAddr(), conn.RemoteAddr())
if upstream.ProxyProtocol != nil {
dest := conn.RemoteAddr()
if upstream.ProxyDestination != upstream.Address {
if addr, err := net.ResolveTCPAddr("tcp", upstream.ProxyDestination); err == nil {
dest = addr
}
}
header := proxyproto.HeaderProxyFromAddrs(*upstream.ProxyProtocol, session.Downstream.RemoteAddr(), dest)
_, err := header.WriteTo(conn)
if err != nil {
return err
Expand Down
19 changes: 11 additions & 8 deletions sshmux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,20 @@ func initHttp(sshPrivateKey []byte) {
return
}

upstream := map[string]any{"private_key": string(sshPrivateKey)}
res := map[string]any{
"upstream": map[string]any{
"host": sshdServerAddr.IP.String(),
"port": sshdServerAddr.Port,
"private_key": string(sshPrivateKey),
},
}
if enableProxy {
upstream["host"] = sshdProxiedAddr.IP.String()
upstream["port"] = sshdProxiedAddr.Port
upstream["proxy_protocol"] = "v2"
} else {
upstream["host"] = sshdServerAddr.IP.String()
upstream["port"] = sshdServerAddr.Port
res["proxy"] = map[string]any{
"host": sshdProxiedAddr.IP.String(),
"port": sshdProxiedAddr.Port,
}
}

res := map[string]any{"upstream": upstream}
jsonRes, err := json.Marshal(res)
if err != nil {
http.Error(w, "Cannot encode JSON", http.StatusInternalServerError)
Expand Down

0 comments on commit c3ab411

Please sign in to comment.