Skip to content

Commit

Permalink
fix: rename struct and deprecated the old name (#13)
Browse files Browse the repository at this point in the history
* fix: rename struct and deprecated the old name

`SSHKeyPair` is way too long and unnecessary

* Update keygen.go

* refactor: remove the rest of the SSHKeyPair mentions

---------

Co-authored-by: bashbunni <[email protected]>
  • Loading branch information
aymanbagabas and bashbunni authored Aug 21, 2023
1 parent 85d0702 commit 54993c5
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 44 deletions.
76 changes: 40 additions & 36 deletions keygen.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ type SSHKeysAlreadyExistErr struct {
}

// SSHKeyPair holds a pair of SSH keys and associated methods.
type SSHKeyPair struct {
// Deprecated: Use KeyPair instead.
type SSHKeyPair = KeyPair

// KeyPair holds a pair of SSH keys and associated methods.
type KeyPair struct {
path string // private key filename path; public key will have .pub appended
writeKeys bool
passphrase []byte
Expand All @@ -95,43 +99,43 @@ type SSHKeyPair struct {
privateKey crypto.PrivateKey
}

func (s SSHKeyPair) privateKeyPath() string {
func (s KeyPair) privateKeyPath() string {
return s.path
}

func (s SSHKeyPair) publicKeyPath() string {
func (s KeyPair) publicKeyPath() string {
return s.privateKeyPath() + ".pub"
}

// Option is a functional option for SSHKeyPair.
type Option func(*SSHKeyPair)
// Option is a functional option for KeyPair.
type Option func(*KeyPair)

// WithPassphrase sets the passphrase for the private key.
func WithPassphrase(passphrase string) Option {
return func(s *SSHKeyPair) {
return func(s *KeyPair) {
s.passphrase = []byte(passphrase)
}
}

// WithKeyType sets the key type for the key pair.
// Available key types are RSA, Ed25519, and ECDSA.
func WithKeyType(keyType KeyType) Option {
return func(s *SSHKeyPair) {
return func(s *KeyPair) {
s.keyType = keyType
}
}

// WithBitSize sets the key size for the RSA key pair.
// This option is ignored for other key types.
func WithBitSize(bits int) Option {
return func(s *SSHKeyPair) {
return func(s *KeyPair) {
s.rsaBitSize = bits
}
}

// WithWrite writes the key pair to disk if it doesn't exist.
func WithWrite() Option {
return func(s *SSHKeyPair) {
return func(s *KeyPair) {
s.writeKeys = true
}
}
Expand All @@ -141,19 +145,19 @@ func WithWrite() Option {
// The default curve is P-384.
// This option is ignored for other key types.
func WithEllipticCurve(curve elliptic.Curve) Option {
return func(s *SSHKeyPair) {
return func(s *KeyPair) {
s.ec = curve
}
}

// New generates an SSHKeyPair, which contains a pair of SSH keys.
// New generates a KeyPair, which contains a pair of SSH keys.
//
// If the key pair already exists, it will be loaded from disk, otherwise, a
// new SSH key pair is generated.
// If no key type is specified, Ed25519 will be used.
func New(path string, opts ...Option) (*SSHKeyPair, error) {
func New(path string, opts ...Option) (*KeyPair, error) {
var err error
s := &SSHKeyPair{
s := &KeyPair{
path: expandPath(path),
rsaBitSize: rsaDefaultBits,
ec: elliptic.P384(),
Expand Down Expand Up @@ -228,7 +232,7 @@ func New(path string, opts ...Option) (*SSHKeyPair, error) {
}

// PrivateKey returns the unencrypted crypto.PrivateKey.
func (s *SSHKeyPair) PrivateKey() crypto.PrivateKey {
func (s *KeyPair) PrivateKey() crypto.PrivateKey {
switch s.keyType {
case RSA, Ed25519, ECDSA:
return s.privateKey
Expand All @@ -237,7 +241,7 @@ func (s *SSHKeyPair) PrivateKey() crypto.PrivateKey {
}
}

// Ensure that SSHKeyPair implements crypto.Signer.
// Ensure that KeyPair implements crypto.Signer.
// This is used to ensure that the private key is a valid crypto.Signer to be
// passed to ssh.NewSignerFromKey.
var (
Expand All @@ -247,18 +251,18 @@ var (
)

// Signer returns an ssh.Signer for the key pair.
func (s *SSHKeyPair) Signer() ssh.Signer {
func (s *KeyPair) Signer() ssh.Signer {
sk, _ := ssh.NewSignerFromKey(s.PrivateKey())
return sk
}

// PublicKey returns the ssh.PublicKey for the key pair.
func (s *SSHKeyPair) PublicKey() ssh.PublicKey {
func (s *KeyPair) PublicKey() ssh.PublicKey {
p, _ := ssh.NewPublicKey(s.cryptoPublicKey())
return p
}

func (s *SSHKeyPair) cryptoPublicKey() crypto.PublicKey {
func (s *KeyPair) cryptoPublicKey() crypto.PublicKey {
switch s.keyType {
case RSA:
key, ok := s.privateKey.(*rsa.PrivateKey)
Expand All @@ -284,13 +288,13 @@ func (s *SSHKeyPair) cryptoPublicKey() crypto.PublicKey {
}

// CryptoPublicKey returns the crypto.PublicKey of the SSH key pair.
func (s *SSHKeyPair) CryptoPublicKey() crypto.PublicKey {
func (s *KeyPair) CryptoPublicKey() crypto.PublicKey {
return s.cryptoPublicKey()
}

// RawAuthorizedKey returns the underlying SSH public key (RFC 4253) in OpenSSH
// authorized_keys format.
func (s *SSHKeyPair) RawAuthorizedKey() []byte {
func (s *KeyPair) RawAuthorizedKey() []byte {
bts, err := os.ReadFile(s.publicKeyPath())
if err != nil {
return []byte(s.AuthorizedKey())
Expand All @@ -313,7 +317,7 @@ func (s *SSHKeyPair) RawAuthorizedKey() []byte {
return []byte(ak)
}

func (s *SSHKeyPair) authorizedKey(pk ssh.PublicKey) string {
func (s *KeyPair) authorizedKey(pk ssh.PublicKey) string {
if pk == nil {
return ""
}
Expand All @@ -324,22 +328,22 @@ func (s *SSHKeyPair) authorizedKey(pk ssh.PublicKey) string {

// AuthorizedKey returns the SSH public key (RFC 4253) in OpenSSH authorized_keys
// format. The returned string is trimmed of sshd options and comments.
func (s *SSHKeyPair) AuthorizedKey() string {
func (s *KeyPair) AuthorizedKey() string {
return s.authorizedKey(s.PublicKey())
}

// RawPrivateKey returns the raw unencrypted private key bytes in PEM format.
func (s *SSHKeyPair) RawPrivateKey() []byte {
func (s *KeyPair) RawPrivateKey() []byte {
return s.rawPrivateKey(nil)
}

// RawProtectedPrivateKey returns the raw password protected private key bytes
// in PEM format.
func (s *SSHKeyPair) RawProtectedPrivateKey() []byte {
func (s *KeyPair) RawProtectedPrivateKey() []byte {
return s.rawPrivateKey(s.passphrase)
}

func (s *SSHKeyPair) rawPrivateKey(pass []byte) []byte {
func (s *KeyPair) rawPrivateKey(pass []byte) []byte {
block, err := s.pemBlock(pass)
if err != nil {
return nil
Expand All @@ -348,7 +352,7 @@ func (s *SSHKeyPair) rawPrivateKey(pass []byte) []byte {
return pem.EncodeToMemory(block)
}

func (s *SSHKeyPair) pemBlock(passphrase []byte) (*pem.Block, error) {
func (s *KeyPair) pemBlock(passphrase []byte) (*pem.Block, error) {
key := s.PrivateKey()

Check failure on line 356 in keygen.go

View workflow job for this annotation

GitHub Actions / lint

variable 'key' is only used in the if-statement (keygen.go:357:2); consider using short syntax (ifshort)
if key == nil {
return nil, ErrMissingSSHKeys
Expand All @@ -365,7 +369,7 @@ func (s *SSHKeyPair) pemBlock(passphrase []byte) (*pem.Block, error) {
}

// generateEd25519Keys creates a pair of EdD25519 keys for SSH auth.
func (s *SSHKeyPair) generateEd25519Keys() error {
func (s *KeyPair) generateEd25519Keys() error {
// Generate keys
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
Expand All @@ -377,7 +381,7 @@ func (s *SSHKeyPair) generateEd25519Keys() error {
}

// generateEd25519Keys creates a pair of EdD25519 keys for SSH auth.
func (s *SSHKeyPair) generateECDSAKeys(curve elliptic.Curve) error {
func (s *KeyPair) generateECDSAKeys(curve elliptic.Curve) error {
// Generate keys
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
Expand All @@ -388,7 +392,7 @@ func (s *SSHKeyPair) generateECDSAKeys(curve elliptic.Curve) error {
}

// generateRSAKeys creates a pair for RSA keys for SSH auth.
func (s *SSHKeyPair) generateRSAKeys(bitSize int) error {
func (s *KeyPair) generateRSAKeys(bitSize int) error {
// Generate private key
privateKey, err := rsa.GenerateKey(rand.Reader, bitSize)
if err != nil {
Expand All @@ -408,7 +412,7 @@ func (s *SSHKeyPair) generateRSAKeys(bitSize int) error {
// the SSH directory we're going to write our keys to (for example, ~/.ssh) as
// well as make sure that no files exist at the location in which we're going
// to write out keys.
func (s *SSHKeyPair) prepFilesystem() error {
func (s *KeyPair) prepFilesystem() error {
var err error

keyDir := filepath.Dir(s.path)
Expand All @@ -421,7 +425,7 @@ func (s *SSHKeyPair) prepFilesystem() error {
info, err := os.Stat(keyDir)
if os.IsNotExist(err) {
// Directory doesn't exist: create it
return os.MkdirAll(keyDir, 0700)
return os.MkdirAll(keyDir, 0o700)
}
if err != nil {
// There was another error statting the directory; something is awry
Expand All @@ -431,9 +435,9 @@ func (s *SSHKeyPair) prepFilesystem() error {
// It exists but it's not a directory
return FilesystemErr{Err: fmt.Errorf("%s is not a directory", keyDir)}
}
if info.Mode().Perm() != 0700 {
if info.Mode().Perm() != 0o700 {
// Permissions are wrong: fix 'em
if err := os.Chmod(keyDir, 0700); err != nil {
if err := os.Chmod(keyDir, 0o700); err != nil {
return FilesystemErr{Err: err}
}
}
Expand All @@ -452,7 +456,7 @@ func (s *SSHKeyPair) prepFilesystem() error {
}

// WriteKeys writes the SSH key pair to disk.
func (s *SSHKeyPair) WriteKeys() error {
func (s *KeyPair) WriteKeys() error {
var err error
priv := s.RawProtectedPrivateKey()
if priv == nil {
Expand All @@ -479,13 +483,13 @@ func (s *SSHKeyPair) WriteKeys() error {
}

// KeyPairExists checks if the SSH key pair exists on disk.
func (s *SSHKeyPair) KeyPairExists() bool {
func (s *KeyPair) KeyPairExists() bool {
return fileExists(s.privateKeyPath())
}

func writeKeyToFile(keyBytes []byte, path string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
return ioutil.WriteFile(path, keyBytes, 0600)
return ioutil.WriteFile(path, keyBytes, 0o600)
}
return FilesystemErr{Err: fmt.Errorf("file %s already exists", path)}
}
Expand Down
16 changes: 8 additions & 8 deletions keygen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"testing"
)

func TestNewSSHKeyPair(t *testing.T) {
func TestNewKeyPair(t *testing.T) {
kp, err := New("")
if err != nil {
t.Errorf("error creating SSH key pair: %v", err)
Expand All @@ -20,7 +20,7 @@ func TestNewSSHKeyPair(t *testing.T) {
}
}

func nilTest(t testing.TB, kp *SSHKeyPair) {
func nilTest(t testing.TB, kp *KeyPair) {
t.Helper()
if kp == nil {
t.Error("expected key pair to be non-nil")
Expand All @@ -45,7 +45,7 @@ func nilTest(t testing.TB, kp *SSHKeyPair) {
}
}

func TestNilSSHKeyPair(t *testing.T) {
func TestNilKeyPair(t *testing.T) {
for _, kt := range []KeyType{RSA, Ed25519, ECDSA} {
t.Run(fmt.Sprintf("test nil key pair for %s", kt), func(t *testing.T) {
kp, err := New("", WithKeyType(kt))
Expand All @@ -57,7 +57,7 @@ func TestNilSSHKeyPair(t *testing.T) {
}
}

func TestNilSSHKeyPairWithPassphrase(t *testing.T) {
func TestNilKeyPairWithPassphrase(t *testing.T) {
for _, kt := range []KeyType{RSA, Ed25519, ECDSA} {
t.Run(fmt.Sprintf("test nil key pair for %s", kt), func(t *testing.T) {
kp, err := New("", WithKeyType(kt), WithPassphrase("test"))
Expand All @@ -69,7 +69,7 @@ func TestNilSSHKeyPairWithPassphrase(t *testing.T) {
}
}

func TestNilSSHKeyPairTestdata(t *testing.T) {
func TestNilKeyPairTestdata(t *testing.T) {
for _, kt := range []KeyType{RSA, Ed25519, ECDSA} {
t.Run(fmt.Sprintf("test nil key pair for %s", kt), func(t *testing.T) {
kp, err := New(filepath.Join("testdata", "test_"+kt.String()), WithPassphrase("test"), WithKeyType(kt))
Expand Down Expand Up @@ -97,7 +97,7 @@ func TestGenerateEd25519Keys(t *testing.T) {
dir := t.TempDir()
filename := "test"

k := &SSHKeyPair{
k := &KeyPair{
path: filepath.Join(dir, filename),
keyType: Ed25519,
}
Expand Down Expand Up @@ -163,7 +163,7 @@ func TestGenerateECDSAKeys(t *testing.T) {
dir := t.TempDir()
filename := "test"

k := &SSHKeyPair{
k := &KeyPair{
path: filepath.Join(dir, filename),
keyType: ECDSA,
ec: elliptic.P384(),
Expand Down Expand Up @@ -228,7 +228,7 @@ func TestGenerateECDSAKeys(t *testing.T) {
// touchTestFile is a utility function we're using in testing.
func createEmptyFile(t *testing.T, path string) (ok bool) {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0700); err != nil {
if err := os.MkdirAll(dir, 0o700); err != nil {
t.Errorf("could not create directory %s: %v", dir, err)
return false
}
Expand Down

0 comments on commit 54993c5

Please sign in to comment.