Skip to content

Commit

Permalink
Add UnmarshalMsgWithState to Unmarshaler interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ohill committed Oct 3, 2023
1 parent ab5758d commit a679500
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 5 deletions.
19 changes: 17 additions & 2 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,17 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) {
u.p.printf("\n return ((*(%s))(%s)).UnmarshalMsg(bts)", baseType, c)
u.p.printf("\n}")

u.p.printf("\nfunc (%s %s) UnmarshalMsgWithState(bts []byte, st msgp.UnmarshalState) ([]byte, error) {", c, methodRecv)
u.p.printf("\n return ((*(%s))(%s)).UnmarshalMsgWithState(bts, st)", baseType, c)
u.p.printf("\n}")

u.p.printf("\nfunc (_ %[2]s) CanUnmarshalMsg(%[1]s interface{}) bool {", c, methodRecv)
u.p.printf("\n _, ok := (%s).(%s)", c, methodRecv)
u.p.printf("\n return ok")
u.p.printf("\n}")

u.topics.Add(methodRecv, "UnmarshalMsg")
u.topics.Add(methodRecv, "UnmarshalMsgWithState")
u.topics.Add(methodRecv, "CanUnmarshalMsg")

return u.msgs, u.p.err
Expand All @@ -75,7 +80,12 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) {
c := p.Varname()
methodRecv := methodReceiver(p)

u.p.printf("\nfunc (%s %s) UnmarshalMsg(bts []byte) (o []byte, err error) {", c, methodRecv)
u.p.printf("\nfunc (%s %s) UnmarshalMsgWithState(bts []byte, st msgp.UnmarshalState) (o []byte, err error) {", c, methodRecv)
u.p.printf("\n if st.Depth == 0 {")
u.p.printf("\n err = msgp.ErrMaxDepthExceeded{}")
u.p.printf("\n return")
u.p.printf("\n }")
u.p.printf("\n st.Depth--")
next(u, p)
u.p.print("\no = bts")

Expand All @@ -91,12 +101,17 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) {
}
u.p.nakedReturn()

u.p.printf("\nfunc (%s %s) UnmarshalMsg(bts []byte) (o []byte, err error) {", c, methodRecv)
u.p.printf("\n return %s.UnmarshalMsgWithState(bts, msgp.DefaultUnmarshalState)", c)
u.p.printf("\n}")

u.p.printf("\nfunc (_ %[2]s) CanUnmarshalMsg(%[1]s interface{}) bool {", c, methodRecv)
u.p.printf("\n _, ok := (%s).(%s)", c, methodRecv)
u.p.printf("\n return ok")
u.p.printf("\n}")

u.topics.Add(methodRecv, "UnmarshalMsg")
u.topics.Add(methodRecv, "UnmarshalMsgWithState")
u.topics.Add(methodRecv, "CanUnmarshalMsg")

return u.msgs, u.p.err
Expand Down Expand Up @@ -236,7 +251,7 @@ func (u *unmarshalGen) gBase(b *BaseElem) {
case Ext:
u.p.printf("\nbts, err = msgp.ReadExtensionBytes(bts, %s)", lowered)
case IDENT:
u.p.printf("\nbts, err = %s.UnmarshalMsg(bts)", lowered)
u.p.printf("\nbts, err = %s.UnmarshalMsgWithState(bts, st)", lowered)
case String:
if b.common.AllocBound() != "" {
sz := randIdent()
Expand Down
12 changes: 10 additions & 2 deletions msgp/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,9 @@ func Resumable(e error) bool {
//
// ErrShortBytes is not wrapped with any context due to backward compatibility
// issues with the public API.
//
func WrapError(err error, ctx ...interface{}) error {
switch e := err.(type) {
case errShort:
case errShort, ErrMaxDepthExceeded:
return e
case contextError:
return e.withContext(ctxString(ctx))
Expand Down Expand Up @@ -344,3 +343,12 @@ func (e *ErrUnsupportedType) withContext(ctx string) error {
o.ctx = addCtx(o.ctx, ctx)
return &o
}

// ErrMaxDepthExceeded is returned if the maximum traversal depth is exceeded.
type ErrMaxDepthExceeded struct{}

// Error implements error
func (e ErrMaxDepthExceeded) Error() string { return "Max depth exceeded" }

// Resumable implements Error
func (e ErrMaxDepthExceeded) Resumable() bool { return false }
9 changes: 9 additions & 0 deletions msgp/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,14 @@ func (t Type) String() string {
// field in a struct rather than unmarshaling the entire struct.
type Unmarshaler interface {
UnmarshalMsg([]byte) ([]byte, error)
UnmarshalMsgWithState([]byte, UnmarshalState) ([]byte, error)
CanUnmarshalMsg(o interface{}) bool
}

// UnmarshalState holds state while running UnmarshalMsg.
type UnmarshalState struct {
AllowableDepth uint64
}

// DefaultUnmarshalState defines the default state.
var DefaultUnmarshalState = UnmarshalState{AllowableDepth: 10000}
12 changes: 11 additions & 1 deletion msgp/read_bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ func (*Raw) CanUnmarshalMsg(z interface{}) bool {
// It sets the contents of *Raw to be the next
// object in the provided byte slice.
func (r *Raw) UnmarshalMsg(b []byte) ([]byte, error) {
return r.UnmarshalMsgWithState(b, DefaultUnmarshalState)
}

// UnmarshalMsg implements msgp.Unmarshaler.
// It sets the contents of *Raw to be the next
// object in the provided byte slice.
func (r *Raw) UnmarshalMsgWithState(b []byte, st UnmarshalState) ([]byte, error) {
if st.AllowableDepth == 0 {
return nil, ErrMaxDepthExceeded{}
}
l := len(b)
out, err := Skip(b)
if err != nil {
Expand Down Expand Up @@ -1185,7 +1195,7 @@ func ReadStringBytes(b []byte) (string, []byte, error) {
// into a slice of bytes. 'v' is the value of
// the 'str' object, which may reside in memory
// pointed to by 'scratch.' 'o' is the remaining bytes
// in 'b.''
// in 'b'.
// Possible errors:
// - ErrShortBytes (b not long enough)
// - TypeError{} (not 'str' type)
Expand Down

0 comments on commit a679500

Please sign in to comment.