Skip to content

Commit

Permalink
feat: replcae vault key with secret
Browse files Browse the repository at this point in the history
  • Loading branch information
Mahanmmi committed Apr 10, 2024
1 parent 0f004fc commit 0d2e02e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 21 deletions.
4 changes: 2 additions & 2 deletions pkg/vault/azure_vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func NewAzureVaultClient(ctx context.Context, logger *zap.Logger, config AzureVa
return &sc, nil
}

func (sc *AzureVaultSourceConfig) Encrypt(ctx context.Context, cred map[string]any, _, _ string) (string, error) {
func (sc *AzureVaultSourceConfig) Encrypt(ctx context.Context, cred map[string]any) (string, error) {
bytes, err := json.Marshal(cred)
if err != nil {
sc.logger.Error("failed to marshal the credential", zap.Error(err))
Expand Down Expand Up @@ -90,7 +90,7 @@ func (sc *AzureVaultSourceConfig) Encrypt(ctx context.Context, cred map[string]a
return base64.StdEncoding.EncodeToString(cipherText), nil
}

func (sc *AzureVaultSourceConfig) Decrypt(ctx context.Context, cypherText string, _ string) (map[string]any, error) {
func (sc *AzureVaultSourceConfig) Decrypt(ctx context.Context, cypherText string) (map[string]any, error) {
aesCipher, err := aes.NewCipher(sc.AesKey)
if err != nil {
sc.logger.Error("failed to create cipher", zap.Error(err))
Expand Down
25 changes: 17 additions & 8 deletions pkg/vault/kms_vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ import (
"github.com/aws/aws-sdk-go-v2/service/sts"
)

type AwsVaultConfig struct {
Region string `yaml:"region" json:"region" koanf:"region"`
RoleArn string `yaml:"role_arn" json:"role_arn" koanf:"role_arn"`
AccessKey string `yaml:"access_key" json:"access_key" koanf:"access_key"`
SecretKey string `yaml:"secret_key" json:"secret_key" koanf:"secret_key"`
}

func getAWSConfig(ctx context.Context, awsAccessKey, awsSecretKey, awsSessionToken, assumeRoleArn string) (aws.Config, error) {
opts := make([]func(*config.LoadOptions) error, 0)

Expand All @@ -39,35 +46,37 @@ func getAWSConfig(ctx context.Context, awsAccessKey, awsSecretKey, awsSessionTok

type KMSVaultSourceConfig struct {
kmsClient *kms.Client
keyArn string
}

func NewKMSVaultSourceConfig(ctx context.Context, accessKey, secretKey, region string) (*KMSVaultSourceConfig, error) {
func NewKMSVaultSourceConfig(ctx context.Context, awsConfig AwsVaultConfig, keyArn string) (*KMSVaultSourceConfig, error) {
var err error
cfg, err := config.LoadDefaultConfig(ctx)
// if the keys are not provided, the default credentials from service account will be used
if accessKey != "" && secretKey != "" {
cfg, err = getAWSConfig(ctx, accessKey, secretKey, "", "")
if awsConfig.AccessKey != "" && awsConfig.SecretKey != "" {
cfg, err = getAWSConfig(ctx, awsConfig.AccessKey, awsConfig.SecretKey, "", "")
}
if err != nil {
return nil, fmt.Errorf("failed to load SDK configuration: %v", err)
}
cfg.Region = region
cfg.Region = awsConfig.Region
// Create KMS client with loaded configuration
svc := kms.NewFromConfig(cfg)

return &KMSVaultSourceConfig{
kmsClient: svc,
keyArn: keyArn,
}, nil
}

func (v *KMSVaultSourceConfig) Encrypt(ctx context.Context, cred map[string]any, keyARN string, _ string) (string, error) {
func (v *KMSVaultSourceConfig) Encrypt(ctx context.Context, cred map[string]any) (string, error) {
bytes, err := json.Marshal(cred)
if err != nil {
return "", err
}

result, err := v.kmsClient.Encrypt(ctx, &kms.EncryptInput{
KeyId: &keyARN,
KeyId: &v.keyArn,
Plaintext: bytes,
EncryptionAlgorithm: types.EncryptionAlgorithmSpecSymmetricDefault,
EncryptionContext: nil, //TODO-Saleh use workspaceID
Expand All @@ -80,7 +89,7 @@ func (v *KMSVaultSourceConfig) Encrypt(ctx context.Context, cred map[string]any,
return encoded, nil
}

func (v *KMSVaultSourceConfig) Decrypt(ctx context.Context, cypherText string, keyARN string) (map[string]any, error) {
func (v *KMSVaultSourceConfig) Decrypt(ctx context.Context, cypherText string) (map[string]any, error) {
bytes, err := base64.StdEncoding.DecodeString(cypherText)
if err != nil {
return nil, fmt.Errorf("failed to decode ciphertext: %v", err)
Expand All @@ -89,7 +98,7 @@ func (v *KMSVaultSourceConfig) Decrypt(ctx context.Context, cypherText string, k
result, err := v.kmsClient.Decrypt(ctx, &kms.DecryptInput{
CiphertextBlob: bytes,
EncryptionAlgorithm: types.EncryptionAlgorithmSpecSymmetricDefault,
KeyId: &keyARN,
KeyId: &v.keyArn,
EncryptionContext: nil, //TODO-Saleh use workspaceID
})
if err != nil {
Expand Down
17 changes: 6 additions & 11 deletions pkg/vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,13 @@ const (
)

type Config struct {
Provider Provider `yaml:"provider" json:"provider" koanf:"provider"`
Aws struct {
Region string `yaml:"region" json:"region" koanf:"region"`
RoleArn string `yaml:"role_arn" json:"role_arn" koanf:"role_arn"`
AccessKey string `yaml:"access_key" json:"access_key" koanf:"access_key"`
SecretKey string `yaml:"secret_key" json:"secret_key" koanf:"secret_key"`
} `yaml:"aws" json:"aws" koanf:"aws"`
Azure AzureVaultConfig `yaml:"azure" json:"azure" koanf:"azure"`
KeyId string `yaml:"key_id" json:"key_id" koanf:"key_id"`
Provider Provider `yaml:"provider" json:"provider" koanf:"provider"`
Aws AwsVaultConfig `yaml:"aws" json:"aws" koanf:"aws"`
Azure AzureVaultConfig `yaml:"azure" json:"azure" koanf:"azure"`
KeyId string `yaml:"key_id" json:"key_id" koanf:"key_id"`
}

type VaultSourceConfig interface {
Encrypt(ctx context.Context, data map[string]any, keyId string) (string, error)
Decrypt(ctx context.Context, cypherText string, keyId string) (map[string]any, error)
Encrypt(ctx context.Context, data map[string]any) (string, error)
Decrypt(ctx context.Context, cypherText string) (map[string]any, error)
}

0 comments on commit 0d2e02e

Please sign in to comment.