diff --git a/sshmux.go b/sshmux.go index c1c686b..d2ef6a8 100644 --- a/sshmux.go +++ b/sshmux.go @@ -32,8 +32,8 @@ type Server struct { ProxyPolicy ProxyPolicyConfig } -type UpstreamInformation struct { - Host string +type upstreamInformation struct { + Address string Signer ssh.Signer Password *string ProxyProtocol byte @@ -180,7 +180,7 @@ func (s *Server) handler(conn net.Conn) { func (s *Server) Handshake(session *ssh.PipeSession) error { hasSetUser := false var user string - var upstream *UpstreamInformation + var upstream *upstreamInformation if s.Banner != "" { err := session.Downstream.SendBanner(s.Banner) if err != nil { @@ -214,15 +214,11 @@ auth_requests: if upstreamResp.Port == 0 { upstreamResp.Port = 22 } - upstream = &UpstreamInformation{ + upstream = &upstreamInformation{ Signer: parsePrivateKey(upstreamResp.PrivateKey, upstreamResp.Certificate), Password: upstreamResp.Password, } - if host, err := netip.ParseAddr(upstreamResp.Host); err == nil { - upstream.Host = netip.AddrPortFrom(host, upstreamResp.Port).String() - } else { - upstream.Host = fmt.Sprintf("%s:%d", upstreamResp.Host, upstreamResp.Port) - } + upstream.Address = net.JoinHostPort(upstreamResp.Host, fmt.Sprintf("%d", upstreamResp.Port)) if upstreamResp.ProxyProtocol != nil { switch *upstreamResp.ProxyProtocol { case "v1": @@ -270,7 +266,7 @@ auth_requests: } } // Stage 2: connect to upstream - conn, err := net.Dial("tcp", upstream.Host) + conn, err := net.Dial("tcp", upstream.Address) if err != nil { return err } @@ -285,7 +281,7 @@ auth_requests: User: user, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - err = session.InitUpstream(conn, upstream.Host, sshConfig) + err = session.InitUpstream(conn, upstream.Address, sshConfig) if err != nil { return err }