Skip to content

Commit

Permalink
Fix flakey tests which use a test ssh server
Browse files Browse the repository at this point in the history
  • Loading branch information
wallyworld committed Feb 8, 2024
1 parent e867977 commit a9bc4d8
Showing 1 changed file with 95 additions and 22 deletions.
117 changes: 95 additions & 22 deletions ssh/ssh_gocrypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"path/filepath"
"regexp"
"sync"
"time"

"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
Expand All @@ -38,26 +39,42 @@ type sshServer struct {
client *cryptossh.Client
}

func (s *sshServer) run(c *gc.C) {
func (s *sshServer) run(errorCh chan error, done chan bool) {
netconn, err := s.listener.Accept()
c.Assert(err, jc.ErrorIsNil)
if err != nil {
errorCh <- fmt.Errorf("accepting connection: %w", err)
return
}
defer netconn.Close()

conn, chans, reqs, err := cryptossh.NewServerConn(netconn, s.cfg)
c.Assert(err, jc.ErrorIsNil)
if err != nil {
errorCh <- fmt.Errorf("getting ssh server connection: %w", err)
return
}
s.client = cryptossh.NewClient(conn, chans, reqs)

var wg sync.WaitGroup
defer wg.Wait()
defer func() {
wg.Wait()
close(errorCh)
}()

sessionChannels := s.client.HandleChannelOpen("session")
c.Assert(sessionChannels, gc.NotNil)
for newChannel := range sessionChannels {
c.Assert(newChannel.ChannelType(), gc.Equals, "session")
select {
case <-done:
return
case newChannel := <-sessionChannels:
if sCh := newChannel.ChannelType(); sCh != "session" {
errorCh <- fmt.Errorf("unexpected session channel %q", sCh)
return
}

channel, reqs, err := newChannel.Accept()

c.Assert(err, jc.ErrorIsNil)
if err != nil {
errorCh <- fmt.Errorf("accepting session connection: %w", err)
return
}
wg.Add(1)

go func() {
Expand All @@ -67,18 +84,30 @@ func (s *sshServer) run(c *gc.C) {
for req := range reqs {
switch req.Type {
case "exec":
c.Assert(req.WantReply, jc.IsTrue)
if !req.WantReply {
errorCh <- fmt.Errorf("no reply wanted for request %+v", req)
return
}
n := binary.BigEndian.Uint32(req.Payload[:4])
command := string(req.Payload[4 : n+4])
c.Assert(command, gc.Equals, testCommandFlat)
req.Reply(true, nil)
if command != testCommandFlat {
errorCh <- fmt.Errorf("unexpected request command: %q", command)
return
}
err = req.Reply(true, nil)
if err != nil {
errorCh <- fmt.Errorf("error sending reply: %w", err)
return
}
channel.Write([]byte("abc value\n"))
_, err := channel.SendRequest("exit-status", false, cryptossh.Marshal(&struct{ n uint32 }{0}))
c.Check(err, jc.ErrorIsNil)
if err != nil {
errorCh <- fmt.Errorf("error sending request: %w", err)
}
return

default:
c.Errorf("Unexpected request type: %v", req.Type)
errorCh <- fmt.Errorf("unexpected request type: %q", req.Type)
return
}
}
Expand Down Expand Up @@ -206,6 +235,16 @@ func (s *SSHGoCryptoCommandSuite) TestClientNoKeys(c *gc.C) {
c.Assert(err, gc.ErrorMatches, "ssh.Dial failed")
}

func waitForServer(c *gc.C, errorCh chan error) error {
select {
case err, _ := <-errorCh:
return err
case <-time.After(testing.LongWait):
c.Fatal("timed out waiting for ssh server")
return nil
}
}

func (s *SSHGoCryptoCommandSuite) TestCommand(c *gc.C) {
client, clientKey := newClient(c)
server, serverKey := s.newServer(c, cryptossh.ServerConfig{})
Expand All @@ -220,7 +259,11 @@ func (s *SSHGoCryptoCommandSuite) TestCommand(c *gc.C) {
checkedKey = true
return nil, nil
}
go server.run(c)
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)

out, err := cmd.Output()
c.Assert(err, jc.ErrorIsNil)
c.Assert(string(out), gc.Equals, "abc value\n")
Expand All @@ -233,6 +276,7 @@ func (s *SSHGoCryptoCommandSuite) TestCommand(c *gc.C) {
serverPort,
cryptossh.MarshalAuthorizedKey(serverKey)),
)
c.Assert(waitForServer(c, errorCh), jc.ErrorIsNil)
}

func (s *SSHGoCryptoCommandSuite) TestCopy(c *gc.C) {
Expand Down Expand Up @@ -262,20 +306,28 @@ func (s *SSHGoCryptoCommandSuite) TestProxyCommand(c *gc.C) {
server.cfg.PublicKeyCallback = func(_ cryptossh.ConnMetadata, pubkey cryptossh.PublicKey) (*cryptossh.Permissions, error) {
return nil, nil
}
go server.run(c)
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)

out, err := cmd.Output()
c.Assert(err, jc.ErrorIsNil)
c.Assert(string(out), gc.Equals, "abc value\n")
// Ensure the proxy command was executed with the appropriate arguments.
data, err := ioutil.ReadFile(netcat + ".args")
c.Assert(err, jc.ErrorIsNil)
c.Assert(string(data), gc.Equals, fmt.Sprintf("%s -q0 127.0.0.1 %v\n", netcat, port))
c.Assert(waitForServer(c, errorCh), jc.ErrorIsNil)
}

func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksYes(c *gc.C) {
server, _ := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
go server.run(c)
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)

var opts ssh.Options
opts.SetPort(serverPort)
Expand All @@ -289,12 +341,16 @@ func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksYes(c *gc.C) {
))
_, err = os.Stat(s.knownHostsFile)
c.Assert(err, jc.Satisfies, os.IsNotExist)
_ = waitForServer(c, errorCh)
}

func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskNonTerminal(c *gc.C) {
server, _ := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
go server.run(c)
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)

var opts ssh.Options
opts.SetPort(serverPort)
Expand All @@ -305,6 +361,7 @@ func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskNonTerminal(c *gc.C) {
c.Assert(err, gc.ErrorMatches, "ssh: handshake failed: not running in a terminal, cannot prompt for verification")
_, err = os.Stat(s.knownHostsFile)
c.Assert(err, jc.Satisfies, os.IsNotExist)
_ = waitForServer(c, errorCh)
}

func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskTerminalYes(c *gc.C) {
Expand All @@ -315,7 +372,10 @@ func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskTerminalYes(c *gc.C) {

server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
go server.run(c)
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)

var opts ssh.Options
opts.SetPort(serverPort)
Expand All @@ -338,6 +398,7 @@ The authenticity of host '127.0.0.1:%[1]d (127.0.0.1:%[1]d)' can't be establishe
ssh-ed25519 key fingerprint is %[2]s.
Are you sure you want to continue connecting (yes/no)? Please type 'yes' or 'no': `[1:],
serverPort, cryptossh.FingerprintSHA256(serverKey)))
c.Assert(waitForServer(c, errorCh), jc.ErrorIsNil)
}

func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskTerminalNo(c *gc.C) {
Expand All @@ -347,7 +408,10 @@ func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskTerminalNo(c *gc.C) {

server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
go server.run(c)
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)

var opts ssh.Options
opts.SetPort(serverPort)
Expand All @@ -365,6 +429,7 @@ The authenticity of host '127.0.0.1:%[1]d (127.0.0.1:%[1]d)' can't be establishe
ssh-ed25519 key fingerprint is %[2]s.
Are you sure you want to continue connecting (yes/no)? `[1:],
serverPort, cryptossh.FingerprintSHA256(serverKey)))
_ = waitForServer(c, errorCh)
}

