From dd36d2560e438d2c2ae6a64e0a74c298ed13fcee Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Sun, 28 Jul 2024 00:31:18 +0800 Subject: [PATCH] fix(reflect): zero len slice case --- internal/reflect/decoder.go | 34 +++++++++++++++++++++++------ internal/reflect/decoder_test.go | 27 ++++++++++++++++++++--- internal/reflect/encoder.go | 2 +- internal/reflect/hack.go | 8 +++---- internal/reflect/ttype.go | 36 +++++++++++++++++++++++++++++++ internal/reflect/unknownfields.go | 8 ++++++- internal/reflect/utils.go | 4 +++- 7 files changed, 103 insertions(+), 16 deletions(-) diff --git a/internal/reflect/decoder.go b/internal/reflect/decoder.go index 7cbda99..bb7f0f0 100644 --- a/internal/reflect/decoder.go +++ b/internal/reflect/decoder.go @@ -27,6 +27,22 @@ import ( "github.com/cloudwego/frugal/internal/binary/defs" ) +var ( + // for slice, Data should points to zerobase var in `runtime` + // so that it can represent as []type{} instead of []type(nil) + zeroLenSlice sliceHeader + + // for string, all fields should be zero + zeroLenStr stringHeader +) + +func init() { + b := make([]byte, 0) + zeroLenSlice = *(*sliceHeader)(unsafe.Pointer(&b)) + s := "" + zeroLenStr = *(*stringHeader)(unsafe.Pointer(&s)) +} + const maxDepthLimit = 1023 var decoderPool = sync.Pool{ @@ -169,17 +185,22 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int l := int(binary.BigEndian.Uint32(b)) i += 4 if l == 0 { + if t.Tag == defs.T_binary { + *(*sliceHeader)(p) = zeroLenSlice + } else { + *(*stringHeader)(p) = zeroLenStr + } return i, nil } x := d.Malloc(l, 1, 0) if t.Tag == defs.T_binary { h := (*sliceHeader)(p) - h.Data = x + h.Data = uintptr(x) h.Len = l h.Cap = l } else { // convert to str h := (*stringHeader)(p) - h.Data = x + h.Data = uintptr(x) h.Len = l } copyn(x, b[i:], l) @@ -193,7 +214,7 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int // check types kt := t.K vt := t.V - if t0 != kt.T || t1 != vt.T { + if t0 != kt.WT || t1 != vt.WT { return 0, errors.New("type mismatch") } @@ -271,20 +292,21 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int // check types et := t.V - if et.T != tp { + if et.WT != tp { return 0, errors.New("type mismatch") } // decode list h := (*sliceHeader)(p) // update the slice field - h.Data = unsafe.Pointer(nil) + h.Data = 0 h.Len = l h.Cap = l if l <= 0 { + *(*sliceHeader)(p) = zeroLenSlice return i, nil } x := d.Malloc(l*et.Size, et.Align, et.MallocAbiType) // malloc for slice. make([]Type, l, l) - h.Data = x + h.Data = uintptr(x) // pre-allocate space for elements if they're pointers // like diff --git a/internal/reflect/decoder_test.go b/internal/reflect/decoder_test.go index 8ae9dd2..631ebb2 100644 --- a/internal/reflect/decoder_test.go +++ b/internal/reflect/decoder_test.go @@ -89,9 +89,17 @@ func TestDecode(t *testing.T) { test: func(t *testing.T, p1 *TestTypes) { assert.Equal(t, vFloat64, p1.Double) }, }, { - name: "case_string", - update: func(p0 *TestTypes) { p0.String_ = "str" }, - test: func(t *testing.T, p1 *TestTypes) { assert.Equal(t, "str", p1.String_) }, + name: "case_string_binary", + update: func(p0 *TestTypes) { p0.String_ = "str"; p0.Binary = []byte{1} }, + test: func(t *testing.T, p1 *TestTypes) { + assert.Equal(t, "str", p1.String_) + assert.Equal(t, []byte{1}, p1.Binary) + }, + }, + { + name: "case_zero_len_binary", + update: func(p0 *TestTypes) { p0.Binary = []byte{} }, + test: func(t *testing.T, p1 *TestTypes) { assert.Equal(t, []byte{}, p1.Binary) }, }, { name: "case_enum", @@ -136,6 +144,19 @@ func TestDecode(t *testing.T) { assert.Equal(t, []*Msg{{}, {Type: vInt32}}, p1.L2) }, }, + { + name: "case_zero_len_list", + update: func(p0 *TestTypes) { + p0.L0 = []int32{} + p0.L1 = []string{} + p0.L2 = []*Msg{} + }, + test: func(t *testing.T, p1 *TestTypes) { + assert.Equal(t, []int32{}, p1.L0) + assert.Equal(t, []string{}, p1.L1) + assert.Equal(t, []*Msg{}, p1.L2) + }, + }, { name: "case_set", update: func(p0 *TestTypes) { diff --git a/internal/reflect/encoder.go b/internal/reflect/encoder.go index 4471f35..69c7196 100644 --- a/internal/reflect/encoder.go +++ b/internal/reflect/encoder.go @@ -145,7 +145,7 @@ func (e *tEncoder) encodeContainerType(t *tType, b []byte, p unsafe.Pointer) (in } binary.BigEndian.PutUint32(b[1:], uint32(h.Len)) i := listHeaderLen - vp := h.Data + vp := unsafe.Pointer(h.Data) // unsafe if h.Data points to non-heap. when? // list elements for j := 0; j < h.Len; j++ { if j != 0 { diff --git a/internal/reflect/hack.go b/internal/reflect/hack.go index 1e5d551..a18b981 100644 --- a/internal/reflect/hack.go +++ b/internal/reflect/hack.go @@ -221,15 +221,15 @@ func rtTypePtr(rt reflect.Type) uintptr { return (*iface)(unsafe.Pointer(&rt)).data } -// same as reflect.StringHeader with Data type is unsafe.Pointer +// same as reflect.StringHeader type stringHeader struct { - Data unsafe.Pointer + Data uintptr Len int } -// same as reflect.SliceHeader with Data type is unsafe.Pointer +// same as reflect.SliceHeader type sliceHeader struct { - Data unsafe.Pointer + Data uintptr Len int Cap int } diff --git a/internal/reflect/ttype.go b/internal/reflect/ttype.go index a3a26a5..0fae917 100644 --- a/internal/reflect/ttype.go +++ b/internal/reflect/ttype.go @@ -49,6 +49,42 @@ const ( tENUM ttype = 0xfe // XXX: kitex issue, int64, but encode as int32 ... ) +func (t ttype) String() string { + switch t { + case tSTOP: + return "STOP" + case tVOID: + return "VOID" + case tBOOL: + return "BOOL" + case tI08: + return "I08" + case tDOUBLE: + return "DOUBLE" + case tI16: + return "I16" + case tI32: + return "I32" + case tI64: + return "I64" + case tSTRING: + return "STRING" + case tSTRUCT: + return "STRUCT" + case tMAP: + return "MAP" + case tSET: + return "SET" + case tLIST: + return "LIST" + case tUTF8: + return "UTF8" + case tUTF16: + return "UTF16" + } + return "UNKNOWN" +} + var simpleTypes = [256]bool{ tBOOL: true, tBYTE: true, diff --git a/internal/reflect/unknownfields.go b/internal/reflect/unknownfields.go index 341f6c8..264bb6c 100644 --- a/internal/reflect/unknownfields.go +++ b/internal/reflect/unknownfields.go @@ -17,6 +17,7 @@ package reflect import ( + "runtime" "sync" "unsafe" ) @@ -47,9 +48,11 @@ func (p *unknownFields) Size() int { func (p *unknownFields) Copy(b []byte) []byte { sz := p.Size() + data := mallocgc(uintptr(sz), nil, false) // without zeroing + ret := []byte{} h := (*sliceHeader)(unsafe.Pointer(&ret)) - h.Data = mallocgc(uintptr(sz), nil, false) // without zeroing + h.Data = uintptr(data) h.Len = sz h.Cap = sz off := 0 @@ -57,6 +60,9 @@ func (p *unknownFields) Copy(b []byte) []byte { copy(ret[off:], b[x.off:x.off+x.sz]) off += x.sz } + // is this needed? + // just make sure the livecycle of data is after updates of ret + runtime.KeepAlive(data) return ret } diff --git a/internal/reflect/utils.go b/internal/reflect/utils.go index 655c140..c7a1bbc 100644 --- a/internal/reflect/utils.go +++ b/internal/reflect/utils.go @@ -19,6 +19,7 @@ package reflect import ( "fmt" "reflect" + "runtime" "sync" "unsafe" ) @@ -28,10 +29,11 @@ import ( func copyn(dst unsafe.Pointer, src []byte, n int) { var b []byte hdr := (*sliceHeader)(unsafe.Pointer(&b)) - hdr.Data = dst + hdr.Data = uintptr(dst) hdr.Cap = n hdr.Len = n copy(b, src) + runtime.KeepAlive(dst) } // only be used when NewRequiredFieldNotSetException