From 2659f3f6a4b6f473b0218ce7234d5ca446d7c3dc Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Sun, 28 Jul 2024 00:31:18 +0800 Subject: [PATCH] fix(reflect): zero len slice case --- internal/reflect/decoder.go | 34 +++++++++++++++++++++++++------ internal/reflect/decoder_test.go | 27 +++++++++++++++++++++--- internal/reflect/encoder.go | 2 +- internal/reflect/hack.go | 20 ++++++++++++++---- internal/reflect/ttype.go | 5 ++++- internal/reflect/unknownfields.go | 8 +++++++- internal/reflect/utils.go | 4 +++- 7 files changed, 83 insertions(+), 17 deletions(-) diff --git a/internal/reflect/decoder.go b/internal/reflect/decoder.go index 7cbda99..bb7f0f0 100644 --- a/internal/reflect/decoder.go +++ b/internal/reflect/decoder.go @@ -27,6 +27,22 @@ import ( "github.com/cloudwego/frugal/internal/binary/defs" ) +var ( + // for slice, Data should points to zerobase var in `runtime` + // so that it can represent as []type{} instead of []type(nil) + zeroLenSlice sliceHeader + + // for string, all fields should be zero + zeroLenStr stringHeader +) + +func init() { + b := make([]byte, 0) + zeroLenSlice = *(*sliceHeader)(unsafe.Pointer(&b)) + s := "" + zeroLenStr = *(*stringHeader)(unsafe.Pointer(&s)) +} + const maxDepthLimit = 1023 var decoderPool = sync.Pool{ @@ -169,17 +185,22 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int l := int(binary.BigEndian.Uint32(b)) i += 4 if l == 0 { + if t.Tag == defs.T_binary { + *(*sliceHeader)(p) = zeroLenSlice + } else { + *(*stringHeader)(p) = zeroLenStr + } return i, nil } x := d.Malloc(l, 1, 0) if t.Tag == defs.T_binary { h := (*sliceHeader)(p) - h.Data = x + h.Data = uintptr(x) h.Len = l h.Cap = l } else { // convert to str h := (*stringHeader)(p) - h.Data = x + h.Data = uintptr(x) h.Len = l } copyn(x, b[i:], l) @@ -193,7 +214,7 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int // check types kt := t.K vt := t.V - if t0 != kt.T || t1 != vt.T { + if t0 != kt.WT || t1 != vt.WT { return 0, errors.New("type mismatch") } @@ -271,20 +292,21 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int // check types et := t.V - if et.T != tp { + if et.WT != tp { return 0, errors.New("type mismatch") } // decode list h := (*sliceHeader)(p) // update the slice field - h.Data = unsafe.Pointer(nil) + h.Data = 0 h.Len = l h.Cap = l if l <= 0 { + *(*sliceHeader)(p) = zeroLenSlice return i, nil } x := d.Malloc(l*et.Size, et.Align, et.MallocAbiType) // malloc for slice. make([]Type, l, l) - h.Data = x + h.Data = uintptr(x) // pre-allocate space for elements if they're pointers // like diff --git a/internal/reflect/decoder_test.go b/internal/reflect/decoder_test.go index 8ae9dd2..631ebb2 100644 --- a/internal/reflect/decoder_test.go +++ b/internal/reflect/decoder_test.go @@ -89,9 +89,17 @@ func TestDecode(t *testing.T) { test: func(t *testing.T, p1 *TestTypes) { assert.Equal(t, vFloat64, p1.Double) }, }, { - name: "case_string", - update: func(p0 *TestTypes) { p0.String_ = "str" }, - test: func(t *testing.T, p1 *TestTypes) { assert.Equal(t, "str", p1.String_) }, + name: "case_string_binary", + update: func(p0 *TestTypes) { p0.String_ = "str"; p0.Binary = []byte{1} }, + test: func(t *testing.T, p1 *TestTypes) { + assert.Equal(t, "str", p1.String_) + assert.Equal(t, []byte{1}, p1.Binary) + }, + }, + { + name: "case_zero_len_binary", + update: func(p0 *TestTypes) { p0.Binary = []byte{} }, + test: func(t *testing.T, p1 *TestTypes) { assert.Equal(t, []byte{}, p1.Binary) }, }, { name: "case_enum", @@ -136,6 +144,19 @@ func TestDecode(t *testing.T) { assert.Equal(t, []*Msg{{}, {Type: vInt32}}, p1.L2) }, }, + { + name: "case_zero_len_list", + update: func(p0 *TestTypes) { + p0.L0 = []int32{} + p0.L1 = []string{} + p0.L2 = []*Msg{} + }, + test: func(t *testing.T, p1 *TestTypes) { + assert.Equal(t, []int32{}, p1.L0) + assert.Equal(t, []string{}, p1.L1) + assert.Equal(t, []*Msg{}, p1.L2) + }, + }, { name: "case_set", update: func(p0 *TestTypes) { diff --git a/internal/reflect/encoder.go b/internal/reflect/encoder.go index 4471f35..b202fc7 100644 --- a/internal/reflect/encoder.go +++ b/internal/reflect/encoder.go @@ -145,7 +145,7 @@ func (e *tEncoder) encodeContainerType(t *tType, b []byte, p unsafe.Pointer) (in } binary.BigEndian.PutUint32(b[1:], uint32(h.Len)) i := listHeaderLen - vp := h.Data + vp := h.UnsafePointer() // list elements for j := 0; j < h.Len; j++ { if j != 0 { diff --git a/internal/reflect/hack.go b/internal/reflect/hack.go index 1e5d551..cb08617 100644 --- a/internal/reflect/hack.go +++ b/internal/reflect/hack.go @@ -221,19 +221,31 @@ func rtTypePtr(rt reflect.Type) uintptr { return (*iface)(unsafe.Pointer(&rt)).data } -// same as reflect.StringHeader with Data type is unsafe.Pointer +// same as reflect.StringHeader type stringHeader struct { - Data unsafe.Pointer + Data uintptr Len int } -// same as reflect.SliceHeader with Data type is unsafe.Pointer +// UnsafePointer ... for passing checkptr +// `p := unsafe.Pointer(h.Data)` is NOT allowed when testing with -race +func (h *stringHeader) UnsafePointer() unsafe.Pointer { + return *(*unsafe.Pointer)(unsafe.Pointer(h)) +} + +// same as reflect.SliceHeader type sliceHeader struct { - Data unsafe.Pointer + Data uintptr Len int Cap int } +// UnsafePointer ... for passing checkptr +// `p := unsafe.Pointer(h.Data)` is NOT allowed when testing with -race +func (h *sliceHeader) UnsafePointer() unsafe.Pointer { + return *(*unsafe.Pointer)(unsafe.Pointer(h)) +} + //go:linkname mallocgc runtime.mallocgc func mallocgc(size uintptr, typ unsafe.Pointer, needzero bool) unsafe.Pointer diff --git a/internal/reflect/ttype.go b/internal/reflect/ttype.go index a3a26a5..3a9f141 100644 --- a/internal/reflect/ttype.go +++ b/internal/reflect/ttype.go @@ -310,7 +310,10 @@ func (t *tType) encodedListSize(p unsafe.Pointer) (int, error) { return listHeaderLen + (h.Len * vt.FixedSize), nil } ret := listHeaderLen - vp := unsafe.Pointer(h.Data) + if h.Len == 0 { + return ret, nil + } + vp := h.UnsafePointer() for i := 0; i < h.Len; i++ { if i != 0 { vp = unsafe.Add(vp, vt.Size) // move to next element diff --git a/internal/reflect/unknownfields.go b/internal/reflect/unknownfields.go index 341f6c8..264bb6c 100644 --- a/internal/reflect/unknownfields.go +++ b/internal/reflect/unknownfields.go @@ -17,6 +17,7 @@ package reflect import ( + "runtime" "sync" "unsafe" ) @@ -47,9 +48,11 @@ func (p *unknownFields) Size() int { func (p *unknownFields) Copy(b []byte) []byte { sz := p.Size() + data := mallocgc(uintptr(sz), nil, false) // without zeroing + ret := []byte{} h := (*sliceHeader)(unsafe.Pointer(&ret)) - h.Data = mallocgc(uintptr(sz), nil, false) // without zeroing + h.Data = uintptr(data) h.Len = sz h.Cap = sz off := 0 @@ -57,6 +60,9 @@ func (p *unknownFields) Copy(b []byte) []byte { copy(ret[off:], b[x.off:x.off+x.sz]) off += x.sz } + // is this needed? + // just make sure the livecycle of data is after updates of ret + runtime.KeepAlive(data) return ret } diff --git a/internal/reflect/utils.go b/internal/reflect/utils.go index 655c140..c7a1bbc 100644 --- a/internal/reflect/utils.go +++ b/internal/reflect/utils.go @@ -19,6 +19,7 @@ package reflect import ( "fmt" "reflect" + "runtime" "sync" "unsafe" ) @@ -28,10 +29,11 @@ import ( func copyn(dst unsafe.Pointer, src []byte, n int) { var b []byte hdr := (*sliceHeader)(unsafe.Pointer(&b)) - hdr.Data = dst + hdr.Data = uintptr(dst) hdr.Cap = n hdr.Len = n copy(b, src) + runtime.KeepAlive(dst) } // only be used when NewRequiredFieldNotSetException