Skip to content

Commit

Permalink
adding a second test server to ensure port forwarding audit events ar…
Browse files Browse the repository at this point in the history
…e deterministic
  • Loading branch information
eriktate committed Jan 16, 2025
1 parent 6c82c2c commit 7dad152
Showing 1 changed file with 55 additions and 49 deletions.
104 changes: 55 additions & 49 deletions lib/srv/regular/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,17 +515,38 @@ func TestSessionAuditLog(t *testing.T) {
x11Event := nextEvent()
require.IsType(t, &apievents.X11Forward{}, x11Event, "expected X11Forward event but got event of tgsype %T", x11Event)

// Request a remote port forwarding listener.
listener, err := f.ssh.clt.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

// Start up a test server that uses the port forwarded listener.
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// LOCAL PORT FORWARDING
// Start up a test server that doesn't do any remote port forwarding
nonForwardServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "hello, world")
}))
t.Cleanup(ts.Close)
ts.Listener = listener
ts.Start()
t.Cleanup(nonForwardServer.Close)
nonForwardServer.Start()

// Each locally forwarded dial should result in a new "start" event and each closed connection should result in a "stop"
// event. Note that we don't know what port the server will forward the connection on, so we don't have an easy way to validate the
// event's addr field.
localConn, err := f.ssh.clt.DialContext(context.Background(), "tcp", nonForwardServer.Listener.Addr().String())
require.NoError(t, err)

e = nextEvent()
localForwardStart, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, events.PortForwardLocalEvent, localForwardStart.GetType())
require.Equal(t, events.PortForwardCode, localForwardStart.GetCode())

// closed connections should result in PortForwardLocal stop events
localConn.Close()
e = nextEvent()
localForwardStop, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, events.PortForwardLocalEvent, localForwardStop.GetType())
require.Equal(t, events.PortForwardStopCode, localForwardStop.GetCode())

// REMOTE PORT FORWARDING
// Creation of a port forwarded listener should generate PortForwardRemote start events
listener, err := f.ssh.clt.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

e = nextEvent()
remoteForwardStart, ok := e.(*apievents.PortForward)
Expand All @@ -534,49 +555,34 @@ func TestSessionAuditLog(t *testing.T) {
require.Equal(t, events.PortForwardRemoteEvent, remoteForwardStart.GetType())
require.Equal(t, events.PortForwardCode, remoteForwardStart.GetCode())

// Request forward to remote port. Each dial should result in a new event. Note that we don't
// know what port the server will forward the connection on, so we don't have an easy way to
// validate the event's addr field.
conn, err := f.ssh.clt.DialContext(context.Background(), "tcp", listener.Addr().String())
require.NoError(t, err)

// the order of PortForwardLocal events and PortForwardRemoteConn events are sometimes swapped but order doesn't matter, so we just
// need to ensure that we receive both
foundLocalForwardStart := false
foundConnForwardStart := false
for i := 0; i < 2; i += 1 {
e = nextEvent()
require.IsType(t, &apievents.PortForward{}, e, "expected PortForward event but got event of type %T", e)
if !foundLocalForwardStart && e.GetType() == events.PortForwardLocalEvent {
foundLocalForwardStart = e.GetCode() == events.PortForwardCode
continue
}
// Start up a test server that uses the remote port forwarded listener.
remoteForwardServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "hello, world")
}))
t.Cleanup(remoteForwardServer.Close)
remoteForwardServer.Listener = listener
remoteForwardServer.Start()

if !foundConnForwardStart && e.GetType() == events.PortForwardRemoteConnEvent {
foundConnForwardStart = e.GetCode() == events.PortForwardCode
}
}
require.True(t, foundLocalForwardStart && foundConnForwardStart)

conn.Close()
// similar to above, order of stop events received is inconsistent and mostly irrelevant here
foundLocalForwardStop := false
foundConnForwardStop := false
for i := 0; i < 2; i += 1 {
e = nextEvent()
require.IsType(t, &apievents.PortForward{}, e, "expected PortForward event but got event of type %T", e)
if !foundLocalForwardStop && e.GetType() == events.PortForwardLocalEvent {
foundLocalForwardStop = e.GetCode() == events.PortForwardStopCode
continue
}
// Each dial to the remote listener should result in a new "start" event and each closed connection should result in a "stop" event.
// Note that we don't know what port the server will forward the connection on, so we don't have an easy way to validate the event's
// addr field.
remoteConn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
e = nextEvent()
remoteConnStart, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, events.PortForwardRemoteConnEvent, remoteConnStart.GetType())
require.Equal(t, events.PortForwardCode, remoteConnStart.GetCode())

if !foundConnForwardStop && e.GetType() == events.PortForwardRemoteConnEvent {
foundConnForwardStop = e.GetCode() == events.PortForwardStopCode
}
}
require.True(t, foundLocalForwardStop && foundConnForwardStop)
remoteConn.Close()
e = nextEvent()
remoteConnStop, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, events.PortForwardRemoteConnEvent, remoteConnStop.GetType())
require.Equal(t, events.PortForwardStopCode, remoteConnStop.GetCode())

ts.Close()
// Closing the server (and therefore the listener) should generate an PortForwardRemote stop event
remoteForwardServer.Close()
e = nextEvent()
remoteForwardStop, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
Expand Down

0 comments on commit 7dad152

Please sign in to comment.