Skip to content

Commit

Permalink
fix(reflect): zero len slice case
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaost committed Jul 27, 2024
1 parent 29e9968 commit dd36d25
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 16 deletions.
34 changes: 28 additions & 6 deletions internal/reflect/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ import (
"github.com/cloudwego/frugal/internal/binary/defs"
)

var (
// for slice, Data should points to zerobase var in `runtime`
// so that it can represent as []type{} instead of []type(nil)
zeroLenSlice sliceHeader

// for string, all fields should be zero
zeroLenStr stringHeader
)

func init() {
b := make([]byte, 0)
zeroLenSlice = *(*sliceHeader)(unsafe.Pointer(&b))
s := ""
zeroLenStr = *(*stringHeader)(unsafe.Pointer(&s))
}

const maxDepthLimit = 1023

var decoderPool = sync.Pool{
Expand Down Expand Up @@ -169,17 +185,22 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int
l := int(binary.BigEndian.Uint32(b))
i += 4
if l == 0 {
if t.Tag == defs.T_binary {
*(*sliceHeader)(p) = zeroLenSlice
} else {
*(*stringHeader)(p) = zeroLenStr
}
return i, nil
}
x := d.Malloc(l, 1, 0)
if t.Tag == defs.T_binary {
h := (*sliceHeader)(p)
h.Data = x
h.Data = uintptr(x)
h.Len = l
h.Cap = l
} else { // convert to str
h := (*stringHeader)(p)
h.Data = x
h.Data = uintptr(x)
h.Len = l
}
copyn(x, b[i:], l)
Expand All @@ -193,7 +214,7 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int
// check types
kt := t.K
vt := t.V
if t0 != kt.T || t1 != vt.T {
if t0 != kt.WT || t1 != vt.WT {
return 0, errors.New("type mismatch")
}

Expand Down Expand Up @@ -271,20 +292,21 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int

// check types
et := t.V
if et.T != tp {
if et.WT != tp {
return 0, errors.New("type mismatch")
}

// decode list
h := (*sliceHeader)(p) // update the slice field
h.Data = unsafe.Pointer(nil)
h.Data = 0
h.Len = l
h.Cap = l
if l <= 0 {
*(*sliceHeader)(p) = zeroLenSlice
return i, nil
}
x := d.Malloc(l*et.Size, et.Align, et.MallocAbiType) // malloc for slice. make([]Type, l, l)
h.Data = x
h.Data = uintptr(x)

// pre-allocate space for elements if they're pointers
// like
Expand Down
27 changes: 24 additions & 3 deletions internal/reflect/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,17 @@ func TestDecode(t *testing.T) {
test: func(t *testing.T, p1 *TestTypes) { assert.Equal(t, vFloat64, p1.Double) },
},
{
name: "case_string",
update: func(p0 *TestTypes) { p0.String_ = "str" },
test: func(t *testing.T, p1 *TestTypes) { assert.Equal(t, "str", p1.String_) },
name: "case_string_binary",
update: func(p0 *TestTypes) { p0.String_ = "str"; p0.Binary = []byte{1} },
test: func(t *testing.T, p1 *TestTypes) {
assert.Equal(t, "str", p1.String_)
assert.Equal(t, []byte{1}, p1.Binary)
},
},
{
name: "case_zero_len_binary",
update: func(p0 *TestTypes) { p0.Binary = []byte{} },
test: func(t *testing.T, p1 *TestTypes) { assert.Equal(t, []byte{}, p1.Binary) },
},
{
name: "case_enum",
Expand Down Expand Up @@ -136,6 +144,19 @@ func TestDecode(t *testing.T) {
assert.Equal(t, []*Msg{{}, {Type: vInt32}}, p1.L2)
},
},
{
name: "case_zero_len_list",
update: func(p0 *TestTypes) {
p0.L0 = []int32{}
p0.L1 = []string{}
p0.L2 = []*Msg{}
},
test: func(t *testing.T, p1 *TestTypes) {
assert.Equal(t, []int32{}, p1.L0)
assert.Equal(t, []string{}, p1.L1)
assert.Equal(t, []*Msg{}, p1.L2)
},
},
{
name: "case_set",
update: func(p0 *TestTypes) {
Expand Down
2 changes: 1 addition & 1 deletion internal/reflect/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (e *tEncoder) encodeContainerType(t *tType, b []byte, p unsafe.Pointer) (in
}
binary.BigEndian.PutUint32(b[1:], uint32(h.Len))
i := listHeaderLen
vp := h.Data
vp := unsafe.Pointer(h.Data) // unsafe if h.Data points to non-heap. when?
// list elements
for j := 0; j < h.Len; j++ {
if j != 0 {
Expand Down
8 changes: 4 additions & 4 deletions internal/reflect/hack.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,15 @@ func rtTypePtr(rt reflect.Type) uintptr {
return (*iface)(unsafe.Pointer(&rt)).data
}

// same as reflect.StringHeader with Data type is unsafe.Pointer
// same as reflect.StringHeader
type stringHeader struct {
Data unsafe.Pointer
Data uintptr
Len int
}

// same as reflect.SliceHeader with Data type is unsafe.Pointer
// same as reflect.SliceHeader
type sliceHeader struct {
Data unsafe.Pointer
Data uintptr
Len int
Cap int
}
Expand Down
36 changes: 36 additions & 0 deletions internal/reflect/ttype.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,42 @@ const (
tENUM ttype = 0xfe // XXX: kitex issue, int64, but encode as int32 ...
)

func (t ttype) String() string {
switch t {
case tSTOP:
return "STOP"
case tVOID:
return "VOID"
case tBOOL:
return "BOOL"
case tI08:
return "I08"
case tDOUBLE:
return "DOUBLE"
case tI16:
return "I16"
case tI32:
return "I32"
case tI64:
return "I64"
case tSTRING:
return "STRING"
case tSTRUCT:
return "STRUCT"
case tMAP:
return "MAP"
case tSET:
return "SET"
case tLIST:
return "LIST"
case tUTF8:
return "UTF8"
case tUTF16:
return "UTF16"
}
return "UNKNOWN"
}

var simpleTypes = [256]bool{
tBOOL: true,
tBYTE: true,
Expand Down
8 changes: 7 additions & 1 deletion internal/reflect/unknownfields.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package reflect

import (
"runtime"
"sync"
"unsafe"
)
Expand Down Expand Up @@ -47,16 +48,21 @@ func (p *unknownFields) Size() int {

func (p *unknownFields) Copy(b []byte) []byte {
sz := p.Size()
data := mallocgc(uintptr(sz), nil, false) // without zeroing

ret := []byte{}
h := (*sliceHeader)(unsafe.Pointer(&ret))
h.Data = mallocgc(uintptr(sz), nil, false) // without zeroing
h.Data = uintptr(data)
h.Len = sz
h.Cap = sz
off := 0
for _, x := range p.offs {
copy(ret[off:], b[x.off:x.off+x.sz])
off += x.sz
}
// is this needed?
// just make sure the livecycle of data is after updates of ret
runtime.KeepAlive(data)
return ret
}

Expand Down
4 changes: 3 additions & 1 deletion internal/reflect/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package reflect
import (
"fmt"
"reflect"
"runtime"
"sync"
"unsafe"
)
Expand All @@ -28,10 +29,11 @@ import (
func copyn(dst unsafe.Pointer, src []byte, n int) {
var b []byte
hdr := (*sliceHeader)(unsafe.Pointer(&b))
hdr.Data = dst
hdr.Data = uintptr(dst)
hdr.Cap = n
hdr.Len = n
copy(b, src)
runtime.KeepAlive(dst)
}

// only be used when NewRequiredFieldNotSetException
Expand Down

0 comments on commit dd36d25

Please sign in to comment.