Skip to content

Commit

Permalink
Ensure the proxy service shuts down properly in case of partial init (#…
Browse files Browse the repository at this point in the history
…50710)

* Ensure the proxy service shuts dowqn properly in case of partial init

* Apply suggestions from code review

Align logging punctuation with the new log style

Co-authored-by: rosstimothy <[email protected]>

---------

Co-authored-by: rosstimothy <[email protected]>
  • Loading branch information
hugoShaka and rosstimothy authored Jan 7, 2025
1 parent a8d4362 commit 0ebacac
Showing 1 changed file with 158 additions and 135 deletions.
293 changes: 158 additions & 135 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4385,11 +4385,162 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
return trace.Wrap(err)
}

// We register the shutdown function before starting the services because we want to run it even if we encounter an
// error and return early. Some of the registered services don't watch if the context is Done (e.g. proxy.web).
// In case of error, if we don't run "proxy.shutdown", those registered services will run ad vitam aeternam and
// Supervisor.Wait() won't return.
var (
tsrv reversetunnelclient.Server
peerClient *peer.Client
peerQUICTransport *quic.Transport
rcWatcher *reversetunnel.RemoteClusterTunnelManager
peerServer *peer.Server
peerQUICServer *peerquic.Server
webServer *web.Server
minimalWebServer *web.Server
sshProxy *regular.Server
sshGRPCServer *grpc.Server
kubeServer *kubeproxy.TLSServer
grpcServerPublic *grpc.Server
grpcServerMTLS *grpc.Server
alpnServer *alpnproxy.Proxy
reverseTunnelALPNServer *alpnproxy.Proxy
clientTLSConfigGenerator *auth.ClientTLSConfigGenerator
)

defer func() {
// execute this when process is asked to exit:
process.OnExit("proxy.shutdown", func(payload interface{}) {
// Close the listeners at the beginning of shutdown, because we are not
// really guaranteed to be capable to serve new requests if we're
// halfway through a shutdown, and double closing a listener is fine.
listeners.Close()
if payload == nil {
logger.InfoContext(process.ExitContext(), "Shutting down immediately")
if tsrv != nil {
warnOnErr(process.ExitContext(), tsrv.Close(), logger)
}
if rcWatcher != nil {
warnOnErr(process.ExitContext(), rcWatcher.Close(), logger)
}
if peerServer != nil {
warnOnErr(process.ExitContext(), peerServer.Close(), logger)
}
if peerQUICServer != nil {
warnOnErr(process.ExitContext(), peerQUICServer.Close(), logger)
}
if webServer != nil {
warnOnErr(process.ExitContext(), webServer.Close(), logger)
}
if minimalWebServer != nil {
warnOnErr(process.ExitContext(), minimalWebServer.Close(), logger)
}
if peerClient != nil {
warnOnErr(process.ExitContext(), peerClient.Stop(), logger)
}
if sshProxy != nil {
warnOnErr(process.ExitContext(), sshProxy.Close(), logger)
}
if sshGRPCServer != nil {
sshGRPCServer.Stop()
}
if kubeServer != nil {
warnOnErr(process.ExitContext(), kubeServer.Close(), logger)
}
if grpcServerPublic != nil {
grpcServerPublic.Stop()
}
if grpcServerMTLS != nil {
grpcServerMTLS.Stop()
}
if alpnServer != nil {
warnOnErr(process.ExitContext(), alpnServer.Close(), logger)
}
if reverseTunnelALPNServer != nil {
warnOnErr(process.ExitContext(), reverseTunnelALPNServer.Close(), logger)
}

if clientTLSConfigGenerator != nil {
clientTLSConfigGenerator.Close()
}
} else {
logger.InfoContext(process.ExitContext(), "Shutting down gracefully")
ctx := payloadContext(payload)
if tsrv != nil {
warnOnErr(ctx, tsrv.DrainConnections(ctx), logger)
}
if sshProxy != nil {
warnOnErr(ctx, sshProxy.Shutdown(ctx), logger)
}
if sshGRPCServer != nil {
sshGRPCServer.GracefulStop()
}
if webServer != nil {
warnOnErr(ctx, webServer.Shutdown(ctx), logger)
}
if minimalWebServer != nil {
warnOnErr(ctx, minimalWebServer.Shutdown(ctx), logger)
}
if tsrv != nil {
warnOnErr(ctx, tsrv.Shutdown(ctx), logger)
}
if rcWatcher != nil {
warnOnErr(ctx, rcWatcher.Close(), logger)
}
if peerServer != nil {
warnOnErr(ctx, peerServer.Shutdown(), logger)
}
if peerQUICServer != nil {
warnOnErr(ctx, peerQUICServer.Shutdown(ctx), logger)
}
if peerClient != nil {
peerClient.Shutdown(ctx)
}
if kubeServer != nil {
warnOnErr(ctx, kubeServer.Shutdown(ctx), logger)
}
if grpcServerPublic != nil {
grpcServerPublic.GracefulStop()
}
if grpcServerMTLS != nil {
grpcServerMTLS.GracefulStop()
}
if alpnServer != nil {
warnOnErr(ctx, alpnServer.Close(), logger)
}
if reverseTunnelALPNServer != nil {
warnOnErr(ctx, reverseTunnelALPNServer.Close(), logger)
}

// Explicitly deleting proxy heartbeats helps the behavior of
// reverse tunnel agents during rollouts, as otherwise they'll keep
// trying to reach proxies until the heartbeats expire.
if services.ShouldDeleteServerHeartbeatsOnShutdown(ctx) {
if err := conn.Client.DeleteProxy(ctx, process.Config.HostUUID); err != nil {
if !trace.IsNotFound(err) {
logger.WarnContext(ctx, "Failed to delete heartbeat", "error", err)
} else {
logger.DebugContext(ctx, "Failed to delete heartbeat", "error", err)
}
}
}

if clientTLSConfigGenerator != nil {
clientTLSConfigGenerator.Close()
}
}
if peerQUICTransport != nil {
_ = peerQUICTransport.Close()
_ = peerQUICTransport.Conn.Close()
}
warnOnErr(process.ExitContext(), asyncEmitter.Close(), logger)
warnOnErr(process.ExitContext(), conn.Close(), logger)
logger.InfoContext(process.ExitContext(), "Exited")
})
}()

