Skip to content

Commit

Permalink
adding additional audit log context around SSH port forwarding (#50932)
Browse files Browse the repository at this point in the history
  • Loading branch information
eriktate authored and mvbrock committed Jan 18, 2025
1 parent c22aca4 commit a0b8fb4
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 204 deletions.
11 changes: 7 additions & 4 deletions lib/events/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,13 @@ const (
X11ForwardErr = "error"

// Port forwarding event
PortForwardEvent = "port"
PortForwardAddr = "addr"
PortForwardSuccess = "success"
PortForwardErr = "error"
PortForwardEvent = "port"
PortForwardLocalEvent = "port.local"
PortForwardRemoteEvent = "port.remote"
PortForwardRemoteConnEvent = "port.remote_conn"
PortForwardAddr = "addr"
PortForwardSuccess = "success"
PortForwardErr = "error"

// AuthAttemptEvent is authentication attempt that either
// succeeded or failed based on event status
Expand Down
6 changes: 6 additions & 0 deletions lib/events/dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ func FromEventFields(fields EventFields) (events.AuditEvent, error) {
e = &events.X11Forward{}
case PortForwardEvent:
e = &events.PortForward{}
case PortForwardLocalEvent:
e = &events.PortForward{}
case PortForwardRemoteEvent:
e = &events.PortForward{}
case PortForwardRemoteConnEvent:
e = &events.PortForward{}
case AuthAttemptEvent:
e = &events.AuthAttempt{}
case SCPEvent:
Expand Down
8 changes: 4 additions & 4 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1269,19 +1269,19 @@ func (c *ServerContext) GetSessionMetadata() apievents.SessionMetadata {
}
}

func (c *ServerContext) GetPortForwardEvent() apievents.PortForward {
func (c *ServerContext) GetPortForwardEvent(evType, code, addr string) apievents.PortForward {
sconn := c.ConnectionContext.ServerConn
return apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardCode,
Type: evType,
Code: code,
},
UserMetadata: c.Identity.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: sconn.LocalAddr().String(),
RemoteAddr: sconn.RemoteAddr().String(),
},
Addr: c.DstAddr,
Addr: addr,
Status: apievents.Status{
Success: true,
},
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha
go io.Copy(io.Discard, ch.Stderr())
ch = scx.TrackActivity(ch)

event := scx.GetPortForwardEvent()
event := scx.GetPortForwardEvent(events.PortForwardEvent, events.PortForwardCode, scx.DstAddr)
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err)
}
Expand Down Expand Up @@ -1120,7 +1120,7 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r
}
defer conn.Close()

event := scx.GetPortForwardEvent()
event := scx.GetPortForwardEvent(events.PortForwardEvent, events.PortForwardFailureCode, scx.DstAddr)
if err := s.EmitAuditEvent(s.closeContext, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
Expand Down
112 changes: 83 additions & 29 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1489,27 +1489,17 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con
return
}

startEvent := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardCode, scx.DstAddr)
s.emitAuditEventWithLog(ctx, &startEvent)

if err := utils.ProxyConn(ctx, conn, channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
errEvent := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode, scx.DstAddr)
s.emitAuditEventWithLog(ctx, &errEvent)
scx.Logger.WarnContext(ctx, "Connection problem in direct-tcpip channel", "error", err)
}

if err := s.EmitAuditEvent(s.ctx, &apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardCode,
},
UserMetadata: scx.Identity.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: scx.ServerConn.LocalAddr().String(),
RemoteAddr: scx.ServerConn.RemoteAddr().String(),
},
Addr: scx.DstAddr,
Status: apievents.Status{
Success: true,
},
}); err != nil {
scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
stopEvent := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode, scx.DstAddr)
s.emitAuditEventWithLog(ctx, &stopEvent)
}

// handleSessionRequests handles out of band session requests once the session
Expand Down Expand Up @@ -1868,9 +1858,7 @@ func (s *Server) handleX11Forward(ctx context.Context, ch ssh.Channel, req *ssh.
s.replyError(ctx, ch, req, err)
err = nil
}
if err := s.EmitAuditEvent(s.ctx, event); err != nil {
scx.Logger.WarnContext(s.ctx, "Failed to emit x11-forward event", "error", err)
}
s.emitAuditEventWithLog(s.ctx, event)
}()

// check if X11 forwarding is disabled, or if xauth can't be handled.
Expand Down Expand Up @@ -2162,6 +2150,7 @@ func (s *Server) createForwardingContext(ctx context.Context, ccx *sshutils.Conn
if err != nil {
return nil, nil, trace.Wrap(err)
}

listenAddr := sshutils.JoinHostPort(req.Addr, req.Port)
scx.IsTestStub = s.isTestStub
scx.ExecType = teleport.TCPIPForwardRequest
Expand Down Expand Up @@ -2201,13 +2190,72 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co
}
scx.SrcAddr = sshutils.JoinHostPort(srcHost, listenPort)

