-
Notifications
You must be signed in to change notification settings - Fork 2
/
sshtransport.go
310 lines (286 loc) · 9.71 KB
/
sshtransport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
package main
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)
func makePubkeyAuth(keyFile string) ([]ssh.AuthMethod, error) {
key, err := ioutil.ReadFile(keyFile)
if err != nil {
return nil, fmt.Errorf("unable to read private key file %s", keyFile)
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
return nil, fmt.Errorf("unable to parse private key file %s", keyFile)
}
return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil
}
type sshTransport struct {
port int
user string
auth []ssh.AuthMethod
sshClientPool *sshClientPool
TransportRegular http.RoundTripper
TransportTLSSkipVerify http.RoundTripper
keyFile string
knownHostsFile string
knownHostsCallback ssh.HostKeyCallback
nextProxyAddr string
}
// trackingSshClient wraps an ssh.Client and tracks
// all connections opened via DialContext and closed via conn.Close().
// This allows it to implement a safe CloseWhenFinished method,
// which can be used to delay closing of the SSH client until the last
// contained connection has been closed properly.
// This avoids crashing when the SSH connection aborts
// while there's still an inflight HTTP connection over an SSH
// channel.
type trackingSshClient struct {
*ssh.Client
mtx sync.Mutex
inflightConns int64
shouldClose bool
}
// trackingSshConn is a wrapper for net.Conn, which is used by
// trackingSshClient to ensure that closed connections are properly
// tracked in the client.
type trackingSshConn struct {
net.Conn
closeFunc func()
}
func (conn trackingSshConn) Close() error {
err := conn.Conn.Close()
conn.closeFunc()
return err
}
func (c *trackingSshClient) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
c.mtx.Lock()
c.inflightConns += 1
c.mtx.Unlock()
conn, err := c.Client.DialContext(ctx, n, addr)
if err != nil {
c.connCloseCallback()
return conn, err
}
tc := trackingSshConn{Conn: conn, closeFunc: c.connCloseCallback}
return tc, err
}
func (c *trackingSshClient) connCloseCallback() {
c.mtx.Lock()
c.inflightConns -= 1
c.mtx.Unlock()
}
func (c *trackingSshClient) CloseWhenFinished() error {
c.mtx.Lock()
defer c.mtx.Unlock()
c.shouldClose = true
var err error
if c.inflightConns <= 0 {
log.Trace("closing ssh transport connection")
err = c.Close()
} else {
log.WithFields(log.Fields{"inflightConns": c.inflightConns}).Trace("delaying closing of ssh transport connection due to active connections")
}
return err
}
func NewSSHTransport(user, keyFile, knownHostsFile string, port int, nextProxyAddr string) (*sshTransport, error) {
t := &sshTransport{
port: port,
sshClientPool: newSSHClientPool(),
keyFile: keyFile,
knownHostsFile: knownHostsFile,
user: user,
nextProxyAddr: nextProxyAddr,
}
err := t.LoadFiles()
if err != nil {
return nil, err
}
t.createTransports()
return t, nil
}
func (t *sshTransport) LoadFiles() error {
auth, err := makePubkeyAuth(t.keyFile)
if err != nil {
return fmt.Errorf("failed to load private key file: %s", err)
}
t.auth = auth
knownHostsCallback, err := knownhosts.New(t.knownHostsFile)
if err != nil {
return fmt.Errorf("failed to load known hosts: %s", err)
}
t.knownHostsCallback = knownHostsCallback
return nil
}
func (t *sshTransport) createTransports() {
transportRegular := &http.Transport{
Proxy: nil,
DialContext: t.dialContext,
MaxIdleConns: 100,
IdleConnTimeout: 2 * timeoutDurationSeconds,
ResponseHeaderTimeout: timeoutDurationSeconds,
ExpectContinueTimeout: 1 * time.Second,
}
transportTLSSkipVerify := transportRegular.Clone()
transportTLSSkipVerify.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
t.TransportRegular = transportRegular
t.TransportTLSSkipVerify = transportTLSSkipVerify
}
func (t *sshTransport) checkHostKey(hostname string, remote net.Addr, key ssh.PublicKey) error {
log.WithFields(log.Fields{"hostname": hostname, "remote": remote, "key": key}).Warn("blindly accepting host key")
return nil
}
func (t *sshTransport) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if network != "tcp" {
log.WithFields(log.Fields{"network": network, "addr": addr}).Error("network type not supported")
return nil, fmt.Errorf("network type %s is not supported", network)
}
if t.nextProxyAddr != "" {
addr = t.nextProxyAddr
}
targetHost, targetPort, err := splitAddr(addr)
if err != nil {
metricErrorsByType.WithLabelValues("address_parsing").Inc()
return nil, errors.New("failed to parse address")
}
retry := true
for retry {
client, err := t.getSSHClient(targetHost)
if err != nil {
metricErrorsByType.WithLabelValues("ssh_connection").Inc()
return nil, fmt.Errorf("failed to obtain ssh connection: %s", err)
}
log.WithFields(log.Fields{"port": targetPort}).Trace("connecting")
conn, err := client.DialContext(ctx, "tcp4", fmt.Sprintf("%s:%d", "127.0.0.1", targetPort))
log.WithFields(log.Fields{"port": targetPort, "err": err}).Trace("done")
if err == nil {
return conn, nil
}
log.WithFields(log.Fields{"host": targetHost, "err": err}).Debug("connection failed, sending keepalive")
errChan := make(chan error)
go func() {
_, _, err := client.SendRequest("[email protected]", true, nil)
errChan <- err
}()
var keepAliveErr error
select {
case keepAliveErr = <-errChan:
if keepAliveErr == nil {
log.WithFields(log.Fields{"host": targetHost}).Debug("keepalive worked, this is not an ssh conn problem")
return nil, err
}
metricErrorsByType.WithLabelValues("ssh_keepalive_failure").Inc()
case <-time.After(timeoutDurationSeconds / 2):
keepAliveErr = fmt.Errorf("failed to receive keepalive within %d seconds, reconnecting", *timeout)
metricErrorsByType.WithLabelValues("ssh_keepalive_timeout").Inc()
}
log.WithFields(log.Fields{"host": targetHost, "err": keepAliveErr}).Debug("keepalive failed, reconnecting")
t.sshClientPool.delete(targetHost)
// Don't close right away, there might still be inflight
// requests which would otherwise crash as they reference
// invalid memory:
_ = client.CloseWhenFinished()
metricSshKeepaliveFailuresTotal.Inc()
retry = false
continue
}
return nil, err
}
// getHostkeyAlgosFor queries the knownhosts database for the given hostport with an invalid
// key to match against. This generates an error which can be used to query for the
// available key type algorithms.
func (t *sshTransport) getHostkeyAlgosFor(hostport string) ([]string, error) {
placeholderAddr := &net.TCPAddr{IP: []byte{0, 0, 0, 0}}
var placeholderPubkey invalidPublicKey
var algos []string
var knownHostsLookupError *knownhosts.KeyError
if err := t.knownHostsCallback(hostport, placeholderAddr, &placeholderPubkey); errors.As(err, &knownHostsLookupError) {
for _, knownKey := range knownHostsLookupError.Want {
algos = append(algos, knownKey.Key.Type())
}
}
if len(algos) < 1 {
metricErrorsByType.WithLabelValues("ssh_host_key_unknown").Inc()
return []string{}, fmt.Errorf("no matching known hosts entry for %s", hostport)
}
return algos, nil
}
func (t *sshTransport) getSSHClient(host string) (*trackingSshClient, error) {
host = strings.ToLower(host)
client, cached := t.sshClientPool.get(host)
if cached {
log.WithFields(log.Fields{"host": host}).Trace("using cached ssh connection")
return client, nil
}
sshAddr := fmt.Sprintf("%s:%d", host, t.port)
knownHostAlgos, err := t.getHostkeyAlgosFor(sshAddr)
if err != nil {
return nil, err
}
upgradedHostKeyAlgos := upgradeHostKeyAlgos(knownHostAlgos)
log.WithFields(log.Fields{"host": host, "HostKeyAlgorithms": upgradedHostKeyAlgos}).Trace("building ssh connection")
clientConfig := &ssh.ClientConfig{
User: t.user,
Auth: t.auth,
HostKeyCallback: t.knownHostsCallback,
HostKeyAlgorithms: upgradedHostKeyAlgos,
Timeout: timeoutDurationSeconds,
}
// TODO: This should use DialContext once this PR is merged:
// https://github.com/golang/go/issues/64686
plainClient, err := ssh.Dial("tcp", sshAddr, clientConfig)
client = &trackingSshClient{Client: plainClient}
if err == nil {
log.WithFields(log.Fields{"host": host}).Trace("caching successful ssh connection")
cachedClient, cached := t.sshClientPool.setOrGetCached(host, client)
if cached {
// we already checked above and did not have a cached client.
// however, due to concurrent requests, we may now have one.
// apparently this is the case here.
// therefore, we drop our newly created client and use the cached one
// instead.
_ = client.Close()
client = cachedClient
}
}
return client, err
}
// When reading known_host files we find key types such as ssh-rsa.
// When talking to an SSH server, we need to advertise what keys we
// can handle.
// We should not advertise ssh-rsa here, as it is insecure and deprecated.
// Instead, we should advertise the newer rsa-sha2-* methods
// which work with the same key type.
// Therefore, this function replaces ssh-rsa with rsa-sha2*.
func upgradeHostKeyAlgos(algos []string) []string {
upgraded := []string{}
for _, algo := range algos {
if algo == "ssh-rsa" {
upgraded = append(upgraded, "rsa-sha2-512")
upgraded = append(upgraded, "rsa-sha2-256")
continue
}
upgraded = append(upgraded, algo)
}
return upgraded
}
type invalidPublicKey struct{}
func (invalidPublicKey) Marshal() []byte {
return []byte("invalid public key")
}
func (invalidPublicKey) Type() string {
return "invalid public key"
}
func (invalidPublicKey) Verify(_ []byte, _ *ssh.Signature) error {
return errors.New("this key is never valid")
}