Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving regular SSH port forwarding audit logs #50932

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions lib/events/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,13 @@ const (
X11ForwardErr = "error"

// Port forwarding event
PortForwardEvent = "port"
PortForwardAddr = "addr"
PortForwardSuccess = "success"
PortForwardErr = "error"
PortForwardEvent = "port"
PortForwardLocalEvent = "port.local"
PortForwardRemoteEvent = "port.remote"
PortForwardRemoteConnEvent = "port.remote_conn"
PortForwardAddr = "addr"
PortForwardSuccess = "success"
PortForwardErr = "error"

// AuthAttemptEvent is authentication attempt that either
// succeeded or failed based on event status
Expand Down
6 changes: 6 additions & 0 deletions lib/events/dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ func FromEventFields(fields EventFields) (events.AuditEvent, error) {
e = &events.X11Forward{}
case PortForwardEvent:
e = &events.PortForward{}
case PortForwardLocalEvent:
e = &events.PortForward{}
case PortForwardRemoteEvent:
e = &events.PortForward{}
case PortForwardRemoteConnEvent:
e = &events.PortForward{}
case AuthAttemptEvent:
e = &events.AuthAttempt{}
case SCPEvent:
Expand Down
8 changes: 4 additions & 4 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1269,19 +1269,19 @@ func (c *ServerContext) GetSessionMetadata() apievents.SessionMetadata {
}
}

func (c *ServerContext) GetPortForwardEvent() apievents.PortForward {
func (c *ServerContext) GetPortForwardEvent(evType, code, addr string) apievents.PortForward {
sconn := c.ConnectionContext.ServerConn
return apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardCode,
Type: evType,
Code: code,
},
UserMetadata: c.Identity.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: sconn.LocalAddr().String(),
RemoteAddr: sconn.RemoteAddr().String(),
},
Addr: c.DstAddr,
Addr: addr,
Status: apievents.Status{
Success: true,
},
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha
go io.Copy(io.Discard, ch.Stderr())
ch = scx.TrackActivity(ch)

event := scx.GetPortForwardEvent()
event := scx.GetPortForwardEvent(events.PortForwardEvent, events.PortForwardCode, scx.DstAddr)
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err)
}
Expand Down Expand Up @@ -1120,7 +1120,7 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r
}
defer conn.Close()

event := scx.GetPortForwardEvent()
event := scx.GetPortForwardEvent(events.PortForwardEvent, events.PortForwardFailureCode, scx.DstAddr)
if err := s.EmitAuditEvent(s.closeContext, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
Expand Down
112 changes: 83 additions & 29 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1489,27 +1489,17 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con
return
}

startEvent := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardCode, scx.DstAddr)
s.emitAuditEventWithLog(ctx, &startEvent)

if err := utils.ProxyConn(ctx, conn, channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
errEvent := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode, scx.DstAddr)
s.emitAuditEventWithLog(ctx, &errEvent)
scx.Logger.WarnContext(ctx, "Connection problem in direct-tcpip channel", "error", err)
}

if err := s.EmitAuditEvent(s.ctx, &apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardCode,
},
UserMetadata: scx.Identity.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: scx.ServerConn.LocalAddr().String(),
RemoteAddr: scx.ServerConn.RemoteAddr().String(),
},
Addr: scx.DstAddr,
Status: apievents.Status{
Success: true,
},
}); err != nil {
scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
stopEvent := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode, scx.DstAddr)
s.emitAuditEventWithLog(ctx, &stopEvent)
}

// handleSessionRequests handles out of band session requests once the session
Expand Down Expand Up @@ -1868,9 +1858,7 @@ func (s *Server) handleX11Forward(ctx context.Context, ch ssh.Channel, req *ssh.
s.replyError(ctx, ch, req, err)
err = nil
}
if err := s.EmitAuditEvent(s.ctx, event); err != nil {
scx.Logger.WarnContext(s.ctx, "Failed to emit x11-forward event", "error", err)
}
s.emitAuditEventWithLog(s.ctx, event)
}()

// check if X11 forwarding is disabled, or if xauth can't be handled.
Expand Down Expand Up @@ -2162,6 +2150,7 @@ func (s *Server) createForwardingContext(ctx context.Context, ccx *sshutils.Conn
if err != nil {
return nil, nil, trace.Wrap(err)
}

listenAddr := sshutils.JoinHostPort(req.Addr, req.Port)
scx.IsTestStub = s.isTestStub
scx.ExecType = teleport.TCPIPForwardRequest
Expand Down Expand Up @@ -2201,13 +2190,72 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co
}
scx.SrcAddr = sshutils.JoinHostPort(srcHost, listenPort)

