From 8a9d527e8ca4700be052e5fb74e9fc5300769665 Mon Sep 17 00:00:00 2001 From: Roy Razon Date: Sun, 11 Feb 2024 15:08:44 +0200 Subject: [PATCH] tunnel server: more configuration --- tunnel-server/index.ts | 175 ++++++++++++++------------- tunnel-server/src/files.ts | 13 ++ tunnel-server/src/metrics.ts | 11 +- tunnel-server/src/ssh/base-server.ts | 6 +- tunnel-server/src/tls-server.ts | 29 +++-- 5 files changed, 130 insertions(+), 104 deletions(-) create mode 100644 tunnel-server/src/files.ts diff --git a/tunnel-server/index.ts b/tunnel-server/index.ts index 51277d33..89acc98e 100644 --- a/tunnel-server/index.ts +++ b/tunnel-server/index.ts @@ -1,13 +1,13 @@ import { promisify } from 'util' import pino from 'pino' -import fs from 'fs' import { KeyObject, createPublicKey } from 'crypto' +import { ListenOptions } from 'net' import { createApp } from './src/app/index.js' import { activeTunnelStoreKey, inMemoryActiveTunnelStore } from './src/tunnel-store/index.js' import { getSSHKeys } from './src/ssh-keys.js' import { proxy } from './src/proxy/index.js' import { appLoggerFromEnv } from './src/logging.js' -import { tunnelsGauge, runMetricsServer, sshConnectionsGauge } from './src/metrics.js' +import { tunnelsGauge, metricsServer as createMetricsServer, sshConnectionsGauge } from './src/metrics.js' import { numberFromEnv, requiredEnv } from './src/env.js' import { editUrl } from './src/url.js' import { cookieSessionStore } from './src/session.js' @@ -15,6 +15,27 @@ import { IdentityProvider, claimsSchema, cliIdentityProvider, jwtAuthenticator, import { createSshServer } from './src/ssh/index.js' import { calcLoginUrl } from './src/app/urls.js' import { createTlsServer } from './src/tls-server.js' +import { readFileSyncOrUndefined } from './src/files.js' + +type HasListen = { + listen: (opts: ListenOptions, callback: (err?: unknown) => void) => void +} + +const LISTEN_HOST = '0.0.0.0' +const listen = async ({ log, server, port }: { + server: T + log: pino.Logger + port: number +}) => { + try { + await promisify(server.listen).call(server, { port, host: LISTEN_HOST }) + log.info('Listening on port %d', port) + } catch (e) { + log.error(new Error(`Error listening on port ${port}`, { cause: e })) + process.exit(1) + } + return server +} const log = pino.default(appLoggerFromEnv()) @@ -25,7 +46,6 @@ const { sshPrivateKey } = await getSSHKeys({ const PORT = numberFromEnv('PORT') || 3000 const SSH_PORT = numberFromEnv('SSH_PORT') || 2222 -const LISTEN_HOST = '0.0.0.0' const BASE_URL = (() => { const result = new URL(requiredEnv('BASE_URL')) if (result.pathname !== '/' || result.search || result.username || result.password || result.hash) { @@ -36,23 +56,11 @@ const BASE_URL = (() => { log.info('base URL: %s', BASE_URL) -const isNotFoundError = (e: unknown) => (e as { code?: unknown })?.code === 'ENOENT' -const readFileSyncOrUndefined = (filename: string) => { - try { - return fs.readFileSync(filename, { encoding: 'utf8' }) - } catch (e) { - if (isNotFoundError(e)) { - return undefined - } - throw e - } -} - const tlsConfig = (() => { - const cert = readFileSyncOrUndefined('./tls/cert.pem') - const key = readFileSyncOrUndefined('./tls/key.pem') + const cert = process.env.TLS_CERT || readFileSyncOrUndefined(process.env.TLS_CERT_FILE || './tls/cert.pem') + const key = process.env.TLS_KEY || readFileSyncOrUndefined(process.env.TLS_KEY_FILE || './tls/key.pem') if (!cert || !key) { - log.info('No TLS cert or key found, TLS will be disabled') + log.warn('No TLS cert or key found, TLS will be disabled') return undefined } log.info('TLS will be enabled') @@ -84,23 +92,29 @@ const authFactory = ( baseIdentityProviders.concat(cliIdentityProvider(publicKey, publicKeyThumbprint)), ) -const activeTunnelStore = inMemoryActiveTunnelStore({ log }) +const activeTunnelStore = inMemoryActiveTunnelStore({ log: log.child({ name: 'tunnel_store' }) }) const sessionStore = cookieSessionStore({ domain: BASE_URL.hostname, schema: claimsSchema, keys: process.env.COOKIE_SECRETS?.split(' ') }) -const app = await createApp({ - sessionStore, - activeTunnelStore, - baseUrl: BASE_URL, - proxy: proxy({ - activeTunnelStore, - log, + +const appLog = log.child({ name: 'app' }) +const app = await listen({ + server: await createApp({ sessionStore, - baseHostname: BASE_URL.hostname, + activeTunnelStore, + baseUrl: BASE_URL, + proxy: proxy({ + activeTunnelStore, + log, + sessionStore, + baseHostname: BASE_URL.hostname, + authFactory, + loginUrl: ({ env, returnPath }) => calcLoginUrl({ baseUrl: BASE_URL, env, returnPath }), + }), + log: appLog, authFactory, - loginUrl: ({ env, returnPath }) => calcLoginUrl({ baseUrl: BASE_URL, env, returnPath }), + saasBaseUrl: saasIdp ? new URL(requiredEnv('SAAS_BASE_URL')) : undefined, }), - log, - authFactory, - saasBaseUrl: saasIdp ? new URL(requiredEnv('SAAS_BASE_URL')) : undefined, + log: appLog, + port: PORT, }) const tunnelUrl = ( @@ -109,62 +123,59 @@ const tunnelUrl = ( tunnel: string, ) => editUrl(rootUrl, { hostname: `${activeTunnelStoreKey(clientId, tunnel)}.${rootUrl.hostname}` }).toString() -const sshServer = createSshServer({ - log: log.child({ name: 'ssh_server' }), - sshPrivateKey, - socketDir: '/tmp', // TODO - activeTunnelStore, - helloBaseResponse: { - // TODO: backwards compat, remove when we drop support for CLI v0.0.35 - baseUrl: { hostname: BASE_URL.hostname, port: BASE_URL.port, protocol: BASE_URL.protocol }, - rootUrl: BASE_URL.toString(), - }, - tunnelsGauge, - sshConnectionsGauge, - tunnelUrl: (clientId, remotePath) => tunnelUrl(BASE_URL, clientId, remotePath), -}) - .listen(SSH_PORT, LISTEN_HOST, () => { - app.log.debug('ssh server listening on port %j', SSH_PORT) - }) - -app.listen({ host: LISTEN_HOST, port: PORT }).catch(err => { - app.log.error(err) - process.exit(1) +const sshServerLog = log.child({ name: 'ssh_server' }) +const sshServer = await listen({ + server: createSshServer({ + log: sshServerLog, + sshPrivateKey, + socketDir: '/tmp', // TODO + activeTunnelStore, + helloBaseResponse: { + // TODO: backwards compat, remove when we drop support for CLI v0.0.35 + baseUrl: { hostname: BASE_URL.hostname, port: BASE_URL.port, protocol: BASE_URL.protocol }, + rootUrl: BASE_URL.toString(), + }, + tunnelsGauge, + sshConnectionsGauge, + tunnelUrl: (clientId, remotePath) => tunnelUrl(BASE_URL, clientId, remotePath), + }), + log: sshServerLog, + port: SSH_PORT, }) const TLS_PORT = numberFromEnv('TLS_PORT') ?? 8443 const tlsLog = log.child({ name: 'tls_server' }) const tlsServer = tlsConfig - ? createTlsServer({ + ? await listen({ + server: createTlsServer({ + log: tlsLog, + tlsConfig, + sshServer, + httpServer: + app.server, + sshHostnames: process.env.SSH_HOSTNAMES ? process.env.SSH_HOSTNAMES.split(',') : [BASE_URL.hostname], + }), + port: TLS_PORT, log: tlsLog, - tlsConfig, - sshServer, - httpServer: - app.server, - sshHostnames: new Set([BASE_URL.hostname]), - }) - : undefined - -tlsServer?.listen({ host: LISTEN_HOST, port: TLS_PORT }, () => { tlsLog.info('TLS server listening on port %j', TLS_PORT) }) - -runMetricsServer(8888).catch(err => { - app.log.error(err) -}); - -['SIGTERM', 'SIGINT'].forEach(signal => { - process.once(signal, () => { - app.log.info(`shutting down on ${signal}`) - Promise.all([ - promisify(sshServer.close).call(sshServer), - app.close(), - tlsServer ? promisify(tlsServer.close).call(tlsServer) : undefined, - ]) - .catch(err => { - app.log.error(err) - process.exit(1) - }) - .finally(() => { - process.exit(0) - }) + }) : undefined + +const metricsLerverLog = log.child({ name: 'metrics_server' }) +const metricsServer = await listen({ + server: createMetricsServer({ log: metricsLerverLog }), + port: 8888, + log: metricsLerverLog, +}) + +const exitSignals = ['SIGTERM', 'SIGINT'] as const +const servers = [app, sshServer, metricsServer, ...tlsServer ? [tlsServer] : []] as const + +exitSignals.forEach(signal => { + process.once(signal, async () => { + log.info(`Shutting down on ${signal}`) + await Promise.all(servers.map(server => promisify(server.close).call(server))).catch(err => { + log.error(err) + process.exit(1) + }) + process.exit(0) }) }) diff --git a/tunnel-server/src/files.ts b/tunnel-server/src/files.ts new file mode 100644 index 00000000..a4750d23 --- /dev/null +++ b/tunnel-server/src/files.ts @@ -0,0 +1,13 @@ +import fs from 'fs' + +const isNotFoundError = (e: unknown) => (e as { code?: unknown })?.code === 'ENOENT' +export const readFileSyncOrUndefined = (filename: string) => { + try { + return fs.readFileSync(filename, { encoding: 'utf8' }) + } catch (e) { + if (isNotFoundError(e)) { + return undefined + } + throw e + } +} diff --git a/tunnel-server/src/metrics.ts b/tunnel-server/src/metrics.ts index ad8d05e5..1345918d 100644 --- a/tunnel-server/src/metrics.ts +++ b/tunnel-server/src/metrics.ts @@ -1,4 +1,5 @@ import fastify from 'fastify' +import pino from 'pino' import { Gauge, Counter, register } from 'prom-client' export const sshConnectionsGauge = new Gauge({ @@ -21,16 +22,14 @@ export const requestsCounter = new Counter({ register.setDefaultLabels({ serviceName: 'preevy-tunnel-server' }) -export function runMetricsServer(port: number) { - const app = fastify() +export const metricsServer = ({ log }: { log: pino.Logger }) => { + const app = fastify({ logger: log }) app.get('/metrics', async (_request, reply) => { // TODO: changing the "void" below to await hangs, find out why and fix void reply.header('Content-Type', register.contentType) void reply.send(await register.metrics()) }) - return app.listen({ - host: '0.0.0.0', - port, - }) + + return app } diff --git a/tunnel-server/src/ssh/base-server.ts b/tunnel-server/src/ssh/base-server.ts index 4a7b7534..293e7af5 100644 --- a/tunnel-server/src/ssh/base-server.ts +++ b/tunnel-server/src/ssh/base-server.ts @@ -69,10 +69,7 @@ type BaseSshServerEvents = { error: (err: Error) => void } -export interface BaseSshServer extends IEventEmitter { - close: ssh2.Server['close'] - listen: ssh2.Server['listen'] -} +export type BaseSshServer = IEventEmitter & Pick export const baseSshServer = ( { @@ -304,5 +301,6 @@ export const baseSshServer = ( return Object.assign(serverEmitter, { close: server.close.bind(server), listen: server.listen.bind(server), + injectSocket: server.injectSocket.bind(server), }) } diff --git a/tunnel-server/src/tls-server.ts b/tunnel-server/src/tls-server.ts index ba4cce33..84d621ab 100644 --- a/tunnel-server/src/tls-server.ts +++ b/tunnel-server/src/tls-server.ts @@ -8,15 +8,20 @@ export const createTlsServer = ({ log, httpServer, sshServer, tlsConfig, sshHost httpServer: Pick sshServer: Pick tlsConfig: tls.TlsOptions - sshHostnames: Set -}) => tls.createServer(tlsConfig) - .on('error', err => { log.error(err) }) - .on('secureConnection', socket => { - const { servername } = (socket as { servername?: string }) - log.debug('TLS connection: %j', servername) - if (servername && sshHostnames.has(servername)) { - sshServer.injectSocket(socket) - } else { - httpServer.emit('connection', socket) - } - }) + sshHostnames: string[] +}) => { + log.info('SSH hostnames: %j', sshHostnames) + const sshHostnamesSet = new Set(sshHostnames) + + return tls.createServer(tlsConfig) + .on('error', err => { log.error(err) }) + .on('secureConnection', socket => { + const { servername } = (socket as { servername?: string }) + log.debug('TLS connection: %j', servername) + if (servername && sshHostnamesSet.has(servername)) { + sshServer.injectSocket(socket) + } else { + httpServer.emit('connection', socket) + } + }) +}