Skip to content

Commit

Permalink
Convert lib/srv to use slog (#49913)
Browse files Browse the repository at this point in the history
This migrates the rest of the srv package to use slog for logging.
Most sub-packages still however rely on logrus.
  • Loading branch information
rosstimothy authored Dec 11, 2024
1 parent 0389d5d commit a191cc1
Show file tree
Hide file tree
Showing 14 changed files with 204 additions and 213 deletions.
82 changes: 38 additions & 44 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -265,7 +264,7 @@ type ServerContext struct {
// ConnectionContext is the parent context which manages connection-level
// resources.
*sshutils.ConnectionContext
*log.Entry
Logger *slog.Logger

mu sync.RWMutex

Expand Down Expand Up @@ -434,17 +433,14 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
ServerSubKind: srv.TargetMetadata().ServerSubKind,
}

fields := log.Fields{
"local": child.ServerConn.LocalAddr(),
"remote": child.ServerConn.RemoteAddr(),
"login": child.Identity.Login,
"teleportUser": child.Identity.TeleportUser,
"id": child.id,
}
child.Entry = log.WithFields(log.Fields{
teleport.ComponentKey: child.srv.Component(),
teleport.ComponentFields: fields,
})
child.Logger = slog.With(
teleport.ComponentKey, srv.Component(),
"local_addr", child.ServerConn.LocalAddr(),
"remote_addr", child.ServerConn.RemoteAddr(),
"login", child.Identity.Login,
"teleport_user", child.Identity.TeleportUser,
"id", child.id,
)

if identityContext.Login == teleport.SSHSessionJoinPrincipal {
child.JoinOnly = true
Expand All @@ -462,15 +458,11 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s

// Update log entry fields.
if !child.disconnectExpiredCert.IsZero() {
fields["cert"] = child.disconnectExpiredCert
child.Logger = child.Logger.With("cert", child.disconnectExpiredCert)
}
if child.clientIdleTimeout != 0 {
fields["idle"] = child.clientIdleTimeout
child.Logger = child.Logger.With("idle", child.clientIdleTimeout)
}
child.Entry = log.WithFields(log.Fields{
teleport.ComponentKey: srv.Component(),
teleport.ComponentFields: fields,
})

clusterName, err := srv.GetAccessPoint().GetClusterName()
if err != nil {
Expand All @@ -491,11 +483,9 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
TeleportUser: child.Identity.TeleportUser,
Login: child.Identity.Login,
ServerID: child.srv.ID(),
// TODO(tross) update this to use the child logger
// once ServerContext is converted to use a slog.Logger
Logger: slog.Default(),
Emitter: child.srv,
EmitterContext: ctx,
Logger: child.Logger,
Emitter: child.srv,
EmitterContext: ctx,
}
for _, opt := range monitorOpts {
opt(&monitorConfig)
Expand Down Expand Up @@ -573,15 +563,15 @@ func (c *ServerContext) GetServer() Server {

// CreateOrJoinSession will look in the SessionRegistry for the session ID. If
// no session is found, a new one is created. If one is found, it is returned.
func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error {
func (c *ServerContext) CreateOrJoinSession(ctx context.Context, reg *SessionRegistry) error {
c.mu.Lock()
defer c.mu.Unlock()
// As SSH conversation progresses, at some point a session will be created and
// its ID will be added to the environment
ssid, found := c.getEnvLocked(sshutils.SessionEnvVar)
if !found {
c.sessionID = rsession.NewID()
c.Logger.Debugf("Will create new session for SSH connection %v.", c.ServerConn.RemoteAddr())
c.Logger.DebugContext(ctx, "Will create new session for SSH connection")
return nil
}

Expand All @@ -595,7 +585,7 @@ func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error {
if sess, found := reg.findSession(*id); found {
c.sessionID = *id
c.session = sess
c.Logger.Debugf("Will join session %v for SSH connection %v.", c.session.id, c.ServerConn.RemoteAddr())
c.Logger.DebugContext(ctx, "Joining active SSH session", "session_id", c.session.id)
} else {
// TODO(capnspacehook): DELETE IN 17.0.0 - by then all supported
// clients should only set TELEPORT_SESSION when they want to
Expand All @@ -605,7 +595,7 @@ func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error {
// to prevent the user from controlling the session ID, generate
// a new one
c.sessionID = rsession.NewID()
c.Logger.Debugf("Will create new session for SSH connection %v.", c.ServerConn.RemoteAddr())
c.Logger.DebugContext(ctx, "Creating new SSH session")
}

return nil
Expand Down Expand Up @@ -676,18 +666,18 @@ func (c *ServerContext) getEnvLocked(key string) (string, bool) {
}

// setSession sets the context's session
func (c *ServerContext) setSession(sess *session, ch ssh.Channel) {
func (c *ServerContext) setSession(ctx context.Context, sess *session, ch ssh.Channel) {
c.mu.Lock()
defer c.mu.Unlock()
c.session = sess

// inform the client of the session ID that is being used in a new
// goroutine to reduce latency
go func() {
c.Logger.Debug("Sending current session ID.")
c.Logger.DebugContext(ctx, "Sending current session ID")
_, err := ch.SendRequest(teleport.CurrentSessionIDRequest, false, []byte(sess.ID()))
if err != nil {
c.Logger.WithError(err).Debug("Failed to send the current session ID.")
c.Logger.DebugContext(ctx, "Failed to send the current session ID", "error", err)
}
}()
}
Expand Down Expand Up @@ -754,7 +744,7 @@ func (c *ServerContext) CheckSFTPAllowed(registry *SessionRegistry) error {
}

// OpenXServerListener opens a new XServer unix listener.
func (c *ServerContext) HandleX11Listener(l net.Listener, singleConnection bool) error {
func (c *ServerContext) HandleX11Listener(ctx context.Context, l net.Listener, singleConnection bool) error {
display, err := x11.ParseDisplayFromUnixSocket(l.Addr().String())
if err != nil {
return trace.Wrap(err)
Expand All @@ -780,7 +770,7 @@ func (c *ServerContext) HandleX11Listener(l net.Listener, singleConnection bool)
xconn, err := l.Accept()
if err != nil {
if !utils.IsOKNetworkError(err) {
c.Logger.WithError(err).Debug("Encountered error accepting XServer connection")
c.Logger.DebugContext(ctx, "Encountered error accepting XServer connection", "error", err)
}
return
}
Expand All @@ -790,7 +780,7 @@ func (c *ServerContext) HandleX11Listener(l net.Listener, singleConnection bool)

xchan, sin, err := c.ServerConn.OpenChannel(x11.ChannelRequest, x11ChannelReqPayload)
if err != nil {
c.Logger.WithError(err).Debug("Failed to open a new X11 channel")
c.Logger.DebugContext(ctx, "Failed to open a new X11 channel", "error", err)
return
}
defer xchan.Close()
Expand All @@ -802,12 +792,12 @@ func (c *ServerContext) HandleX11Listener(l net.Listener, singleConnection bool)
go func() {
err := sshutils.ForwardRequests(ctx, sin, c.RemoteSession)
if err != nil {
c.Logger.WithError(err).Debug("Failed to forward ssh request from server during X11 forwarding")
c.Logger.DebugContext(ctx, "Failed to forward ssh request from server during X11 forwarding", "error", err)
}
}()

if err := utils.ProxyConn(ctx, xconn, xchan); err != nil {
c.Logger.WithError(err).Debug("Encountered error during X11 forwarding")
c.Logger.DebugContext(ctx, "Encountered error during X11 forwarding", "error", err)
}
}()

Expand Down Expand Up @@ -884,7 +874,7 @@ func (c *ServerContext) reportStats(conn utils.Stater) {
sessionDataEvent.ConnectionMetadata.LocalAddr = c.ServerConn.LocalAddr().String()
}
if err := c.GetServer().EmitAuditEvent(c.GetServer().Context(), sessionDataEvent); err != nil {
c.WithError(err).Warn("Failed to emit session data event.")
c.Logger.WarnContext(c.GetServer().Context(), "Failed to emit session data event", "error", err)
}

// Emit TX and RX bytes to their respective Prometheus counters.
Expand Down Expand Up @@ -926,21 +916,21 @@ func (c *ServerContext) CancelFunc() context.CancelFunc {

// SendExecResult sends the result of execution of the "exec" command over the
// ExecResultCh.
func (c *ServerContext) SendExecResult(r ExecResult) {
func (c *ServerContext) SendExecResult(ctx context.Context, r ExecResult) {
select {
case c.ExecResultCh <- r:
default:
c.Infof("Blocked on sending exec result %v.", r)
c.Logger.InfoContext(ctx, "Blocked on sending exec result", "code", r.Code, "command", r.Command)
}
}

// SendSubsystemResult sends the result of running the subsystem over the
// SubsystemResultCh.
func (c *ServerContext) SendSubsystemResult(r SubsystemResult) {
func (c *ServerContext) SendSubsystemResult(ctx context.Context, r SubsystemResult) {
select {
case c.SubsystemResultCh <- r:
default:
c.Info("Blocked on sending subsystem result.")
c.Logger.InfoContext(ctx, "Blocked on sending subsystem result")
}
}

Expand Down Expand Up @@ -1005,7 +995,11 @@ func getPAMConfig(c *ServerContext) (*PAMConfig, error) {
// If the trait isn't passed by the IdP due to misconfiguration
// we fallback to setting a value which will indicate this.
if trace.IsNotFound(err) {
c.Logger.WithError(err).Warnf("Attempted to interpolate custom PAM environment with external trait but received SAML response does not contain claim")
c.Logger.WarnContext(
c.CancelContext(),
"Attempted to interpolate custom PAM environment with external trait but received SAML response does not contain claim",
"error", err,
)
continue
}

Expand Down Expand Up @@ -1120,11 +1114,11 @@ func buildEnvironment(ctx *ServerContext) []string {
// SSH_CONNECTION environment variables.
remoteHost, remotePort, err := net.SplitHostPort(ctx.ServerConn.RemoteAddr().String())
if err != nil {
ctx.Logger.Debugf("Failed to split remote address: %v.", err)
ctx.Logger.DebugContext(ctx.CancelContext(), "Failed to split remote address", "error", err)
} else {
localHost, localPort, err := net.SplitHostPort(ctx.ServerConn.LocalAddr().String())
if err != nil {
ctx.Logger.Debugf("Failed to split local address: %v.", err)
ctx.Logger.DebugContext(ctx.CancelContext(), "Failed to split local address", "error", err)
} else {
env.AddTrusted("SSH_CLIENT", fmt.Sprintf("%s %s %s", remoteHost, remotePort, localPort))
env.AddTrusted("SSH_CONNECTION", fmt.Sprintf("%s %s %s %s", remoteHost, remotePort, localHost, localPort))
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func TestCreateOrJoinSession(t *testing.T) {
ctx.SetEnv(sshutils.SessionEnvVar, tt.sessionID)
}

err = ctx.CreateOrJoinSession(registry)
err = ctx.CreateOrJoinSession(context.Background(), registry)
require.NoError(t, err)
require.False(t, ctx.sessionID.IsZero())
if tt.wantSameSessionID {
Expand Down
Loading

0 comments on commit a191cc1

Please sign in to comment.