func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksNoMismatch(c *gc.C) {
Expand All @@ -373,7 +438,10 @@ func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksNoMismatch(c *gc.C) {

server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
go server.run(c)
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)

// Write a mismatching key to the known_hosts file. Even with
// StrictHostChecksNo, we should be verifying against an existing
Expand Down Expand Up @@ -410,6 +478,7 @@ Please contact your system administrator.
Add correct host key in .*/known_hosts to get rid of this message.
Offending ssh-ed25519 key in .*/known_hosts:1
`[1:], regexp.QuoteMeta(cryptossh.FingerprintSHA256(serverKey))))
_ = waitForServer(c, errorCh)
}

func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksDifferentKeyTypes(c *gc.C) {
Expand All @@ -418,7 +487,10 @@ func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksDifferentKeyTypes(c *gc.C)

server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
go server.run(c)
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)

// Write a mismatching key to the known_hosts file with a different
// key type. Even with StrictHostChecksNo, we should be verifying
Expand Down Expand Up @@ -457,6 +529,7 @@ Add correct host key in .*/known_hosts to get rid of this message.
Host was previously using different host key algorithms:
- ssh-dss key in .*/known_hosts:1
`[1:], regexp.QuoteMeta(cryptossh.FingerprintSHA256(serverKey))))
_ = waitForServer(c, errorCh)
}

type mockReadLineWriter struct {
Expand Down

0 comments on commit a9bc4d8

Please sign in to comment.