event := scx.GetPortForwardEvent()
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
}
if err := sshutils.StartRemoteListener(ctx, scx.ConnectionContext.ServerConn, scx.SrcAddr, listener); err != nil {
return trace.Wrap(err)
}
event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardCode, scx.SrcAddr)
s.emitAuditEventWithLog(ctx, &event)

// spawn remote forwarding handler to multiplex connections to the forwarded port
go func() {
stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardStopCode, scx.SrcAddr)
defer s.emitAuditEventWithLog(ctx, &stopEvent)

for {
conn, err := listener.Accept()
if err != nil {
if !utils.IsOKNetworkError(err) {
slog.WarnContext(ctx, "failed to accept connection", "error", err)
}
return
}
logger := slog.With(
"src_addr", scx.SrcAddr,
"remote_addr", conn.RemoteAddr().String(),
)

dstHost, dstPort, err := sshutils.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
conn.Close()
logger.WarnContext(ctx, "failed to parse addr", "error", err)
return
}

req := sshutils.ForwardedTCPIPRequest{
Addr: srcHost,
Port: listenPort,
OrigAddr: dstHost,
OrigPort: dstPort,
}
if err := req.CheckAndSetDefaults(); err != nil {
conn.Close()
logger.WarnContext(ctx, "failed to create forwarded tcpip request", "error", err)
return
}
reqBytes := ssh.Marshal(req)

ch, rch, err := scx.ConnectionContext.ServerConn.OpenChannel(teleport.ChanForwardedTCPIP, reqBytes)
if err != nil {
conn.Close()
logger.WarnContext(ctx, "failed to open channel", "error", err)
continue
}
go ssh.DiscardRequests(rch)
go io.Copy(io.Discard, ch.Stderr())
go func() {
startEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode, scx.SrcAddr)
startEvent.RemoteAddr = conn.RemoteAddr().String()
s.emitAuditEventWithLog(ctx, &startEvent)

if err := utils.ProxyConn(ctx, conn, ch); err != nil {
errEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardFailureCode, scx.SrcAddr)
errEvent.RemoteAddr = conn.RemoteAddr().String()
s.emitAuditEventWithLog(ctx, &errEvent)
}

stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode, scx.SrcAddr)
stopEvent.RemoteAddr = conn.RemoteAddr().String()
s.emitAuditEventWithLog(ctx, &stopEvent)
}()
}
}()

// Report addr back to the client.
if r.WantReply {
Expand Down Expand Up @@ -2250,14 +2298,14 @@ func (s *Server) handleCancelTCPIPForwardRequest(ctx context.Context, ccx *sshut
return trace.Wrap(err)
}
defer scx.Close()

listener, ok := s.remoteForwardingMap.LoadAndDelete(scx.SrcAddr)
if !ok {
return trace.NotFound("no remote forwarding listener at %v", scx.SrcAddr)
}
if err := r.Reply(true, nil); err != nil {
s.logger.WarnContext(ctx, "Failed to reply to request", "request_type", r.Type, "error", err)
}

return trace.Wrap(listener.Close())
}

