Skip to content

Commit

Permalink
Patch ASA-0024-0012 and 0013 on v7.x compatible branch. (#62)
Browse files Browse the repository at this point in the history
Merge commit from fork

* Limit recursion depth for unknown field detection

(cherry picked from commit f038dc731c55be1e1c526e67695acc358631afd6)

* Limit unpack any

(cherry picked from commit 1a2bff56fb7391f9ce87d4fbe9e0367ae991c0b2)

* Update Changelog

* Another limit recursion depth for unknown field detection

* Update changelog

Co-authored-by: Alexander Peters <[email protected]>
  • Loading branch information
vincentwschau and alpe authored Dec 19, 2024
1 parent 3865004 commit 93be52f
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 6 deletions.
57 changes: 55 additions & 2 deletions codec/types/interface_registry.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package types

import (
"errors"
"fmt"
"reflect"

Expand All @@ -12,6 +13,17 @@ import (
"cosmossdk.io/x/tx/signing"
)

var (

// MaxUnpackAnySubCalls extension point that defines the maximum number of sub-calls allowed during the unpacking
// process of protobuf Any messages.
MaxUnpackAnySubCalls = 100

// MaxUnpackAnyRecursionDepth extension point that defines the maximum allowed recursion depth during protobuf Any
// message unpacking.
MaxUnpackAnyRecursionDepth = 10
)

// AnyUnpacker is an interface which allows safely unpacking types packed
// in Any's against a whitelist of registered types
type AnyUnpacker interface {
Expand Down Expand Up @@ -270,6 +282,45 @@ func (registry *interfaceRegistry) ListImplementations(ifaceName string) []strin
}

func (registry *interfaceRegistry) UnpackAny(any *Any, iface interface{}) error {
unpacker := &statefulUnpacker{
registry: registry,
maxDepth: MaxUnpackAnyRecursionDepth,
maxCalls: &sharedCounter{count: MaxUnpackAnySubCalls},
}
return unpacker.UnpackAny(any, iface)
}

// sharedCounter is a type that encapsulates a counter value
type sharedCounter struct {
count int
}

// statefulUnpacker is a struct that helps in deserializing and unpacking
// protobuf Any messages while maintaining certain stateful constraints.
type statefulUnpacker struct {
registry *interfaceRegistry
maxDepth int
maxCalls *sharedCounter
}

// cloneForRecursion returns a new statefulUnpacker instance with maxDepth reduced by one, preserving the registry and maxCalls.
func (r statefulUnpacker) cloneForRecursion() *statefulUnpacker {
return &statefulUnpacker{
registry: r.registry,
maxDepth: r.maxDepth - 1,
maxCalls: r.maxCalls,
}
}

// UnpackAny deserializes a protobuf Any message into the provided interface, ensuring the interface is a pointer.
// It applies stateful constraints such as max depth and call limits, and unpacks interfaces if required.
func (r *statefulUnpacker) UnpackAny(any *Any, iface interface{}) error {
if r.maxDepth == 0 {
return errors.New("max depth exceeded")
}
if r.maxCalls.count == 0 {
return errors.New("call limit exceeded")
}
// here we gracefully handle the case in which `any` itself is `nil`, which may occur in message decoding
if any == nil {
return nil
Expand All @@ -280,6 +331,8 @@ func (registry *interfaceRegistry) UnpackAny(any *Any, iface interface{}) error
return nil
}

r.maxCalls.count--

rv := reflect.ValueOf(iface)
if rv.Kind() != reflect.Ptr {
return fmt.Errorf("UnpackAny expects a pointer")
Expand All @@ -295,7 +348,7 @@ func (registry *interfaceRegistry) UnpackAny(any *Any, iface interface{}) error
}
}

imap, found := registry.interfaceImpls[rt]
imap, found := r.registry.interfaceImpls[rt]
if !found {
return fmt.Errorf("no registered implementations of type %+v", rt)
}
Expand All @@ -315,7 +368,7 @@ func (registry *interfaceRegistry) UnpackAny(any *Any, iface interface{}) error
return err
}

err = UnpackInterfaces(msg, registry)
err = UnpackInterfaces(msg, r.cloneForRecursion())
if err != nil {
return err
}
Expand Down
18 changes: 16 additions & 2 deletions codec/unknownproto/unknown_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,23 @@ func RejectUnknownFieldsStrict(bz []byte, msg proto.Message, resolver jsonpb.Any
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
// An AnyResolver must be provided for traversing inside google.protobuf.Any's.
func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool, resolver jsonpb.AnyResolver) (hasUnknownNonCriticals bool, err error) {
// recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28
return doRejectUnknownFields(bz, msg, allowUnknownNonCriticals, resolver, 10_000)
}

func doRejectUnknownFields(
bz []byte,
msg proto.Message,
allowUnknownNonCriticals bool,
resolver jsonpb.AnyResolver,
recursionLimit int,
) (hasUnknownNonCriticals bool, err error) {
if len(bz) == 0 {
return hasUnknownNonCriticals, nil
}
if recursionLimit == 0 {
return false, errors.New("recursion limit reached")
}

desc, ok := msg.(descriptorIface)
if !ok {
Expand Down Expand Up @@ -130,7 +144,7 @@ func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals

if protoMessageName == ".google.protobuf.Any" {
// Firstly typecheck types.Any to ensure nothing snuck in.
hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver)
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver, recursionLimit-1)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
Expand All @@ -153,7 +167,7 @@ func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals
}
}

hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver)
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver, recursionLimit-1)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
Expand Down
18 changes: 16 additions & 2 deletions x/tx/decode/unknown.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,23 @@ func RejectUnknownFieldsStrict(bz []byte, msg protoreflect.MessageDescriptor, re
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
// An AnyResolver must be provided for traversing inside google.protobuf.Any's.
func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUnknownNonCriticals bool, resolver protodesc.Resolver) (hasUnknownNonCriticals bool, err error) {
// recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28
return doRejectUnknownFields(bz, desc, allowUnknownNonCriticals, resolver, 10_000)
}

func doRejectUnknownFields(
bz []byte,
desc protoreflect.MessageDescriptor,
allowUnknownNonCriticals bool,
resolver protodesc.Resolver,
recursionLimit int,
) (hasUnknownNonCriticals bool, err error) {
if len(bz) == 0 {
return hasUnknownNonCriticals, nil
}
if recursionLimit == 0 {
return false, errors.New("recursion limit reached")
}

fields := desc.Fields()

Expand Down Expand Up @@ -91,7 +105,7 @@ func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUn

if fieldMessage.FullName() == anyFullName {
// Firstly typecheck types.Any to ensure nothing snuck in.
hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, anyDesc, allowUnknownNonCriticals, resolver)
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, anyDesc, allowUnknownNonCriticals, resolver, recursionLimit-1)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
Expand All @@ -111,7 +125,7 @@ func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUn
fieldBytes = a.Value
}

hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, fieldMessage, allowUnknownNonCriticals, resolver)
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, fieldMessage, allowUnknownNonCriticals, resolver, recursionLimit-1)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
Expand Down

0 comments on commit 93be52f

Please sign in to comment.