diff --git a/client.go b/client.go index 5f87184..97379ac 100644 --- a/client.go +++ b/client.go @@ -14,6 +14,7 @@ import ( type client struct { key clientKey + sshCert *ssh.Certificate sshConfig ssh.ClientConfig sshClient *ssh.Client httpClient *http.Client @@ -42,17 +43,16 @@ func (key *clientKey) String() string { // establishes the SSH connection and sets up the HTTP client func (client *client) connect() error { - sshClient, err := ssh.Dial("tcp", client.key.hostPort(), &client.sshConfig) if err != nil { metrics.connections.failed++ - log.Printf("SSH connection to %v failed: %v", client.key.String(), err) + log.Printf("SSH connection to %s failed: %v", client.key.String(), err) return err } client.sshClient = sshClient metrics.connections.established++ - log.Printf("SSH connection to %v established", client.key.String()) + log.Printf("SSH connection to %s established", client.key.String()) return nil } @@ -87,10 +87,10 @@ retry: if err == nil { metrics.forwardings.established++ - log.Printf("TCP forwarding via %v to %s established", client.key.String(), address) + log.Printf("TCP forwarding via %s to %s established", client.key.String(), address) } else { metrics.forwardings.failed++ - log.Printf("TCP forwarding via %v to %s failed: %s", client.key.String(), address, err) + log.Printf("TCP forwarding via %s to %s failed: %s", client.key.String(), address, err) } return conn, err diff --git a/client_test.go b/client_test.go index c35c003..a3c1945 100644 --- a/client_test.go +++ b/client_test.go @@ -139,6 +139,7 @@ func TestClientDialHTTPS(t *testing.T) { response.Body.Close() } } + func TestInvalidRequestURI(t *testing.T) { assert := assert.New(t) diff --git a/main.go b/main.go index 6c44efc..2e05da0 100644 --- a/main.go +++ b/main.go @@ -24,10 +24,9 @@ var home = func() string { return os.Getenv("HOME") }() -var sshKeyDir = or(os.Getenv("HOS_KEY_DIR"), filepath.Join(home, ".ssh")) - var ( - sshKeys = []string{ + sshKeyDir = envStr("HOS_KEY_DIR", filepath.Join(home, ".ssh")) + sshKeys = []string{ filepath.Join(sshKeyDir, "id_rsa"), filepath.Join(sshKeyDir, "id_ed25519"), } @@ -37,18 +36,10 @@ var ( // command line flags var ( - listen = or(os.Getenv("HOS_LISTEN"), "[::1]:8080") - enableMetrics = os.Getenv("HOS_METRICS") != "0" - sshUser = or(os.Getenv("HOS_USER"), "root") - sshTimeout = func() time.Duration { - dur := os.Getenv("HOS_TIMEOUT") - if dur != "" { - if d, err := time.ParseDuration(dur); err != nil { - return d - } - } - return 10 * time.Second - }() + listen = envStr("HOS_LISTEN", "[::1]:8080") + enableMetrics = envStr("HOS_METRICS", "1") != "0" + sshUser = envStr("HOS_USER", "root") + sshTimeout = envDur("HOS_TIMEOUT", 10*time.Second) ) // build flags @@ -97,9 +88,16 @@ func main() { log.Fatal(http.ListenAndServe(listen, nil)) } -func or(s, alt string) string { - if s != "" { +func envStr(name, fallback string) string { + if s := os.Getenv(name); s != "" { return s } - return alt + return fallback +} + +func envDur(name string, fallback time.Duration) time.Duration { + if dur, err := time.ParseDuration(os.Getenv(name)); err == nil { + return dur + } + return fallback } diff --git a/metrics.go b/metrics.go index f2179de..236363c 100644 --- a/metrics.go +++ b/metrics.go @@ -1,6 +1,8 @@ package main -import "github.com/prometheus/client_golang/prometheus" +import ( + "github.com/prometheus/client_golang/prometheus" +) type connectionStats struct { established uint @@ -8,40 +10,60 @@ type connectionStats struct { } type prometheusExporter struct { + certTTL *prometheus.Desc + connUp *prometheus.Desc + conns *prometheus.Desc + fwds *prometheus.Desc + connections connectionStats forwardings connectionStats } var ( - metrics = prometheusExporter{} - - variableLabels = []string{"state"} - sshConnectionUpDesc = prometheus.NewDesc("sshproxy_connection_up", "SSH connection up", []string{"host"}, nil) - sshConnectionsDesc = prometheus.NewDesc("sshproxy_connections_total", "SSH connections", variableLabels, nil) - sshForwardingsDesc = prometheus.NewDesc("sshproxy_forwardings_total", "TCP forwardings", variableLabels, nil) + connLabels = []string{"state"} + hostLabel = []string{"host"} ) +var metrics = prometheusExporter{ + certTTL: prometheus.NewDesc("sshproxy_certificate_ttl", "TTL until SSH certificate expires", hostLabel, nil), + connUp: prometheus.NewDesc("sshproxy_connection_up", "SSH connection up", hostLabel, nil), + conns: prometheus.NewDesc("sshproxy_connections_total", "SSH connections", connLabels, nil), + fwds: prometheus.NewDesc("sshproxy_forwardings_total", "TCP forwardings", connLabels, nil), +} + // Describe implements (part of the) prometheus.Collector interface. func (e *prometheusExporter) Describe(c chan<- *prometheus.Desc) { - c <- sshConnectionsDesc - c <- sshForwardingsDesc + c <- metrics.certTTL + c <- metrics.connUp + c <- metrics.conns + c <- metrics.fwds } // Collect implements (part of the) prometheus.Collector interface. func (e prometheusExporter) Collect(c chan<- prometheus.Metric) { const C = prometheus.CounterValue - c <- prometheus.MustNewConstMetric(sshConnectionsDesc, C, float64(e.connections.established), "established") - c <- prometheus.MustNewConstMetric(sshConnectionsDesc, C, float64(e.connections.failed), "failed") - c <- prometheus.MustNewConstMetric(sshForwardingsDesc, C, float64(e.forwardings.established), "established") - c <- prometheus.MustNewConstMetric(sshForwardingsDesc, C, float64(e.forwardings.failed), "failed") + const G = prometheus.GaugeValue + met := prometheus.MustNewConstMetric + + c <- met(metrics.conns, C, float64(e.connections.established), "established") + c <- met(metrics.conns, C, float64(e.connections.failed), "failed") + c <- met(metrics.fwds, C, float64(e.forwardings.established), "established") + c <- met(metrics.fwds, C, float64(e.forwardings.failed), "failed") proxy.mtx.Lock() for key, client := range proxy.clients { + host := key.String() + var up float64 if client.sshClient != nil { up = 1 } - c <- prometheus.MustNewConstMetric(sshConnectionUpDesc, prometheus.GaugeValue, up, key.String()) + c <- met(metrics.connUp, G, up, host) + + if cert := client.sshCert; cert != nil { + ttl := float64(cert.ValidBefore) + c <- met(metrics.certTTL, G, ttl, host) + } } proxy.mtx.Unlock() } diff --git a/proxy.go b/proxy.go index 7a811e0..7e8f0ea 100644 --- a/proxy.go +++ b/proxy.go @@ -38,6 +38,16 @@ func (proxy *Proxy) getClient(key clientKey) *client { key: key, sshConfig: proxy.sshConfig, // make copy } + pClient.sshConfig.HostKeyCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error { + if err := proxy.sshConfig.HostKeyCallback(hostname, remote, key); err != nil { + return err + } + if cert, ok := key.(*ssh.Certificate); ok && cert != nil { + pClient.sshCert = cert + } + return nil + } + if key.username != "" { pClient.sshConfig.User = key.username }