diff --git a/association_test.go b/association_test.go index 33307202..d7d71150 100644 --- a/association_test.go +++ b/association_test.go @@ -2564,16 +2564,62 @@ func TestAssocMaxMessageSize(t *testing.T) { }) } +// udpConnWrapper wraps a *net.UDPConn and implements net.Conn interface. +type udpConnWrapper struct { + conn *net.UDPConn + remoteAddr net.Addr +} + +func newUDPConnWrapper(conn *net.UDPConn, remoteAddr net.Addr) net.Conn { + return &udpConnWrapper{ + conn: conn, + remoteAddr: remoteAddr, + } +} + +// Implement the net.Conn interface methods +func (w *udpConnWrapper) Read(b []byte) (n int, err error) { + // w.conn.ReadFrom(b) + n, _, err = w.conn.ReadFrom(b) + return n, err +} + +func (w *udpConnWrapper) Write(b []byte) (n int, err error) { + return w.conn.WriteTo(b, w.remoteAddr) +} + +func (w *udpConnWrapper) Close() error { + return w.conn.Close() +} + +func (w *udpConnWrapper) LocalAddr() net.Addr { + return w.conn.LocalAddr() +} + +func (w *udpConnWrapper) RemoteAddr() net.Addr { + return w.remoteAddr +} + +func (w *udpConnWrapper) SetDeadline(t time.Time) error { + return w.conn.SetDeadline(t) +} + +func (w *udpConnWrapper) SetReadDeadline(t time.Time) error { + return w.conn.SetReadDeadline(t) +} + +func (w *udpConnWrapper) SetWriteDeadline(t time.Time) error { + return w.conn.SetWriteDeadline(t) +} + // crateUDPConnPair creates a pair of net.UDPConn objects that are connected with each other -func createUDPConnPair(t *testing.T) (*net.UDPConn, *net.UDPConn, error) { +func createUDPConnPair(t *testing.T) (net.Conn, net.Conn, error) { udp1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}) if err != nil { return nil, nil, err } addr1, ok := udp1.LocalAddr().(*net.UDPAddr) require.True(t, ok) - err = udp1.Close() - require.NoError(t, err) udp2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}) if err != nil { @@ -2581,20 +2627,18 @@ func createUDPConnPair(t *testing.T) (*net.UDPConn, *net.UDPConn, error) { } addr2, ok := udp2.LocalAddr().(*net.UDPAddr) require.True(t, ok) - err = udp2.Close() - require.NoError(t, err) - udp1, err = net.DialUDP("udp", addr1, addr2) + conn1 := newUDPConnWrapper(udp1, addr2) if err != nil { return nil, nil, err } - udp2, err = net.DialUDP("udp", addr2, addr1) + conn2 := newUDPConnWrapper(udp2, addr1) if err != nil { return nil, nil, err } - return udp1, udp2, nil + return conn1, conn2, nil } func createAssocs(t *testing.T) (*Association, *Association, error) {