Expand Down Expand Up @@ -2300,7 +2348,7 @@ func (s *Server) parseSubsystemRequest(ctx context.Context, req *ssh.Request, se
case r.Name == teleport.SFTPSubsystem:
err := serverContext.CheckSFTPAllowed(s.reg)
if err != nil {
s.EmitAuditEvent(context.Background(), &apievents.SFTP{
s.emitAuditEventWithLog(context.Background(), &apievents.SFTP{
Metadata: apievents.Metadata{
Code: events.SFTPDisallowedCode,
Type: events.SFTPEvent,
Expand Down Expand Up @@ -2347,3 +2395,9 @@ func (s *Server) handlePuTTYWinadj(ctx context.Context, req *ssh.Request) error
req.WantReply = false
return nil
}

func (s *Server) emitAuditEventWithLog(ctx context.Context, event apievents.AuditEvent) {
if err := s.EmitAuditEvent(ctx, event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit event", "type", event.GetType(), "code", event.GetCode())
}
}
82 changes: 65 additions & 17 deletions lib/srv/regular/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,6 @@ func TestSessionAuditLog(t *testing.T) {
roleOptions := role.GetOptions()
roleOptions.PermitX11Forwarding = types.NewBool(true)
roleOptions.ForwardAgent = types.NewBool(true)
//nolint:staticcheck // this field is preserved for existing deployments, but shouldn't be used going forward
roleOptions.PortForwarding = types.NewBoolOption(true)
role.SetOptions(roleOptions)
_, err = f.testSrv.Auth().UpsertRole(ctx, role)
require.NoError(t, err)
Expand Down Expand Up @@ -517,32 +515,82 @@ func TestSessionAuditLog(t *testing.T) {
x11Event := nextEvent()
require.IsType(t, &apievents.X11Forward{}, x11Event, "expected X11Forward event but got event of tgsype %T", x11Event)

// Request a remote port forwarding listener.
// LOCAL PORT FORWARDING
// Start up a test server that doesn't do any remote port forwarding
nonForwardServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "hello, world")
}))
t.Cleanup(nonForwardServer.Close)
nonForwardServer.Start()

// Each locally forwarded dial should result in a new "start" event and each closed connection should result in a "stop"
// event. Note that we don't know what port the server will forward the connection on, so we don't have an easy way to validate the
// event's addr field.
localConn, err := f.ssh.clt.DialContext(context.Background(), "tcp", nonForwardServer.Listener.Addr().String())
require.NoError(t, err)

e = nextEvent()
localForwardStart, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, events.PortForwardLocalEvent, localForwardStart.GetType())
require.Equal(t, events.PortForwardCode, localForwardStart.GetCode())
require.Equal(t, nonForwardServer.Listener.Addr().String(), localForwardStart.Addr)

// closed connections should result in PortForwardLocal stop events
localConn.Close()
e = nextEvent()
localForwardStop, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, events.PortForwardLocalEvent, localForwardStop.GetType())
require.Equal(t, events.PortForwardStopCode, localForwardStop.GetCode())
require.Equal(t, nonForwardServer.Listener.Addr().String(), localForwardStop.Addr)

// REMOTE PORT FORWARDING
// Creation of a port forwarded listener should generate PortForwardRemote start events
listener, err := f.ssh.clt.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

// Start up a test server that uses the port forwarded listener.
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
e = nextEvent()
remoteForwardStart, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, listener.Addr().String(), remoteForwardStart.Addr)
require.Equal(t, events.PortForwardRemoteEvent, remoteForwardStart.GetType())
require.Equal(t, events.PortForwardCode, remoteForwardStart.GetCode())

// Start up a test server that uses the remote port forwarded listener.
remoteForwardServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "hello, world")
}))
t.Cleanup(ts.Close)
ts.Listener = listener
ts.Start()
t.Cleanup(remoteForwardServer.Close)
remoteForwardServer.Listener = listener
remoteForwardServer.Start()

// Request forward to remote port. Each dial should result in a new event. Note that we don't
// know what port the server will forward the connection on, so we don't have an easy way to
// validate the event's addr field.
conn, err := f.ssh.clt.DialContext(context.Background(), "tcp", listener.Addr().String())
// Each dial to the remote listener should result in a new "start" event and each closed connection should result in a "stop" event.
// Note that we don't know what port the server will forward the connection on, so we don't have an easy way to validate the event's
// addr field.
remoteConn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
conn.Close()
e = nextEvent()
remoteConnStart, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, events.PortForwardRemoteConnEvent, remoteConnStart.GetType())
require.Equal(t, events.PortForwardCode, remoteConnStart.GetCode())

directPortForwardEvent := nextEvent()
require.IsType(t, &apievents.PortForward{}, directPortForwardEvent, "expected PortForward event but got event of type %T", directPortForwardEvent)
remoteConn.Close()
e = nextEvent()
remoteConnStop, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, events.PortForwardRemoteConnEvent, remoteConnStop.GetType())
require.Equal(t, events.PortForwardStopCode, remoteConnStop.GetCode())

// Closing the server (and therefore the listener) should generate an PortForwardRemote stop event
remoteForwardServer.Close()
e = nextEvent()
remotePortForwardEvent, ok := e.(*apievents.PortForward)
remoteForwardStop, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, listener.Addr().String(), remotePortForwardEvent.Addr)
require.Equal(t, events.PortForwardRemoteEvent, remoteForwardStop.GetType())
require.Equal(t, events.PortForwardStopCode, remoteForwardStop.Code)
require.Equal(t, listener.Addr().String(), remoteForwardStop.Addr)

// End the session. Session leave, data, and end events should be emitted.
se.Close()
Expand Down
8 changes: 0 additions & 8 deletions lib/sshutils/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,6 @@ func (mc *mockChannel) Stderr() io.ReadWriter {
return fakeReaderWriter{}
}

type mockSSHConn struct {
mockChan *mockChannel
}

func (mc *mockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
return mc.mockChan, make(<-chan *ssh.Request), nil
}

type mockSSHNewChannel struct {
mock.Mock
ssh.NewChannel
Expand Down
Loading

0 comments on commit a0b8fb4

Please sign in to comment.