diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 1dc151ed10a64..8afba1b361e45 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -515,17 +515,38 @@ func TestSessionAuditLog(t *testing.T) { x11Event := nextEvent() require.IsType(t, &apievents.X11Forward{}, x11Event, "expected X11Forward event but got event of tgsype %T", x11Event) - // Request a remote port forwarding listener. - listener, err := f.ssh.clt.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - // Start up a test server that uses the port forwarded listener. - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // LOCAL PORT FORWARDING + // Start up a test server that doesn't do any remote port forwarding + nonForwardServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "hello, world") })) - t.Cleanup(ts.Close) - ts.Listener = listener - ts.Start() + t.Cleanup(nonForwardServer.Close) + nonForwardServer.Start() + + // Each locally forwarded dial should result in a new "start" event and each closed connection should result in a "stop" + // event. Note that we don't know what port the server will forward the connection on, so we don't have an easy way to validate the + // event's addr field. + localConn, err := f.ssh.clt.DialContext(context.Background(), "tcp", nonForwardServer.Listener.Addr().String()) + require.NoError(t, err) + + e = nextEvent() + localForwardStart, ok := e.(*apievents.PortForward) + require.True(t, ok, "expected PortForward event but got event of type %T", e) + require.Equal(t, events.PortForwardLocalEvent, localForwardStart.GetType()) + require.Equal(t, events.PortForwardCode, localForwardStart.GetCode()) + + // closed connections should result in PortForwardLocal stop events + localConn.Close() + e = nextEvent() + localForwardStop, ok := e.(*apievents.PortForward) + require.True(t, ok, "expected PortForward event but got event of type %T", e) + require.Equal(t, events.PortForwardLocalEvent, localForwardStop.GetType()) + require.Equal(t, events.PortForwardStopCode, localForwardStop.GetCode()) + + // REMOTE PORT FORWARDING + // Creation of a port forwarded listener should generate PortForwardRemote start events + listener, err := f.ssh.clt.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) e = nextEvent() remoteForwardStart, ok := e.(*apievents.PortForward) @@ -534,49 +555,34 @@ func TestSessionAuditLog(t *testing.T) { require.Equal(t, events.PortForwardRemoteEvent, remoteForwardStart.GetType()) require.Equal(t, events.PortForwardCode, remoteForwardStart.GetCode()) - // Request forward to remote port. Each dial should result in a new event. Note that we don't - // know what port the server will forward the connection on, so we don't have an easy way to - // validate the event's addr field. - conn, err := f.ssh.clt.DialContext(context.Background(), "tcp", listener.Addr().String()) - require.NoError(t, err) - - // the order of PortForwardLocal events and PortForwardRemoteConn events are sometimes swapped but order doesn't matter, so we just - // need to ensure that we receive both - foundLocalForwardStart := false - foundConnForwardStart := false - for i := 0; i < 2; i += 1 { - e = nextEvent() - require.IsType(t, &apievents.PortForward{}, e, "expected PortForward event but got event of type %T", e) - if !foundLocalForwardStart && e.GetType() == events.PortForwardLocalEvent { - foundLocalForwardStart = e.GetCode() == events.PortForwardCode - continue - } + // Start up a test server that uses the remote port forwarded listener. + remoteForwardServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "hello, world") + })) + t.Cleanup(remoteForwardServer.Close) + remoteForwardServer.Listener = listener + remoteForwardServer.Start() - if !foundConnForwardStart && e.GetType() == events.PortForwardRemoteConnEvent { - foundConnForwardStart = e.GetCode() == events.PortForwardCode - } - } - require.True(t, foundLocalForwardStart && foundConnForwardStart) - - conn.Close() - // similar to above, order of stop events received is inconsistent and mostly irrelevant here - foundLocalForwardStop := false - foundConnForwardStop := false - for i := 0; i < 2; i += 1 { - e = nextEvent() - require.IsType(t, &apievents.PortForward{}, e, "expected PortForward event but got event of type %T", e) - if !foundLocalForwardStop && e.GetType() == events.PortForwardLocalEvent { - foundLocalForwardStop = e.GetCode() == events.PortForwardStopCode - continue - } + // Each dial to the remote listener should result in a new "start" event and each closed connection should result in a "stop" event. + // Note that we don't know what port the server will forward the connection on, so we don't have an easy way to validate the event's + // addr field. + remoteConn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + e = nextEvent() + remoteConnStart, ok := e.(*apievents.PortForward) + require.True(t, ok, "expected PortForward event but got event of type %T", e) + require.Equal(t, events.PortForwardRemoteConnEvent, remoteConnStart.GetType()) + require.Equal(t, events.PortForwardCode, remoteConnStart.GetCode()) - if !foundConnForwardStop && e.GetType() == events.PortForwardRemoteConnEvent { - foundConnForwardStop = e.GetCode() == events.PortForwardStopCode - } - } - require.True(t, foundLocalForwardStop && foundConnForwardStop) + remoteConn.Close() + e = nextEvent() + remoteConnStop, ok := e.(*apievents.PortForward) + require.True(t, ok, "expected PortForward event but got event of type %T", e) + require.Equal(t, events.PortForwardRemoteConnEvent, remoteConnStop.GetType()) + require.Equal(t, events.PortForwardStopCode, remoteConnStop.GetCode()) - ts.Close() + // Closing the server (and therefore the listener) should generate an PortForwardRemote stop event + remoteForwardServer.Close() e = nextEvent() remoteForwardStop, ok := e.(*apievents.PortForward) require.True(t, ok, "expected PortForward event but got event of type %T", e)