event := scx.GetPortForwardEvent()
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
}
if err := sshutils.StartRemoteListener(ctx, scx.ConnectionContext.ServerConn, scx.SrcAddr, listener); err != nil {
return trace.Wrap(err)
}
event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardCode, scx.SrcAddr)
s.emitAuditEventWithLog(ctx, &event)

// spawn remote forwarding handler to multiplex connections to the forwarded port
go func() {
stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardStopCode, scx.SrcAddr)
defer s.emitAuditEventWithLog(ctx, &stopEvent)

for {
conn, err := listener.Accept()
if err != nil {
if !utils.IsOKNetworkError(err) {
slog.WarnContext(ctx, "failed to accept connection", "error", err)
}
return
}
logger := slog.With(
"src_addr", scx.SrcAddr,
"remote_addr", conn.RemoteAddr().String(),
)

dstHost, dstPort, err := sshutils.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
conn.Close()
logger.WarnContext(ctx, "failed to parse addr", "error", err)
return
}

req := sshutils.ForwardedTCPIPRequest{
Addr: srcHost,
Port: listenPort,
OrigAddr: dstHost,
OrigPort: dstPort,
}
if err := req.CheckAndSetDefaults(); err != nil {
conn.Close()
logger.WarnContext(ctx, "failed to create forwarded tcpip request", "error", err)
return
}
reqBytes := ssh.Marshal(req)

ch, rch, err := scx.ConnectionContext.ServerConn.OpenChannel(teleport.ChanForwardedTCPIP, reqBytes)
if err != nil {
conn.Close()
logger.WarnContext(ctx, "failed to open channel", "error", err)
continue
}
go ssh.DiscardRequests(rch)
go io.Copy(io.Discard, ch.Stderr())
go func() {
startEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode, scx.SrcAddr)
startEvent.RemoteAddr = conn.RemoteAddr().String()
s.emitAuditEventWithLog(ctx, &startEvent)

if err := utils.ProxyConn(ctx, conn, ch); err != nil {
errEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardFailureCode, scx.SrcAddr)
errEvent.RemoteAddr = conn.RemoteAddr().String()
s.emitAuditEventWithLog(ctx, &errEvent)
}

stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode, scx.SrcAddr)
stopEvent.RemoteAddr = conn.RemoteAddr().String()
s.emitAuditEventWithLog(ctx, &stopEvent)
}()
}
}()

// Report addr back to the client.
if r.WantReply {
Expand Down Expand Up @@ -2250,14 +2298,14 @@ func (s *Server) handleCancelTCPIPForwardRequest(ctx context.Context, ccx *sshut
return trace.Wrap(err)
}
defer scx.Close()

listener, ok := s.remoteForwardingMap.LoadAndDelete(scx.SrcAddr)
if !ok {
return trace.NotFound("no remote forwarding listener at %v", scx.SrcAddr)
}
if err := r.Reply(true, nil); err != nil {
s.logger.WarnContext(ctx, "Failed to reply to request", "request_type", r.Type, "error", err)
}

return trace.Wrap(listener.Close())
}

