Skip to content

Commit

Permalink
Merge pull request #128 from bmc-toolbox/fix-race-condition-creating-…
Browse files Browse the repository at this point in the history
…sshclient

PSM-2322 Fix a race condition in creating SSHClient
  • Loading branch information
atrubachev authored Apr 21, 2020
2 parents 03a62f1 + a612a03 commit c377a45
Show file tree
Hide file tree
Showing 23 changed files with 600 additions and 765 deletions.
146 changes: 82 additions & 64 deletions internal/sshclient/sshclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,101 +3,119 @@ package sshclient
import (
"fmt"
"net"
"strings"
"sync"
"time"
"unicode"

"golang.org/x/crypto/ssh"
)

const (
// PowerOff defines the action of powering off a device
PowerOff = "poweroff"
// PowerOn defines the action of powering on a device
PowerOn = "poweron"
// PowerCycle defines the action of power cycle a device
PowerCycle = "powercycle"
// HardReset defines the action of hard reset a device
HardReset = "hardreset"
// Reseat defines the action of power reseat a device
Reseat = "reseat"
// IsOn defines the current power status of a device
IsOn = "ison"
// PowerCycleBmc the action of power cycle the bmc of a device
PowerCycleBmc = "powercyclebmc"
// PxeOnce the action of pxe once a device
PxeOnce = "pxeonce"
clientTimeout = 15 * time.Second
sshPort = "22"
)

// SSHClient implements out commom abstraction for ssh
// SSHClient implements out common abstraction for SSH
type SSHClient struct {
addr string
config *ssh.ClientConfig
client *ssh.Client
lock *sync.Mutex
}

// Sleep transforms a sleep statement in a sleep-able time
func Sleep(sleep string) (err error) {
sleep = strings.Replace(sleep, "sleep ", "", 1)
s, err := time.ParseDuration(sleep)
// New creates a new SSH client
func New(addr string, username string, password string) (*SSHClient, error) {
cfg := &ssh.ClientConfig{
User: username,
Auth: []ssh.AuthMethod{
ssh.Password(password),
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
return []string{}, nil
}),
},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
Timeout: clientTimeout,
}

addr, err := checkAndBuildAddr(addr)
if err != nil {
return fmt.Errorf("error sleeping: %v", err)
return nil, err
}
time.Sleep(s)

return err
return &SSHClient{addr: addr, config: cfg, lock: new(sync.Mutex)}, nil
}

// Run execute the given command and returns a string with the output
func (s *SSHClient) Run(command string) (result string, err error) {
// Run executes the given command and returns the output as a string
func (s *SSHClient) Run(command string) (string, error) {
if err := s.createClient(); err != nil {
return "", err
}

return s.run(command)
}

func (s *SSHClient) run(command string) (string, error) {
session, err := s.client.NewSession()
if err != nil {
return result, err
return "", err
}
defer session.Close()

output, err := session.CombinedOutput(command)
if err != nil {
return string(output), err
}

return string(output), err
}

// IsntLetterOrNumber check if the give rune is not a letter nor a number
func IsntLetterOrNumber(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c)
// Close sends "exit" command and closes the SSH connection
func (s *SSHClient) Close() error {
s.lock.Lock()
defer s.lock.Unlock()

if s.client == nil {
return nil
}
defer func() {
s.client.Close()
s.client = nil
}()

// some vendors have issues with the bmc if you don't do it
if _, err := s.run("exit"); err != nil {
return err
}

return nil
}

// New returns a new configured ssh client
func New(host string, username string, password string) (connection *SSHClient, err error) {
if !strings.Contains(host, ":") {
host = fmt.Sprintf("%s:22", host)
func (s *SSHClient) createClient() error {
s.lock.Lock()
defer s.lock.Unlock()
if s.client != nil {
return nil // TODO: check client is alive
}
c, err := ssh.Dial(
"tcp",
host,
&ssh.ClientConfig{
User: username,
Auth: []ssh.AuthMethod{
ssh.Password(password),
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
return []string{}, nil
}),
},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
Timeout: 15 * time.Second,
},
)

c, err := ssh.Dial("tcp", s.addr, s.config)
if err != nil {
return connection, fmt.Errorf("unable to connect to bmc: %v", err)
return fmt.Errorf("unable to connect to bmc: %w", err)
}
return &SSHClient{c}, err
s.client = c

return nil
}

// Close closed the ssh connection and ensure to always exit, some vendors will have issues with the bmc if you dont do it
func (s *SSHClient) Close() (err error) {
defer s.client.Close()
_, err = s.Run("exit")
return err
func checkAndBuildAddr(addr string) (string, error) {
if addr == "" {
return "", fmt.Errorf("address is empty")
}

if _, _, err := net.SplitHostPort(addr); err == nil {
return addr, nil
}

addrWithPort := net.JoinHostPort(addr, sshPort)
if _, _, err := net.SplitHostPort(addrWithPort); err == nil {
return addrWithPort, nil
}

return "", fmt.Errorf("failed to parse address %q", addr)
}
85 changes: 85 additions & 0 deletions internal/sshclient/sshclient_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package sshclient

import "testing"

func Test_checkAndBuildAddr(t *testing.T) {
type args struct {
addr string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "OK: only IPv4 address",
args: args{
"127.0.0.1",
},
want: "127.0.0.1:22",
wantErr: false,
},
{
name: "OK: only IPv6 address",
args: args{
"fe80::1",
},
want: "[fe80::1]:22",
wantErr: false,
},
{
name: "OK: only host",
args: args{
"localhost",
},
want: "localhost:22",
wantErr: false,
},
{
name: "OK: IPv4 address with port",
args: args{
"127.0.0.1:2222",
},
want: "127.0.0.1:2222",
wantErr: false,
},
{
name: "OK: IPv6 address with port",
args: args{
"[fe80::1]:2222",
},
want: "[fe80::1]:2222",
wantErr: false,
},
{
name: "OK: host with port",
args: args{
"localhost:2222",
},
want: "localhost:2222",
wantErr: false,
},
{
name: "Not OK: empty addr",
args: args{
"",
},
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := checkAndBuildAddr(tt.args.addr)
t.Log(got)
if (err != nil) != tt.wantErr {
t.Errorf("checkAndBuildAddr() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("checkAndBuildAddr() got = %v, want %v", got, tt.want)
}
})
}
}
8 changes: 8 additions & 0 deletions internal/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package internal

import "unicode"

// IsntLetterOrNumber check if the give rune is not a letter nor a number
func IsntLetterOrNumber(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c)
}
Loading

0 comments on commit c377a45

Please sign in to comment.