From a9bc4d8ce859be800c365e4395e669eb623b8222 Mon Sep 17 00:00:00 2001 From: Ian Booth Date: Thu, 8 Feb 2024 16:38:53 +1000 Subject: [PATCH] Fix flakey tests which use a test ssh server --- ssh/ssh_gocrypto_test.go | 117 +++++++++++++++++++++++++++++++-------- 1 file changed, 95 insertions(+), 22 deletions(-) diff --git a/ssh/ssh_gocrypto_test.go b/ssh/ssh_gocrypto_test.go index 86f20e14..344a6d4e 100644 --- a/ssh/ssh_gocrypto_test.go +++ b/ssh/ssh_gocrypto_test.go @@ -17,6 +17,7 @@ import ( "path/filepath" "regexp" "sync" + "time" "github.com/juju/testing" jc "github.com/juju/testing/checkers" @@ -38,26 +39,42 @@ type sshServer struct { client *cryptossh.Client } -func (s *sshServer) run(c *gc.C) { +func (s *sshServer) run(errorCh chan error, done chan bool) { netconn, err := s.listener.Accept() - c.Assert(err, jc.ErrorIsNil) + if err != nil { + errorCh <- fmt.Errorf("accepting connection: %w", err) + return + } defer netconn.Close() conn, chans, reqs, err := cryptossh.NewServerConn(netconn, s.cfg) - c.Assert(err, jc.ErrorIsNil) + if err != nil { + errorCh <- fmt.Errorf("getting ssh server connection: %w", err) + return + } s.client = cryptossh.NewClient(conn, chans, reqs) var wg sync.WaitGroup - defer wg.Wait() + defer func() { + wg.Wait() + close(errorCh) + }() sessionChannels := s.client.HandleChannelOpen("session") - c.Assert(sessionChannels, gc.NotNil) - for newChannel := range sessionChannels { - c.Assert(newChannel.ChannelType(), gc.Equals, "session") + select { + case <-done: + return + case newChannel := <-sessionChannels: + if sCh := newChannel.ChannelType(); sCh != "session" { + errorCh <- fmt.Errorf("unexpected session channel %q", sCh) + return + } channel, reqs, err := newChannel.Accept() - - c.Assert(err, jc.ErrorIsNil) + if err != nil { + errorCh <- fmt.Errorf("accepting session connection: %w", err) + return + } wg.Add(1) go func() { @@ -67,18 +84,30 @@ func (s *sshServer) run(c *gc.C) { for req := range reqs { switch req.Type { case "exec": - c.Assert(req.WantReply, jc.IsTrue) + if !req.WantReply { + errorCh <- fmt.Errorf("no reply wanted for request %+v", req) + return + } n := binary.BigEndian.Uint32(req.Payload[:4]) command := string(req.Payload[4 : n+4]) - c.Assert(command, gc.Equals, testCommandFlat) - req.Reply(true, nil) + if command != testCommandFlat { + errorCh <- fmt.Errorf("unexpected request command: %q", command) + return + } + err = req.Reply(true, nil) + if err != nil { + errorCh <- fmt.Errorf("error sending reply: %w", err) + return + } channel.Write([]byte("abc value\n")) _, err := channel.SendRequest("exit-status", false, cryptossh.Marshal(&struct{ n uint32 }{0})) - c.Check(err, jc.ErrorIsNil) + if err != nil { + errorCh <- fmt.Errorf("error sending request: %w", err) + } return default: - c.Errorf("Unexpected request type: %v", req.Type) + errorCh <- fmt.Errorf("unexpected request type: %q", req.Type) return } } @@ -206,6 +235,16 @@ func (s *SSHGoCryptoCommandSuite) TestClientNoKeys(c *gc.C) { c.Assert(err, gc.ErrorMatches, "ssh.Dial failed") } +func waitForServer(c *gc.C, errorCh chan error) error { + select { + case err, _ := <-errorCh: + return err + case <-time.After(testing.LongWait): + c.Fatal("timed out waiting for ssh server") + return nil + } +} + func (s *SSHGoCryptoCommandSuite) TestCommand(c *gc.C) { client, clientKey := newClient(c) server, serverKey := s.newServer(c, cryptossh.ServerConfig{}) @@ -220,7 +259,11 @@ func (s *SSHGoCryptoCommandSuite) TestCommand(c *gc.C) { checkedKey = true return nil, nil } - go server.run(c) + errorCh := make(chan error, 1) + done := make(chan bool) + defer close(done) + go server.run(errorCh, done) + out, err := cmd.Output() c.Assert(err, jc.ErrorIsNil) c.Assert(string(out), gc.Equals, "abc value\n") @@ -233,6 +276,7 @@ func (s *SSHGoCryptoCommandSuite) TestCommand(c *gc.C) { serverPort, cryptossh.MarshalAuthorizedKey(serverKey)), ) + c.Assert(waitForServer(c, errorCh), jc.ErrorIsNil) } func (s *SSHGoCryptoCommandSuite) TestCopy(c *gc.C) { @@ -262,7 +306,11 @@ func (s *SSHGoCryptoCommandSuite) TestProxyCommand(c *gc.C) { server.cfg.PublicKeyCallback = func(_ cryptossh.ConnMetadata, pubkey cryptossh.PublicKey) (*cryptossh.Permissions, error) { return nil, nil } - go server.run(c) + errorCh := make(chan error, 1) + done := make(chan bool) + defer close(done) + go server.run(errorCh, done) + out, err := cmd.Output() c.Assert(err, jc.ErrorIsNil) c.Assert(string(out), gc.Equals, "abc value\n") @@ -270,12 +318,16 @@ func (s *SSHGoCryptoCommandSuite) TestProxyCommand(c *gc.C) { data, err := ioutil.ReadFile(netcat + ".args") c.Assert(err, jc.ErrorIsNil) c.Assert(string(data), gc.Equals, fmt.Sprintf("%s -q0 127.0.0.1 %v\n", netcat, port)) + c.Assert(waitForServer(c, errorCh), jc.ErrorIsNil) } func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksYes(c *gc.C) { server, _ := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true}) serverPort := server.listener.Addr().(*net.TCPAddr).Port - go server.run(c) + errorCh := make(chan error, 1) + done := make(chan bool) + defer close(done) + go server.run(errorCh, done) var opts ssh.Options opts.SetPort(serverPort) @@ -289,12 +341,16 @@ func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksYes(c *gc.C) { )) _, err = os.Stat(s.knownHostsFile) c.Assert(err, jc.Satisfies, os.IsNotExist) + _ = waitForServer(c, errorCh) } func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskNonTerminal(c *gc.C) { server, _ := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true}) serverPort := server.listener.Addr().(*net.TCPAddr).Port - go server.run(c) + errorCh := make(chan error, 1) + done := make(chan bool) + defer close(done) + go server.run(errorCh, done) var opts ssh.Options opts.SetPort(serverPort) @@ -305,6 +361,7 @@ func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskNonTerminal(c *gc.C) { c.Assert(err, gc.ErrorMatches, "ssh: handshake failed: not running in a terminal, cannot prompt for verification") _, err = os.Stat(s.knownHostsFile) c.Assert(err, jc.Satisfies, os.IsNotExist) + _ = waitForServer(c, errorCh) } func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskTerminalYes(c *gc.C) { @@ -315,7 +372,10 @@ func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskTerminalYes(c *gc.C) { server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true}) serverPort := server.listener.Addr().(*net.TCPAddr).Port - go server.run(c) + errorCh := make(chan error, 1) + done := make(chan bool) + defer close(done) + go server.run(errorCh, done) var opts ssh.Options opts.SetPort(serverPort) @@ -338,6 +398,7 @@ The authenticity of host '127.0.0.1:%[1]d (127.0.0.1:%[1]d)' can't be establishe ssh-ed25519 key fingerprint is %[2]s. Are you sure you want to continue connecting (yes/no)? Please type 'yes' or 'no': `[1:], serverPort, cryptossh.FingerprintSHA256(serverKey))) + c.Assert(waitForServer(c, errorCh), jc.ErrorIsNil) } func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskTerminalNo(c *gc.C) { @@ -347,7 +408,10 @@ func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskTerminalNo(c *gc.C) { server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true}) serverPort := server.listener.Addr().(*net.TCPAddr).Port - go server.run(c) + errorCh := make(chan error, 1) + done := make(chan bool) + defer close(done) + go server.run(errorCh, done) var opts ssh.Options opts.SetPort(serverPort) @@ -365,6 +429,7 @@ The authenticity of host '127.0.0.1:%[1]d (127.0.0.1:%[1]d)' can't be establishe ssh-ed25519 key fingerprint is %[2]s. Are you sure you want to continue connecting (yes/no)? `[1:], serverPort, cryptossh.FingerprintSHA256(serverKey))) + _ = waitForServer(c, errorCh) } func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksNoMismatch(c *gc.C) { @@ -373,7 +438,10 @@ func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksNoMismatch(c *gc.C) { server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true}) serverPort := server.listener.Addr().(*net.TCPAddr).Port - go server.run(c) + errorCh := make(chan error, 1) + done := make(chan bool) + defer close(done) + go server.run(errorCh, done) // Write a mismatching key to the known_hosts file. Even with // StrictHostChecksNo, we should be verifying against an existing @@ -410,6 +478,7 @@ Please contact your system administrator. Add correct host key in .*/known_hosts to get rid of this message. Offending ssh-ed25519 key in .*/known_hosts:1 `[1:], regexp.QuoteMeta(cryptossh.FingerprintSHA256(serverKey)))) + _ = waitForServer(c, errorCh) } func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksDifferentKeyTypes(c *gc.C) { @@ -418,7 +487,10 @@ func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksDifferentKeyTypes(c *gc.C) server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true}) serverPort := server.listener.Addr().(*net.TCPAddr).Port - go server.run(c) + errorCh := make(chan error, 1) + done := make(chan bool) + defer close(done) + go server.run(errorCh, done) // Write a mismatching key to the known_hosts file with a different // key type. Even with StrictHostChecksNo, we should be verifying @@ -457,6 +529,7 @@ Add correct host key in .*/known_hosts to get rid of this message. Host was previously using different host key algorithms: - ssh-dss key in .*/known_hosts:1 `[1:], regexp.QuoteMeta(cryptossh.FingerprintSHA256(serverKey)))) + _ = waitForServer(c, errorCh) } type mockReadLineWriter struct {