Expand Down Expand Up @@ -2300,7 +2348,7 @@ func (s *Server) parseSubsystemRequest(ctx context.Context, req *ssh.Request, se
case r.Name == teleport.SFTPSubsystem:
err := serverContext.CheckSFTPAllowed(s.reg)
if err != nil {
s.EmitAuditEvent(context.Background(), &apievents.SFTP{
s.emitAuditEventWithLog(context.Background(), &apievents.SFTP{
Metadata: apievents.Metadata{
Code: events.SFTPDisallowedCode,
Type: events.SFTPEvent,
Expand Down Expand Up @@ -2347,3 +2395,9 @@ func (s *Server) handlePuTTYWinadj(ctx context.Context, req *ssh.Request) error
req.WantReply = false
return nil
}

func (s *Server) emitAuditEventWithLog(ctx context.Context, event apievents.AuditEvent) {
if err := s.EmitAuditEvent(ctx, event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit event", "type", event.GetType(), "code", event.GetCode())
}
}
82 changes: 65 additions & 17 deletions lib/srv/regular/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,6 @@ func TestSessionAuditLog(t *testing.T) {
roleOptions := role.GetOptions()
roleOptions.PermitX11Forwarding = types.NewBool(true)
roleOptions.ForwardAgent = types.NewBool(true)
//nolint:staticcheck // this field is preserved for existing deployments, but shouldn't be used going forward
roleOptions.PortForwarding = types.NewBoolOption(true)
role.SetOptions(roleOptions)
_, err = f.testSrv.Auth().UpsertRole(ctx, role)
require.NoError(t, err)
Expand Down Expand Up @@ -517,32 +515,82 @@ func TestSessionAuditLog(t *testing.T) {
x11Event := nextEvent()
require.IsType(t, &apievents.X11Forward{}, x11Event, "expected X11Forward event but got event of tgsype %T", x11Event)

// Request a remote port forwarding listener.
// LOCAL PORT FORWARDING
// Start up a test server that doesn't do any remote port forwarding
nonForwardServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "hello, world")
}))
t.Cleanup(nonForwardServer.Close)
nonForwardServer.Start()

// Each locally forwarded dial should result in a new "start" event and each closed connection should result in a "stop"
// event. Note that we don't know what port the server will forward the connection on, so we don't have an easy way to validate the
// event's addr field.
localConn, err := f.ssh.clt.DialContext(context.Background(), "tcp", nonForwardServer.Listener.Addr().String())
require.NoError(t, err)

e = nextEvent()
localForwardStart, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, events.PortForwardLocalEvent, localForwardStart.GetType())
require.Equal(t, events.PortForwardCode, localForwardStart.GetCode())
require.Equal(t, nonForwardServer.Listener.Addr().String(), localForwardStart.Addr)

// closed connections should result in PortForwardLocal stop events
localConn.Close()
e = nextEvent()
localForwardStop, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, events.PortForwardLocalEvent, localForwardStop.GetType())
require.Equal(t, events.PortForwardStopCode, localForwardStop.GetCode())
require.Equal(t, nonForwardServer.Listener.Addr().String(), localForwardStop.Addr)

// REMOTE PORT FORWARDING
// Creation of a port forwarded listener should generate PortForwardRemote start events
listener, err := f.ssh.clt.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

// Start up a test server that uses the port forwarded listener.
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
e = nextEvent()
remoteForwardStart, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, listener.Addr().String(), remoteForwardStart.Addr)
require.Equal(t, events.PortForwardRemoteEvent, remoteForwardStart.GetType())
require.Equal(t, events.PortForwardCode, remoteForwardStart.GetCode())

// Start up a test server that uses the remote port forwarded listener.
remoteForwardServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "hello, world")
}))
t.Cleanup(ts.Close)
ts.Listener = listener
ts.Start()
t.Cleanup(remoteForwardServer.Close)
remoteForwardServer.Listener = listener
remoteForwardServer.Start()

// Request forward to remote port. Each dial should result in a new event. Note that we don't
// know what port the server will forward the connection on, so we don't have an easy way to
// validate the event's addr field.
conn, err := f.ssh.clt.DialContext(context.Background(), "tcp", listener.Addr().String())
// Each dial to the remote listener should result in a new "start" event and each closed connection should result in a "stop" event.
// Note that we don't know what port the server will forward the connection on, so we don't have an easy way to validate the event's
// addr field.
remoteConn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
conn.Close()
e = nextEvent()
remoteConnStart, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, events.PortForwardRemoteConnEvent, remoteConnStart.GetType())
require.Equal(t, events.PortForwardCode, remoteConnStart.GetCode())

directPortForwardEvent := nextEvent()
require.IsType(t, &apievents.PortForward{}, directPortForwardEvent, "expected PortForward event but got event of type %T", directPortForwardEvent)
remoteConn.Close()
e = nextEvent()
remoteConnStop, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, events.PortForwardRemoteConnEvent, remoteConnStop.GetType())
require.Equal(t, events.PortForwardStopCode, remoteConnStop.GetCode())

// Closing the server (and therefore the listener) should generate an PortForwardRemote stop event
remoteForwardServer.Close()
e = nextEvent()
remotePortForwardEvent, ok := e.(*apievents.PortForward)
remoteForwardStop, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, listener.Addr().String(), remotePortForwardEvent.Addr)
require.Equal(t, events.PortForwardRemoteEvent, remoteForwardStop.GetType())
require.Equal(t, events.PortForwardStopCode, remoteForwardStop.Code)
require.Equal(t, listener.Addr().String(), remoteForwardStop.Addr)

// End the session. Session leave, data, and end events should be emitted.
se.Close()
Expand Down
8 changes: 0 additions & 8 deletions lib/sshutils/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,6 @@ func (mc *mockChannel) Stderr() io.ReadWriter {
return fakeReaderWriter{}
}

type mockSSHConn struct {
mockChan *mockChannel
}

func (mc *mockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
return mc.mockChan, make(<-chan *ssh.Request), nil
}

type mockSSHNewChannel struct {
mock.Mock
ssh.NewChannel
Expand Down
Loading
Loading