diff --git a/pkg/core/client.go b/pkg/core/client.go index 1d6655e1b2..e54b41a896 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -209,31 +209,15 @@ func (c *Client) DialTCP(addr string) (net.Conn, error) { if err != nil { return nil, err } - // Send request - err = struc.Pack(stream, &clientRequest{ - UDP: false, - Host: host, - Port: port, - }) - if err != nil { - _ = stream.Close() - return nil, err - } - // Read response - var sr serverResponse - err = struc.Unpack(stream, &sr) - if err != nil { - _ = stream.Close() - return nil, err - } - if !sr.OK { - _ = stream.Close() - return nil, fmt.Errorf("connection rejected: %s", sr.Message) - } return &hyTCPConn{ Orig: stream, PseudoLocalAddr: session.LocalAddr(), PseudoRemoteAddr: session.RemoteAddr(), + handshakeRequest: &clientRequest{ + UDP: false, + Host: host, + Port: port, + }, }, nil } @@ -304,13 +288,48 @@ type hyTCPConn struct { Orig quic.Stream PseudoLocalAddr net.Addr PseudoRemoteAddr net.Addr + // For TCP connections + handshakeRead sync.Once + handshakeWrite sync.Once + handshakeRequest *clientRequest } func (w *hyTCPConn) Read(b []byte) (n int, err error) { + w.handshakeRead.Do(func() { + if w.handshakeRequest == nil { + return + } + defer func() { w.handshakeRequest = nil }() + // Read response + var sr serverResponse + err = struc.Unpack(w.Orig, &sr) + if err != nil { + _ = w.Close() + err = fmt.Errorf("handshake %s read failed: %s", w.PseudoRemoteAddr.String(), err) + return + } + if !sr.OK { + _ = w.Close() + err = fmt.Errorf("handshake %s connection rejected: %s", w.PseudoRemoteAddr.String(), sr.Message) + return + } + }) return w.Orig.Read(b) } func (w *hyTCPConn) Write(b []byte) (n int, err error) { + w.handshakeWrite.Do(func() { + if w.handshakeRequest == nil { + return + } + // Send request + err = struc.Pack(w.Orig, w.handshakeRequest) + if err != nil { + _ = w.Close() + err = fmt.Errorf("handshake %s write failed: %s", w.PseudoRemoteAddr.String(), err) + return + } + }) return w.Orig.Write(b) }