Skip to content

Commit

Permalink
Merge pull request #7 from morgabra/morgabra/DynamoVersion
Browse files Browse the repository at this point in the history
Generate an optional 'Version() (int64, error)' func
  • Loading branch information
pquerna authored Dec 9, 2020
2 parents 74221a3 + f64273f commit b734a4a
Show file tree
Hide file tree
Showing 14 changed files with 1,480 additions and 48 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.idea
examplepb/example.pb.dynamo.go
examplepb/example.pb.go
107 changes: 60 additions & 47 deletions dynamo/dynamo.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions dynamo/dynamo.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ message DynamoMessageOptions {
Key sort = 3;
// A list of compound values that can be set from other fields.
repeated Key compound_field = 4;
// A value that can be used as a item version.
Key version = 5;
}

message Key {
Expand Down
4 changes: 4 additions & 0 deletions examplepb/example.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ message Store {
option (dynamo.msg).partition = {name: "pk", prefix: "store", fields: ["id", "country", "foo"]};
option (dynamo.msg).sort = {name: "sk", const: "example"};

option (dynamo.msg).version = {fields: ["updated_at"]};

option (dynamo.msg).compound_field = {name: "gsi1pk", prefix: "store", fields: ["id", "country", "foo"]};
option (dynamo.msg).compound_field = {name: "gsi1sk", const: "dummyvalue"};

Expand All @@ -26,6 +28,8 @@ message Store {
repeated string best_employee_ids = 8 [(dynamo.field).type.set = true];
google.protobuf.Timestamp bin_date = 9;

google.protobuf.Timestamp updated_at = 10;

uint64 foo = 99;
repeated uint64 morefoo = 100 [(dynamo.field).type.set = true];
}
63 changes: 62 additions & 1 deletion internal/pgd/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package pgd

import (
"bytes"
"errors"
"fmt"

"github.com/dave/jennifer/jen"
"github.com/davecgh/go-spew/spew"
"github.com/lyft/protoc-gen-star"
pgs "github.com/lyft/protoc-gen-star"
pgsgo "github.com/lyft/protoc-gen-star/lang/go"
dynamopb "github.com/pquerna/protoc-gen-dynamo/dynamo"
)
Expand Down Expand Up @@ -66,6 +68,7 @@ const (
dynamoPkg = "github.com/aws/aws-sdk-go/service/dynamodb"
protoPkg = "github.com/golang/protobuf/proto"
awsPkg = "github.com/aws/aws-sdk-go/aws"
ptypesPkg = "github.com/golang/protobuf/ptypes"
strconvPkg = "strconv"
stringsPkg = "strings"
fmtPkg = "fmt"
Expand Down Expand Up @@ -174,6 +177,50 @@ type namedKey struct {
key *dynamopb.Key
}

func (m *Module) applyVersionFuncs(msg pgs.Message, key namedKey, f *jen.File) error {
if key.key.Const != "" {
return errors.New("version: constants not allowed")
}

if len(key.key.Fields) != 1 {
return errors.New("version: exactly 1 field is required")
}

structName := m.ctx.Name(msg)
fn := key.key.Fields[0]
field := fieldByName(msg, fn)
srcName := field.Name().UpperCamelCase().String()

var stmts []jen.Code
if field.Type().ProtoType().IsNumeric() {
// return int64(p.<fieldName>)
stmts = append(stmts, jen.Return(jen.List(jen.Int64().Parens(jen.Id("p").Dot(srcName)), jen.Nil())))
} else {
d := field.Descriptor().TypeName
if d != nil && *d == ".google.protobuf.Timestamp" {
// t, err := ptypes.Timestamp(p.<fieldName>)
// if err != nil { return 0, err }
// return t.UnixNano(), nil
f.ImportName(ptypesPkg, "ptypes")
stmts = append(stmts, jen.List(jen.Id("t"), jen.Id("err")).Op(":=").Qual(ptypesPkg, "Timestamp").Call(
jen.Id("p").Dot(srcName),
))
stmts = append(stmts, jen.If(jen.Err().Op("!=").Nil()).Block(jen.Return(jen.List(jen.Lit(0), jen.Err()))))
stmts = append(stmts, jen.Return(jen.List(jen.Id("t").Dot("UnixNano").Call(), jen.Nil())))
}
}

if len(stmts) == 0 {
return errors.New("version: numeric or timestamp type is required")
}

f.Func().Params(
jen.Id("p").Op("*").Id(structName.String()),
).Id(key.name).Params().Parens(jen.List(jen.Int64(), jen.Error())).Block(stmts...).Line()

return nil
}

func (m *Module) applyKeyFuncs(f *jen.File, in pgs.File) error {
const stringBuffer = "sb"
for _, msg := range in.AllMessages() {
Expand All @@ -198,6 +245,10 @@ func (m *Module) applyKeyFuncs(f *jen.File, in pgs.File) error {
key: mext.Sort,
name: "SortKey",
},
{
key: mext.Version,
name: "Version",
},
}

for _, ck := range mext.CompoundField {
Expand All @@ -212,6 +263,16 @@ func (m *Module) applyKeyFuncs(f *jen.File, in pgs.File) error {
if key.key == nil {
continue
}

if key.name == "Version" {
err := m.applyVersionFuncs(msg, key, f)
if err != nil {
m.Logf("Generating version funcs failed: %s", err)
m.Fail("code generation failed")
}
continue
}

stmts := []jen.Code{}
if key.key.Const != "" {
stmts = append(stmts,
Expand Down
Loading

0 comments on commit b734a4a

Please sign in to comment.