Skip to content

Commit

Permalink
adding directive: unmarshling callback (#12)
Browse files Browse the repository at this point in the history
* WIP: added a check for wellformness of structs in unmarshling

* added directive to include callbacks at end of unmarshling

* added use of callbacks

* fix: comment

* tidy gomod
  • Loading branch information
algonathan authored Oct 19, 2021
1 parent a756cc9 commit 9359fd1
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 15 deletions.
40 changes: 32 additions & 8 deletions gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ var primitives = map[string]Primitive{
"interface{}": Intf,
"time.Time": Time,
"msgp.Extension": Ext,
"error": Error,
"error": Error,
}

// types built into the library
Expand All @@ -136,19 +136,37 @@ var builtins = map[string]struct{}{
"msgp.Number": struct{}{},
}

// Callback represents a function that can is expected to be printed into the generated code.
// for example, at the end of a successful unmarshalling.
type Callback struct {
Fname string
CallbackType CallbackType
}

type CallbackType uint64

// UnmarshalCallBack represents a type callback that should run over the generated code.
const UnmarshalCallBack CallbackType = 1

func (c Callback) IsUnmarshallCallback() bool { return c.CallbackType == UnmarshalCallBack }
func (c Callback) GetName() string { return c.Fname }

// common data/methods for every Elem
type common struct {
vname, alias string
allocbound string
callbacks []Callback
}

func (c *common) SetVarname(s string) { c.vname = s }
func (c *common) Varname() string { return c.vname }
func (c *common) Alias(typ string) { c.alias = typ }
func (c *common) SortInterface() string { return "" }
func (c *common) SetAllocBound(s string) { c.allocbound = s }
func (c *common) AllocBound() string { return c.allocbound }
func (c *common) hidden() {}
func (c *common) SetVarname(s string) { c.vname = s }
func (c *common) Varname() string { return c.vname }
func (c *common) Alias(typ string) { c.alias = typ }
func (c *common) SortInterface() string { return "" }
func (c *common) SetAllocBound(s string) { c.allocbound = s }
func (c *common) AllocBound() string { return c.allocbound }
func (c *common) GetCallbacks() []Callback { return c.callbacks }
func (c *common) AddCallback(cb Callback) { c.callbacks = append(c.callbacks, cb) }
func (c *common) hidden() {}

func IsDangling(e Elem) bool {
if be, ok := e.(*BaseElem); ok && be.Dangling() {
Expand Down Expand Up @@ -221,6 +239,12 @@ type Elem interface {
// when decoding this type. Meaningful for slices and maps.
AllocBound() string

// AddCallback adds to the elem a Callback it should call at the end of marshaling
AddCallback(Callback)

// GetCallbacks fetches all callbacks this Elem stored.
GetCallbacks() []Callback

hidden()
}

Expand Down
11 changes: 11 additions & 0 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) {
u.p.printf("\nfunc (%s %s) UnmarshalMsg(bts []byte) (o []byte, err error) {", c, methodRecv)
next(u, p)
u.p.print("\no = bts")

// right before the return: attempt to inspect well formed:
for _, callback := range p.GetCallbacks() {
if !callback.IsUnmarshallCallback() {
continue
}

u.p.printf("\nif err = %s.%s(); err != nil {", c, callback.GetName())
u.p.printf("\n return")
u.p.printf("\n}")
}
u.p.nakedReturn()

u.p.printf("\nfunc (_ %[2]s) CanUnmarshalMsg(%[1]s interface{}) bool {", c, methodRecv)
Expand Down
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module github.com/algorand/msgp
go 1.12

require (
github.com/philhofer/fwd v1.0.0
github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31
golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d
golang.org/x/tools v0.0.0-20200423205358-59e73619c742
)
18 changes: 14 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
github.com/philhofer/fwd v1.0.0 h1:UbZqGr5Y38ApvM/V/jEljVxwocdweyH+vmYvRPBnbqQ=
github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31 h1:OXcKh35JaYsGMRzpvFkLv/MEyPuL49CThT1pZ8aSml4=
github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31/go.mod h1:onvgF043R+lC5RZ8IT9rBXDaEDnpnw/Cl+HFiw+v/7Q=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/mod v0.2.0 h1:KU7oHjnv3XNWfa5COkzUifxZmxp1TyI7ImMXqFxLwvQ=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d h1:/iIZNFGxc/a7C3yWjGcnboV+Tkc7mxr+p6fDztwoxuM=
golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200423205358-59e73619c742 h1:9OGWpORUXvk8AsaBJlpzzDx7Srv/rSK6rvjcsJq4rJo=
golang.org/x/tools v0.0.0-20200423205358-59e73619c742/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
33 changes: 33 additions & 0 deletions parse/directives.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package parse

import (
"errors"
"fmt"
"go/ast"
"strings"
Expand All @@ -26,6 +27,38 @@ var directives = map[string]directive{
"tuple": astuple,
"sort": sortintf,
"allocbound": allocbound,
// _postunmarshalcheck is used to add callbacks to the end of unmarshling that are tied to a specific Element.
_postunmarshalcheck: postunmarshalcheck,
}

const _postunmarshalcheck = "postunmarshalcheck"

var errNotEnoughArguments = errors.New("postunmarshalcheck did not receive enough arguments. expected at least 3")

//msgp:postunmarshalcheck {Type} {funcName} {funcName} ...
// the functions should have no params, and output zero.
func postunmarshalcheck(text []string, f *FileSet) error {
if len(text) < 3 {
return errNotEnoughArguments
}
// not error but doesn't do anything
if text[0] != _postunmarshalcheck {
return nil
}
text = text[1:]

elemType := text[0]
elem, ok := f.Identities[elemType]
if !ok {
return errors.New(fmt.Sprintf("postunmarshalcheck error: type %v does not exist", elemType))
}
for _, fName := range text[1:] {
elem.AddCallback(gen.Callback{
Fname: fName,
CallbackType: gen.UnmarshalCallBack,
})
}
return nil
}

var passDirectives = map[string]passDirective{
Expand Down
53 changes: 53 additions & 0 deletions parse/directives_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package parse

import (
"testing"

"github.com/algorand/msgp/gen"
)

const (
testStructName = "TestStruct"
testFuncName = "callback"
)

func TestPostunmarshalcheck(t *testing.T) {
st := gen.Struct{
Fields: nil,
AsTuple: false,
}

fl := FileSet{
Identities: map[string]gen.Elem{testStructName: &st},
Directives: []string{"postunmarshalcheck"}, // raw preprocessor directives
}
if err := postunmarshalcheck([]string{"postunmarshalcheck", testStructName, testFuncName}, &fl); err != nil {
t.Fatal()
}
if testFuncName != st.GetCallbacks()[0].GetName() {
t.Fatal()
}
if !st.GetCallbacks()[0].IsUnmarshallCallback() {
t.Fatal()
}
}

func TestPostunmarshalcheckFailures(t *testing.T) {

st := gen.Struct{
Fields: nil,
AsTuple: false,
}

fl := FileSet{
Identities: map[string]gen.Elem{testStructName: &st},
Directives: []string{"postunmarshalcheck"}, // raw preprocessor directives
}
if err := postunmarshalcheck([]string{"postunmarshalcheck", testFuncName}, &fl); err == nil {
t.Fatal()
}

if err := postunmarshalcheck([]string{"postunmarshalcheck", "non-existing-type", testFuncName}, &fl); err == nil {
t.Fatal()
}
}
2 changes: 1 addition & 1 deletion parse/getast.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ func (fs *FileSet) getFieldsFromEmbeddedStruct(importPrefix string, f ast.Expr)
return nil
}

return pkgfs.getFieldsFromEmbeddedStruct(pkgid.Name + ".", f.Sel)
return pkgfs.getFieldsFromEmbeddedStruct(pkgid.Name+".", f.Sel)
default:
// other possibilities are disallowed
return nil
Expand Down

0 comments on commit 9359fd1

Please sign in to comment.