Skip to content

Commit

Permalink
Merge pull request #7 from Versent/feature/add_versioning_support
Browse files Browse the repository at this point in the history
Added missing support for credential version and filtering.
  • Loading branch information
wolfeidau committed Mar 12, 2016
2 parents 8254c4f + fb58c66 commit 08ffab1
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 31 deletions.
3 changes: 0 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ VERSION=1.0.5
GO15VENDOREXPERIMENT := 1
ITERATION := 1

vendor:
godep save -d -t

build:
rm -rf build && mkdir build
mkdir -p build/Linux && GOOS=linux go build -ldflags "-X main.Version=$(VERSION)" -o build/Linux/$(NAME) ./cmd/unicreds
Expand Down
36 changes: 21 additions & 15 deletions cmd/unicreds/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"io/ioutil"
"os"

"github.com/apex/log"
"github.com/apex/log/handlers/cli"

"github.com/alecthomas/kingpin"
"github.com/aws/aws-sdk-go/aws"
"github.com/versent/unicreds"
Expand All @@ -25,9 +28,10 @@ var (
cmdGet = app.Command("get", "Get a credential from the store.")
cmdGetName = cmdGet.Arg("credential", "The name of the credential to get.").Required().String()

cmdGetAll = app.Command("getall", "Get all credentials from the store.")
cmdGetAll = app.Command("getall", "Get latest credentials from the store.")

cmdList = app.Command("list", "List all credentials names and version.")
cmdList = app.Command("list", "List latest credentials with names and version.")
cmdListAll = cmdList.Flag("all", "List all versions").Bool()

cmdPut = app.Command("put", "Put a credential into the store.")
cmdPutName = cmdPut.Arg("credential", "The name of the credential to store.").Required().String()
Expand All @@ -48,6 +52,7 @@ var (

func main() {
app.Version(Version)
log.SetHandler(cli.Default)

command := kingpin.MustParse(app.Parse(os.Args[1:]))

Expand All @@ -68,21 +73,22 @@ func main() {
if err != nil {
printFatalError(err)
}
fmt.Printf("%+v\n", cred.Secret)
fmt.Println(cred.Secret)
case cmdPut.FullCommand():
var version string
if *cmdPutVersion != 0 {
version = fmt.Sprintf("%d", *cmdPutVersion)
version, err := unicreds.ResolveVersion(*cmdPutName, *cmdPutVersion)
if err != nil {
printFatalError(err)
}
err := unicreds.PutSecret(*alias, *cmdPutName, *cmdPutSecret, version)

err = unicreds.PutSecret(*alias, *cmdPutName, *cmdPutSecret, version)
if err != nil {
printFatalError(err)
}
fmt.Printf("%s has been stored\n", *cmdPutName)
log.WithFields(log.Fields{"name": *cmdPutName, "version": version}).Info("stored")
case cmdPutFile.FullCommand():
var version string
if *cmdPutFileVersion != 0 {
version = fmt.Sprintf("%d", *cmdPutFileVersion)
version, err := unicreds.ResolveVersion(*cmdPutFileName, *cmdPutFileVersion)
if err != nil {
printFatalError(err)
}

data, err := ioutil.ReadFile(*cmdPutFileSecretPath)
Expand All @@ -94,9 +100,9 @@ func main() {
if err != nil {
printFatalError(err)
}
fmt.Printf("%s has been stored\n", *cmdPutFileName)
log.WithFields(log.Fields{"name": *cmdPutName, "version": version}).Info("stored")
case cmdList.FullCommand():
creds, err := unicreds.ListSecrets()
creds, err := unicreds.ListSecrets(*cmdListAll)
if err != nil {
printFatalError(err)
}
Expand All @@ -113,7 +119,7 @@ func main() {
}
table.Render()
case cmdGetAll.FullCommand():
creds, err := unicreds.ListSecrets()
creds, err := unicreds.GetAllSecrets(true)
if err != nil {
printFatalError(err)
}
Expand All @@ -138,6 +144,6 @@ func main() {
}

func printFatalError(err error) {
fmt.Fprintf(os.Stderr, "error occured: %v\n", err)
log.WithError(err).Error("failed")
os.Exit(1)
}
171 changes: 158 additions & 13 deletions ds.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package unicreds
import (
"encoding/base64"
"errors"
"fmt"
"sort"
"strconv"
"time"

"github.com/apex/log"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
Expand Down Expand Up @@ -73,10 +76,23 @@ type DecryptedCredential struct {
Secret string
}

// ByVersion sort helper for credentials
type ByVersion []*Credential

func (a ByVersion) Len() int { return len(a) }
func (a ByVersion) Swap(i, j int) { a[i], a[j] = a[j], a[i] }

func (a ByVersion) Less(i, j int) bool {
aiv, _ := strconv.Atoi(a[i].Version)
ajv, _ := strconv.Atoi(a[j].Version)

return aiv < ajv
}

// Setup create the table which stores credentials
func Setup() (err error) {

res, err := dynamoSvc.CreateTable(&dynamodb.CreateTableInput{
_, err = dynamoSvc.CreateTable(&dynamodb.CreateTableInput{
AttributeDefinitions: []*dynamodb.AttributeDefinition{
{
AttributeName: aws.String("name"),
Expand Down Expand Up @@ -108,7 +124,7 @@ func Setup() (err error) {
return
}

fmt.Printf("res = %+v\n", res)
log.Info("created")

err = waitForTable()

Expand Down Expand Up @@ -153,8 +169,72 @@ func GetSecret(name string) (*DecryptedCredential, error) {
return decryptCredential(cred)
}

// ListSecrets return a list of secrets
func ListSecrets() ([]*DecryptedCredential, error) {
// GetHighestVersion look up the highest version for a given name
func GetHighestVersion(name string) (string, error) {

res, err := dynamoSvc.Query(&dynamodb.QueryInput{
TableName: aws.String(Table),
ExpressionAttributeNames: map[string]*string{
"#N": aws.String("name"),
},
ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{
":name": &dynamodb.AttributeValue{
S: aws.String(name),
},
},
KeyConditionExpression: aws.String("#N = :name"),
Limit: aws.Int64(1),
ConsistentRead: aws.Bool(true),
ScanIndexForward: aws.Bool(false), // descending order
ProjectionExpression: aws.String("version"),
})

if err != nil {
return "", err
}

if len(res.Items) == 0 {
return "", ErrSecretNotFound
}

v := res.Items[0]["version"]

if v == nil {
return "", ErrSecretNotFound
}

return aws.StringValue(v.S), nil
}

// ListSecrets returns a list of all secrets
func ListSecrets(all bool) ([]*Credential, error) {

res, err := dynamoSvc.Scan(&dynamodb.ScanInput{
TableName: aws.String(Table),
ExpressionAttributeNames: map[string]*string{
"#N": aws.String("name"),
},
ProjectionExpression: aws.String("#N, version, created_at"),
ConsistentRead: aws.Bool(true),
})
if err != nil {
return nil, err
}

if all {
return decodeCredential(res.Items)
}

creds, err := decodeCredential(res.Items)
if err != nil {
return nil, err
}

return filterLatest(creds)
}

// GetAllSecrets returns a list of all secrets
func GetAllSecrets(all bool) ([]*DecryptedCredential, error) {

res, err := dynamoSvc.Scan(&dynamodb.ScanInput{
TableName: aws.String(Table),
Expand All @@ -168,20 +248,18 @@ func ListSecrets() ([]*DecryptedCredential, error) {
},
ConsistentRead: aws.Bool(true),
})
if err != nil {
return nil, err
}

creds, err := decodeCredential(res.Items)
if err != nil {
return nil, err
}

var results []*DecryptedCredential

for _, item := range res.Items {
cred := new(Credential)

err = Decode(item, cred)
if err != nil {
return nil, err
}
for _, cred := range creds {

dcred, err := decryptCredential(cred)
if err != nil {
Expand Down Expand Up @@ -240,6 +318,10 @@ func PutSecret(alias, name, secret, version string) error {
_, err = dynamoSvc.PutItem(&dynamodb.PutItemInput{
TableName: aws.String(Table),
Item: data,
ExpressionAttributeNames: map[string]*string{
"#N": aws.String("name"),
},
ConditionExpression: aws.String("attribute_not_exists(#N)"),
})

return err
Expand Down Expand Up @@ -275,7 +357,7 @@ func DeleteSecret(name string) error {
return err
}

fmt.Printf("deleting name=%s version=%s\n", cred.Name, cred.Version)
log.WithFields(log.Fields{"name": cred.Name, "version": cred.Version}).Info("deleting")

_, err = dynamoSvc.DeleteItem(&dynamodb.DeleteItemInput{
TableName: aws.String(Table),
Expand All @@ -297,6 +379,30 @@ func DeleteSecret(name string) error {
return nil
}

// ResolveVersion calculate the version given a name and version
func ResolveVersion(name string, version int) (string, error) {

if version != 0 {
return strconv.Itoa(version), nil
}

ver, err := GetHighestVersion(name)
if err != nil {
if err == ErrSecretNotFound {
return "1", nil
}
return "", err
}

if version, err = strconv.Atoi(ver); err != nil {
return "", err
}

version++

return strconv.Itoa(version), nil
}

func decryptCredential(cred *Credential) (*DecryptedCredential, error) {

wrappedKey, err := base64.StdEncoding.DecodeString(cred.Key)
Expand Down Expand Up @@ -336,6 +442,45 @@ func decryptCredential(cred *Credential) (*DecryptedCredential, error) {
return &DecryptedCredential{Credential: cred, Secret: plainText}, nil
}

func decodeCredential(items []map[string]*dynamodb.AttributeValue) ([]*Credential, error) {

results := make([]*Credential, 0, len(items))

for _, item := range items {
cred := new(Credential)

err := Decode(item, cred)
if err != nil {
return nil, err
}

results = append(results, cred)
}
return results, nil
}

func filterLatest(creds []*Credential) ([]*Credential, error) {

sort.Sort(ByVersion(creds))

names := map[string]*Credential{}

for _, cred := range creds {
names[cred.Name] = cred
}

results := make([]*Credential, 0, len(names))

for _, val := range names {
results = append(results, val)
}

// because maps key order is randomised in golang
sort.Sort(ByVersion(results))

return results, nil
}

func waitForTable() error {

timeout := make(chan bool, 1)
Expand Down

0 comments on commit 08ffab1

Please sign in to comment.