diff --git a/communicator/config.go b/communicator/config.go index 331104b3e..16b3eb498 100644 --- a/communicator/config.go +++ b/communicator/config.go @@ -496,6 +496,16 @@ func (c *Config) prepareSSH(ctx *interpolate.Context) []error { c.SSHKeepAliveInterval = 5 * time.Second } + // Validation + var errs []error + if c.SSHPrivateKeyFile == "" && c.SSHCertificateFile != "" { + errs = append(errs, fmt.Errorf("ssh_private_key_file must be specified if ssh_certificate_file is specified")) + } + + if c.SSHBastionPrivateKeyFile == "" && c.SSHBastionCertificateFile != "" { + errs = append(errs, fmt.Errorf("ssh_bastion_private_key_file must be specified if ssh_bastion_certificate_file is specified")) + } + if c.SSHBastionHost != "" { if c.SSHBastionPort == 0 { c.SSHBastionPort = 22 @@ -503,12 +513,8 @@ func (c *Config) prepareSSH(ctx *interpolate.Context) []error { if c.SSHBastionPrivateKeyFile == "" && c.SSHPrivateKeyFile != "" { c.SSHBastionPrivateKeyFile = c.SSHPrivateKeyFile - } - - if c.SSHBastionCertificateFile == "" && c.SSHCertificateFile != "" { c.SSHBastionCertificateFile = c.SSHCertificateFile } - } if c.SSHProxyHost != "" { @@ -526,8 +532,6 @@ func (c *Config) prepareSSH(ctx *interpolate.Context) []error { c.SSHTimeout = c.SSHWaitTimeout } - // Validation - var errs []error if c.SSHUsername == "" { errs = append(errs, errors.New("An ssh_username must be specified\n Note: some builders used to default ssh_username to \"root\".")) } diff --git a/communicator/config_test.go b/communicator/config_test.go index a6371592e..654515e14 100644 --- a/communicator/config_test.go +++ b/communicator/config_test.go @@ -4,12 +4,21 @@ package communicator import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "os" "reflect" "testing" + "time" + "github.com/google/go-cmp/cmp" "github.com/hashicorp/packer-plugin-sdk/multistep" "github.com/hashicorp/packer-plugin-sdk/template/interpolate" "github.com/masterzen/winrm" + "golang.org/x/crypto/ssh" ) func testConfig() *Config { @@ -142,28 +151,355 @@ func TestConfig_winrm_use_ntlm(t *testing.T) { } -func TestSSHBastion(t *testing.T) { - c := &Config{ - Type: "ssh", - SSH: SSH{ - SSHUsername: "root", - SSHBastionHost: "mybastionhost.company.com", - SSHBastionPassword: "test", - }, +// generateSSHPrivateKey generates a new RSA SSH private key for use in tests +// +// It returns the path in which the key was created. +// Removing the key after testing is the caller's responsibility. +func generateSSHPrivateKey() (path string, signer ssh.Signer, err error) { + pk, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + err = fmt.Errorf("failed to generate key: %s", err) + return } - if err := c.Prepare(testContext(t)); len(err) > 0 { - t.Fatalf("bad: %#v", err) + sshKeyFile, err := os.CreateTemp("", "") + if err != nil { + err = fmt.Errorf("failed to open a temp file: %s", err) + return } - if c.SSHBastionCertificateFile != "" { - t.Fatalf("Identity certificate somehow set") + defer sshKeyFile.Close() + + path = sshKeyFile.Name() + + rawPkey := x509.MarshalPKCS1PrivateKey(pk) + + err = pem.Encode(sshKeyFile, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: rawPkey, + }) + if err != nil { + err = fmt.Errorf("failed to encode to PEM: %s", err) + return + } + + signer, err = ssh.NewSignerFromKey(pk) + if err != nil { + err = fmt.Errorf("failed to create SSH signer: %s", err) + return + } + + return +} + +// generateSSHKeys generates a new SSH key, the CA key and a cert linked to the SSH key for use in tests +// +// It returns the paths in which the keys and cert were created. +// Removing the keys and certs after testing is the caller's responsibility. +func generateSSHKeys() ( + privKeyPath string, + certKeyPath string, + certPath string, + err error, +) { + var sshPrivKey, certSSHKey ssh.Signer + + privKeyPath, sshPrivKey, err = generateSSHPrivateKey() + if err != nil { + err = fmt.Errorf("failed to generate private key: %s", err) + return + } + + certKeyPath, certSSHKey, err = generateSSHPrivateKey() + if err != nil { + err = fmt.Errorf("failed to generate CA private key: %s", err) + return + } + + cert := &ssh.Certificate{ + CertType: ssh.HostCert, + Key: sshPrivKey.PublicKey(), + ValidAfter: 0, + ValidBefore: ssh.CertTimeInfinity, + KeyId: "TestSSHCert", + ValidPrincipals: []string{"authority.example.com"}, } - if c.SSHPrivateKeyFile != "" { - t.Fatalf("Private key file somehow set") + certFile, err := os.CreateTemp("", "") + if err != nil { + err = fmt.Errorf("failed to create cert file: %s", err) + return } + defer certFile.Close() + certPath = certFile.Name() + + err = cert.SignCert(rand.Reader, certSSHKey) + if err != nil { + err = fmt.Errorf("failed to sign cert: %s", err) + return + } + + rawCert := ssh.MarshalAuthorizedKey(cert) + + _, err = certFile.Write(rawCert) + if err != nil { + err = fmt.Errorf("failed to write marshalled certificate: %s", err) + } + + return +} + +func TestSSHBastion(t *testing.T) { + privKeyPath, certKeyPath, certPath, err := generateSSHKeys() + if err != nil { + t.Fatalf("failed to generate SSH keys and certificates: %s", err) + } + + defer func() { + os.Remove(privKeyPath) + os.Remove(certKeyPath) + os.Remove(certPath) + }() + + t.Logf("generated private key (%q), CA key (%q), certificate (%q)", privKeyPath, certKeyPath, certPath) + + bastionPrivKeyPath, bastionCertKeyPath, bastionCertPath, err := generateSSHKeys() + if err != nil { + t.Fatalf("failed to generate bastion SSH keys and certificates: %s", err) + } + + defer func() { + os.Remove(bastionPrivKeyPath) + os.Remove(bastionCertKeyPath) + os.Remove(bastionCertPath) + }() + + t.Logf("generated bastion private key (%q), CA key (%q), certificate (%q)", bastionPrivKeyPath, bastionCertKeyPath, bastionCertPath) + + testcases := []struct { + name string + config *Config + expectedConfig *Config + expectError bool + }{ + { + "OK - with host and password", + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionHost: "mybastionhost.company.com", + SSHBastionPassword: "test", + }, + }, + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionHost: "mybastionhost.company.com", + SSHBastionPassword: "test", + SSHPort: 22, + SSHTimeout: time.Minute * 5, + SSHFileTransferMethod: "scp", + SSHKeepAliveInterval: time.Second * 5, + SSHHandshakeAttempts: 10, + SSHBastionPort: 22, + }, + }, + false, + }, + { + "OK - bastion config with bastion SSH private key", + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionHost: "my.bastion", + SSHBastionPrivateKeyFile: bastionPrivKeyPath, + }, + }, + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionHost: "my.bastion", + SSHBastionPrivateKeyFile: bastionPrivKeyPath, + SSHPort: 22, + SSHTimeout: time.Minute * 5, + SSHFileTransferMethod: "scp", + SSHKeepAliveInterval: time.Second * 5, + SSHHandshakeAttempts: 10, + SSHBastionPort: 22, + }, + }, + false, + }, + { + "OK - bastion config with SSH private key, bastion key should be the same as SSH key", + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionHost: "my.bastion", + SSHPrivateKeyFile: privKeyPath, + }, + }, + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionHost: "my.bastion", + SSHBastionPrivateKeyFile: privKeyPath, + SSHPort: 22, + SSHTimeout: time.Minute * 5, + SSHFileTransferMethod: "scp", + SSHKeepAliveInterval: time.Second * 5, + SSHHandshakeAttempts: 10, + SSHBastionPort: 22, + SSHPrivateKeyFile: privKeyPath, + }, + }, + false, + }, + { + "OK - bastion config with SSH private key and cert, bastion should have both set", + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionHost: "my.bastion", + SSHPrivateKeyFile: privKeyPath, + SSHCertificateFile: certPath, + }, + }, + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionHost: "my.bastion", + SSHBastionPrivateKeyFile: privKeyPath, + SSHBastionCertificateFile: certPath, + SSHPort: 22, + SSHTimeout: time.Minute * 5, + SSHFileTransferMethod: "scp", + SSHKeepAliveInterval: time.Second * 5, + SSHHandshakeAttempts: 10, + SSHBastionPort: 22, + SSHPrivateKeyFile: privKeyPath, + SSHCertificateFile: certPath, + }, + }, + false, + }, + { + "OK - bastion config with SSH private key and cert, and a bastion private key, bastion cert should not be set", + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionHost: "my.bastion", + SSHBastionPrivateKeyFile: bastionPrivKeyPath, + SSHPrivateKeyFile: privKeyPath, + SSHCertificateFile: certPath, + }, + }, + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionHost: "my.bastion", + SSHBastionPrivateKeyFile: bastionPrivKeyPath, + SSHPort: 22, + SSHTimeout: time.Minute * 5, + SSHFileTransferMethod: "scp", + SSHKeepAliveInterval: time.Second * 5, + SSHHandshakeAttempts: 10, + SSHBastionPort: 22, + SSHPrivateKeyFile: privKeyPath, + SSHCertificateFile: certPath, + }, + }, + false, + }, + { + "OK - bastion config with SSH private key and cert, and a bastion private key and cert", + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionHost: "my.bastion", + SSHBastionPrivateKeyFile: bastionPrivKeyPath, + SSHBastionCertificateFile: bastionCertPath, + SSHPrivateKeyFile: privKeyPath, + SSHCertificateFile: certPath, + }, + }, + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionHost: "my.bastion", + SSHBastionPrivateKeyFile: bastionPrivKeyPath, + SSHBastionCertificateFile: bastionCertPath, + SSHPort: 22, + SSHTimeout: time.Minute * 5, + SSHFileTransferMethod: "scp", + SSHKeepAliveInterval: time.Second * 5, + SSHHandshakeAttempts: 10, + SSHBastionPort: 22, + SSHPrivateKeyFile: privKeyPath, + SSHCertificateFile: certPath, + }, + }, + false, + }, + { + "Fail - ssh certificate file specified without an ssh private key file", + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHCertificateFile: certPath, + }, + }, + nil, + true, + }, + { + "Fail - ssh bastion certificate file specified without an ssh bastion private key file", + &Config{ + Type: "ssh", + SSH: SSH{ + SSHUsername: "root", + SSHBastionCertificateFile: certPath, + }, + }, + nil, + true, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + errs := tt.config.Prepare(testContext(t)) + + for _, err := range errs { + t.Logf("%s", err) + } + if (len(errs) != 0) != tt.expectError { + t.Fatalf("Expected %t error, got %d", tt.expectError, len(errs)) + } + if tt.expectError { + return + } + + diff := cmp.Diff(tt.config, tt.expectedConfig) + if diff != "" { + t.Errorf(diff) + } + }) + } } func TestSSHConfigFunc_ciphers(t *testing.T) {