// register SSH reverse tunnel server that accepts connections
// from remote teleport nodes
var tsrv reversetunnelclient.Server
var peerClient *peer.Client
var peerQUICTransport *quic.Transport
if !process.Config.Proxy.DisableReverseTunnel {
if listeners.proxyPeer != nil {
if process.Config.Proxy.QUICProxyPeering {
Expand Down Expand Up @@ -4530,8 +4681,6 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {

// Register web proxy server
alpnHandlerForWeb := &alpnproxy.ConnectionHandlerWrapper{}
var webServer *web.Server
var minimalWebServer *web.Server

if !process.Config.Proxy.DisableWebService {
var fs http.FileSystem
Expand Down Expand Up @@ -4775,8 +4924,6 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}

var peerAddrString string
var peerServer *peer.Server
var peerQUICServer *peerquic.Server
if !process.Config.Proxy.DisableReverseTunnel && listeners.proxyPeer != nil {
peerAddr, err := process.Config.Proxy.PublicPeerAddr()
if err != nil {
Expand Down Expand Up @@ -4869,7 +5016,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
logger.InfoContext(process.ExitContext(), "advertising proxy peering QUIC support")
}

sshProxy, err := regular.New(
sshProxy, err = regular.New(
process.ExitContext(),
cfg.SSH.Addr,
cfg.Hostname,
Expand Down Expand Up @@ -4932,7 +5079,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}

// clientTLSConfigGenerator pre-generates specialized per-cluster client TLS config values
clientTLSConfigGenerator, err := auth.NewClientTLSConfigGenerator(auth.ClientTLSConfigGeneratorConfig{
clientTLSConfigGenerator, err = auth.NewClientTLSConfigGenerator(auth.ClientTLSConfigGeneratorConfig{
TLS: sshGRPCTLSConfig,
ClusterName: clusterName,
PermitRemoteClusters: true,
Expand All @@ -4953,7 +5100,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
return trace.Wrap(err)
}

sshGRPCServer := grpc.NewServer(
sshGRPCServer = grpc.NewServer(
grpc.ChainUnaryInterceptor(
interceptors.GRPCServerUnaryErrorInterceptor,
//nolint:staticcheck // SA1019. There is a data race in the stats.Handler that is replacing
Expand Down Expand Up @@ -5037,7 +5184,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
rcWatchLog := process.logger.With(teleport.ComponentKey, teleport.Component(teleport.ComponentReverseTunnelAgent, process.id))

// Create and register reverse tunnel AgentPool.
rcWatcher, err := reversetunnel.NewRemoteClusterTunnelManager(reversetunnel.RemoteClusterTunnelManagerConfig{
rcWatcher, err = reversetunnel.NewRemoteClusterTunnelManager(reversetunnel.RemoteClusterTunnelManagerConfig{
HostUUID: conn.HostID(),
AuthClient: conn.Client,
AccessPoint: accessPoint,
Expand Down Expand Up @@ -5066,7 +5213,6 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
return nil
})

var kubeServer *kubeproxy.TLSServer
if listeners.kube != nil && !process.Config.Proxy.DisableReverseTunnel {
authorizer, err := authz.NewAuthorizer(authz.AuthorizerOpts{
ClusterName: clusterName,
Expand Down Expand Up @@ -5284,10 +5430,6 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}
}

var (
grpcServerPublic *grpc.Server
grpcServerMTLS *grpc.Server
)
if alpnRouter != nil {
grpcServerPublic, err = process.initPublicGRPCServer(proxyLimiter, conn, listeners.grpcPublic)
if err != nil {
Expand All @@ -5314,8 +5456,6 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}
}

var alpnServer *alpnproxy.Proxy
var reverseTunnelALPNServer *alpnproxy.Proxy
if !cfg.Proxy.DisableTLS && !cfg.Proxy.DisableALPNSNIListener && listeners.web != nil {
authDialerService := alpnproxyauth.NewAuthProxyDialerService(
tsrv,
Expand Down Expand Up @@ -5378,123 +5518,6 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}
}

// execute this when process is asked to exit:
process.OnExit("proxy.shutdown", func(payload interface{}) {
// Close the listeners at the beginning of shutdown, because we are not
// really guaranteed to be capable to serve new requests if we're
// halfway through a shutdown, and double closing a listener is fine.
listeners.Close()
if payload == nil {
logger.InfoContext(process.ExitContext(), "Shutting down immediately.")
if tsrv != nil {
warnOnErr(process.ExitContext(), tsrv.Close(), logger)
}
warnOnErr(process.ExitContext(), rcWatcher.Close(), logger)
if peerServer != nil {
warnOnErr(process.ExitContext(), peerServer.Close(), logger)
}
if peerQUICServer != nil {
warnOnErr(process.ExitContext(), peerQUICServer.Close(), logger)
}
if webServer != nil {
warnOnErr(process.ExitContext(), webServer.Close(), logger)
}
if minimalWebServer != nil {
warnOnErr(process.ExitContext(), minimalWebServer.Close(), logger)
}
if peerClient != nil {
warnOnErr(process.ExitContext(), peerClient.Stop(), logger)
}
warnOnErr(process.ExitContext(), sshProxy.Close(), logger)
sshGRPCServer.Stop()
if kubeServer != nil {
warnOnErr(process.ExitContext(), kubeServer.Close(), logger)
}
if grpcServerPublic != nil {
grpcServerPublic.Stop()
}
if grpcServerMTLS != nil {
grpcServerMTLS.Stop()
}
if alpnServer != nil {
warnOnErr(process.ExitContext(), alpnServer.Close(), logger)
}
if reverseTunnelALPNServer != nil {
warnOnErr(process.ExitContext(), reverseTunnelALPNServer.Close(), logger)
}

if clientTLSConfigGenerator != nil {
clientTLSConfigGenerator.Close()
}
} else {
logger.InfoContext(process.ExitContext(), "Shutting down gracefully.")
ctx := payloadContext(payload)
if tsrv != nil {
warnOnErr(ctx, tsrv.DrainConnections(ctx), logger)
}
warnOnErr(ctx, sshProxy.Shutdown(ctx), logger)
sshGRPCServer.GracefulStop()
if webServer != nil {
warnOnErr(ctx, webServer.Shutdown(ctx), logger)
}
if minimalWebServer != nil {
warnOnErr(ctx, minimalWebServer.Shutdown(ctx), logger)
}
if tsrv != nil {
warnOnErr(ctx, tsrv.Shutdown(ctx), logger)
}
warnOnErr(ctx, rcWatcher.Close(), logger)
if peerServer != nil {
warnOnErr(ctx, peerServer.Shutdown(), logger)
}
if peerQUICServer != nil {
warnOnErr(ctx, peerQUICServer.Shutdown(ctx), logger)
}
if peerClient != nil {
peerClient.Shutdown(ctx)
}
if kubeServer != nil {
warnOnErr(ctx, kubeServer.Shutdown(ctx), logger)
}
if grpcServerPublic != nil {
grpcServerPublic.GracefulStop()
}
if grpcServerMTLS != nil {
grpcServerMTLS.GracefulStop()
}
if alpnServer != nil {
warnOnErr(ctx, alpnServer.Close(), logger)
}
if reverseTunnelALPNServer != nil {
warnOnErr(ctx, reverseTunnelALPNServer.Close(), logger)
}

// Explicitly deleting proxy heartbeats helps the behavior of
// reverse tunnel agents during rollouts, as otherwise they'll keep
// trying to reach proxies until the heartbeats expire.
if services.ShouldDeleteServerHeartbeatsOnShutdown(ctx) {
if err := conn.Client.DeleteProxy(ctx, process.Config.HostUUID); err != nil {
if !trace.IsNotFound(err) {
logger.WarnContext(ctx, "Failed to delete heartbeat.", "error", err)
} else {
logger.DebugContext(ctx, "Failed to delete heartbeat.", "error", err)
}
}
}

if clientTLSConfigGenerator != nil {
clientTLSConfigGenerator.Close()
}
}
if peerQUICTransport != nil {
_ = peerQUICTransport.Close()
_ = peerQUICTransport.Conn.Close()
}
warnOnErr(process.ExitContext(), asyncEmitter.Close(), logger)
warnOnErr(process.ExitContext(), conn.Close(), logger)
logger.InfoContext(process.ExitContext(), "Exited.")
})

return nil
}

Expand Down

0 comments on commit 0ebacac

Please sign in to comment.