diff --git a/crypto/ssh/handshake.go b/crypto/ssh/handshake.go index 6cc6bb3..56cdc7c 100644 --- a/crypto/ssh/handshake.go +++ b/crypto/ssh/handshake.go @@ -699,17 +699,17 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") { supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",") extInfo := &extInfoMsg{ - NumExtensions: 1, + NumExtensions: 2, Payload: make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1), } extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs")) extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...) extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList)) extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...) - // extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com")) - // extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...) - // extInfo.Payload = appendInt(extInfo.Payload, 1) - // extInfo.Payload = append(extInfo.Payload, "0"...) + extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com")) + extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...) + extInfo.Payload = appendInt(extInfo.Payload, 1) + extInfo.Payload = append(extInfo.Payload, "0"...) if err := t.conn.writePacket(Marshal(extInfo)); err != nil { return err } diff --git a/crypto/ssh/pipe.go b/crypto/ssh/pipe.go index 8a991d6..8648e6b 100644 --- a/crypto/ssh/pipe.go +++ b/crypto/ssh/pipe.go @@ -442,12 +442,23 @@ func (s *PipeSession) Close() { } } -func pipe(dst, src packetConn) error { +func pipe(dst, src packetConn, handlePing bool) error { for { msg, err := src.readPacket() if err != nil { return err } + if handlePing && msg[0] == msgPing { + var ping pingMsg + if err := Unmarshal(msg, &ping); err != nil { + return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err) + } + err = src.writePacket(Marshal(pongMsg(ping))) + if err != nil { + return err + } + continue + } err = dst.writePacket(msg) if err != nil { return err @@ -459,11 +470,14 @@ func (s *PipeSession) RunPipe() error { c := make(chan error) go func() { defer s.Downstream.transport.Close() - c <- pipe(s.Downstream.transport, s.Upstream.transport) + c <- pipe(s.Downstream.transport, s.Upstream.transport, false) }() go func() { defer s.Upstream.transport.Close() - c <- pipe(s.Upstream.transport, s.Downstream.transport) + // If the upstream doesn't support ping@openssh.com, short-circuit with a pong response + upstream_ping_version := s.Upstream.extensions["ping@openssh.com"] + upstream_supports_ping := len(upstream_ping_version) == 1 && upstream_ping_version[0] == byte('0') + c <- pipe(s.Upstream.transport, s.Downstream.transport, !upstream_supports_ping) }() return <-c }