From f40ffab66dd74a1e8633278c54124fac6d7b88fd Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Wed, 4 Dec 2024 12:59:37 +0800 Subject: [PATCH] fix(reflect): deep nested struct cause panic --- internal/reflect/desc.go | 35 ++++++++++++++++++++++---------- internal/reflect/encoder_test.go | 25 +++++++++++++++++++++++ 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/internal/reflect/desc.go b/internal/reflect/desc.go index c584fbd..8b4a958 100644 --- a/internal/reflect/desc.go +++ b/internal/reflect/desc.go @@ -92,23 +92,36 @@ func newStructDescAndPrefetch(t reflect.Type) (*structDesc, error) { func prefetchSubStructDesc(d *structDesc) error { for i := range d.fields { - var t *tType f := d.fields[i] - if f.Type.T == tSTRUCT { - t = f.Type - } else if f.Type.T == tMAP && f.Type.V.T == tSTRUCT { - t = f.Type.V - } else if f.Type.T == tLIST && f.Type.V.T == tSTRUCT { - t = f.Type.V - } else { - continue + switch f.Type.T { + case tSTRUCT, tMAP, tLIST, tSET: + if err := fetchStructDesc(f.Type); err != nil { + return err + } } - sd, err := newStructDescAndPrefetch(t.RT) + } + return nil +} + +func fetchStructDesc(t *tType) error { + if t.T == tMAP { + err := fetchStructDesc(t.K) if err != nil { return err } - t.Sd = sd + return fetchStructDesc(t.V) + } + if t.T == tLIST || t.T == tSET { + return fetchStructDesc(t.V) + } + if t.T != tSTRUCT || t.Sd != nil { + return nil + } + sd, err := newStructDescAndPrefetch(t.RT) + if err != nil { + return err } + t.Sd = sd return nil } diff --git a/internal/reflect/encoder_test.go b/internal/reflect/encoder_test.go index 5cc0b26..34dad4e 100644 --- a/internal/reflect/encoder_test.go +++ b/internal/reflect/encoder_test.go @@ -145,3 +145,28 @@ func TestEncodeUnknownFields(t *testing.T) { assert.Equal(t, n, len(b)) assert.Contains(t, string(b), string(append([]byte("helloworld")[:], byte(tSTOP)))) } + +func TestNestedListMapStruct(t *testing.T) { + type Msg1 struct { + A string `frugal:"1,default,string"` + B string `frugal:"2,default,string"` + } + type Msg2 struct { + Msgs []map[string]*Msg1 `thrift:"item_list,2" frugal:"2,default,list>" json:"item_list"` + } + p := &Msg2{} + p.Msgs = make([]map[string]*Msg1, 0, 1) + p.Msgs = append(p.Msgs, map[string]*Msg1{}) + p.Msgs[0]["32"] = &Msg1{A: "Hello", B: "World"} + + b := make([]byte, EncodedSize(p)) + i, err := Encode(b, p) + require.NoError(t, err) + require.Equal(t, i, len(b)) + + p2 := &Msg2{} + i, err = Decode(b, p2) + require.NoError(t, err) + require.Equal(t, i, len(b)) + require.Equal(t, p, p2) +}