From 7c7d7229fb1bf7d8ab724c7bb32a593071ff4b93 Mon Sep 17 00:00:00 2001 From: Zhengyao Xie Date: Fri, 19 Jul 2024 15:27:31 +0800 Subject: [PATCH] feat(thrift): migrate unknownfields from kitex --- .../thrift/unknownfields/unknownfields.go | 354 ++++++++++++++++++ .../unknownfields/unknownfields_test.go | 123 ++++++ 2 files changed, 477 insertions(+) create mode 100644 protocol/thrift/unknownfields/unknownfields.go create mode 100644 protocol/thrift/unknownfields/unknownfields_test.go diff --git a/protocol/thrift/unknownfields/unknownfields.go b/protocol/thrift/unknownfields/unknownfields.go new file mode 100644 index 0000000..ec4ace7 --- /dev/null +++ b/protocol/thrift/unknownfields/unknownfields.go @@ -0,0 +1,354 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "errors" + "fmt" + "reflect" + + "github.com/cloudwego/gopkg/protocol/thrift" +) + +// UnknownFieldsName ... generated by thriftgo, check thriftgo for more information +const UnknownFieldsName = "_unknownFields" + +// UnknownField is used to describe an unknown field. +type UnknownField struct { + ID int16 + Type thrift.TType + KeyType thrift.TType + ValType thrift.TType + Value interface{} +} + +// GetUnknownFields deserialize unknownFields stored in v to a list of *UnknownFields. +func GetUnknownFields(v interface{}) (fields []UnknownField, err error) { + var buf []byte + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Ptr && !rv.IsNil() { + rv = rv.Elem() + } + if rv.Kind() != reflect.Struct { + return nil, fmt.Errorf("%T is not a struct type", v) + } + if unknownField := rv.FieldByName(UnknownFieldsName); !unknownField.IsValid() { + return nil, fmt.Errorf("%T has no field named '%s'", v, UnknownFieldsName) + } else { + buf = unknownField.Bytes() + } + return ConvertUnknownFields(buf) +} + +// ConvertUnknownFields converts buf to deserialized unknown fields. +func ConvertUnknownFields(buf []byte) (fields []UnknownField, err error) { + if len(buf) == 0 { + return nil, errors.New("_unknownFields is empty") + } + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + var f UnknownField + if offset == len(buf) { + return + } + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + return nil, fmt.Errorf("read field %d begin error: %v", fieldId, err) + } + l, err = readUnknownField(&f, buf[offset:], fieldTypeId, fieldId) + offset += l + if err != nil { + return nil, fmt.Errorf("read unknown field %d error: %v", fieldId, err) + } + fields = append(fields, f) + } +} + +func readUnknownField(f *UnknownField, buf []byte, fieldType thrift.TType, id int16) (length int, err error) { + var size int + var l int + f.ID = id + f.Type = fieldType + switch fieldType { + case thrift.BOOL: + f.Value, l, err = thrift.Binary.ReadBool(buf[length:]) + length += l + case thrift.BYTE: + f.Value, l, err = thrift.Binary.ReadByte(buf[length:]) + length += l + case thrift.I16: + f.Value, l, err = thrift.Binary.ReadI16(buf[length:]) + length += l + case thrift.I32: + f.Value, l, err = thrift.Binary.ReadI32(buf[length:]) + length += l + case thrift.I64: + f.Value, l, err = thrift.Binary.ReadI64(buf[length:]) + length += l + case thrift.DOUBLE: + f.Value, l, err = thrift.Binary.ReadDouble(buf[length:]) + length += l + case thrift.STRING: + f.Value, l, err = thrift.Binary.ReadString(buf[length:]) + length += l + case thrift.SET: + var ttype thrift.TType + ttype, size, l, err = thrift.Binary.ReadSetBegin(buf[length:]) + length += l + if err != nil { + return length, fmt.Errorf("read set begin error: %w", err) + } + f.ValType = ttype + set := make([]UnknownField, size) + for i := 0; i < size; i++ { + l, err2 := readUnknownField(&set[i], buf[length:], f.ValType, int16(i)) + length += l + if err2 != nil { + return length, fmt.Errorf("read set elem error: %w", err2) + } + } + f.Value = set + case thrift.LIST: + var ttype thrift.TType + ttype, size, l, err = thrift.Binary.ReadListBegin(buf[length:]) + length += l + if err != nil { + return length, fmt.Errorf("read list begin error: %w", err) + } + f.ValType = ttype + list := make([]UnknownField, size) + for i := 0; i < size; i++ { + l, err2 := readUnknownField(&list[i], buf[length:], f.ValType, int16(i)) + length += l + if err2 != nil { + return length, fmt.Errorf("read list elem error: %w", err2) + } + } + f.Value = list + case thrift.MAP: + var kttype, vttype thrift.TType + kttype, vttype, size, l, err = thrift.Binary.ReadMapBegin(buf[length:]) + length += l + if err != nil { + return length, fmt.Errorf("read map begin error: %w", err) + } + f.KeyType = kttype + f.ValType = vttype + flatMap := make([]UnknownField, size*2) + for i := 0; i < size; i++ { + l, err2 := readUnknownField(&flatMap[2*i], buf[length:], f.KeyType, int16(i)) + length += l + if err2 != nil { + return length, fmt.Errorf("read map key error: %w", err2) + } + l, err2 = readUnknownField(&flatMap[2*i+1], buf[length:], f.ValType, int16(i)) + length += l + if err2 != nil { + return length, fmt.Errorf("read map value error: %w", err2) + } + } + f.Value = flatMap + case thrift.STRUCT: + var field UnknownField + var fields []UnknownField + for { + fieldTypeID, fieldID, l, err := thrift.Binary.ReadFieldBegin(buf[length:]) + length += l + if err != nil { + return length, fmt.Errorf("read field begin error: %w", err) + } + if fieldTypeID == thrift.STOP { + break + } + l, err = readUnknownField(&field, buf[length:], fieldTypeID, fieldID) + length += l + if err != nil { + return length, fmt.Errorf("read struct field error: %w", err) + } + fields = append(fields, field) + } + f.Value = fields + default: + return length, fmt.Errorf("unknown data type %d", f.Type) + } + if err != nil { + return length, err + } + return +} + +// UnknownFieldsLength returns the length of fs. +func UnknownFieldsLength(fs []UnknownField) (int, error) { + l := 0 + for _, f := range fs { + l += thrift.Binary.FieldBeginLength() + ll, err := unknownFieldLength(&f) + l += ll + if err != nil { + return l, err + } + } + return l, nil +} + +func unknownFieldLength(f *UnknownField) (length int, err error) { + // use constants to avoid some type assert + switch f.Type { + case thrift.BOOL: + length += thrift.Binary.BoolLength() + case thrift.BYTE: + length += thrift.Binary.ByteLength() + case thrift.DOUBLE: + length += thrift.Binary.DoubleLength() + case thrift.I16: + length += thrift.Binary.I16Length() + case thrift.I32: + length += thrift.Binary.I32Length() + case thrift.I64: + length += thrift.Binary.I64Length() + case thrift.STRING: + length += thrift.Binary.StringLength(f.Value.(string)) + case thrift.SET: + vs := f.Value.([]UnknownField) + length += thrift.Binary.SetBeginLength() + for _, v := range vs { + l, err := unknownFieldLength(&v) + length += l + if err != nil { + return length, err + } + } + case thrift.LIST: + vs := f.Value.([]UnknownField) + length += thrift.Binary.ListBeginLength() + for _, v := range vs { + l, err := unknownFieldLength(&v) + length += l + if err != nil { + return length, err + } + } + case thrift.MAP: + kvs := f.Value.([]UnknownField) + length += thrift.Binary.MapBeginLength() + for i := 0; i < len(kvs); i += 2 { + l, err := unknownFieldLength(&kvs[i]) + length += l + if err != nil { + return length, err + } + l, err = unknownFieldLength(&kvs[i+1]) + length += l + if err != nil { + return length, err + } + } + case thrift.STRUCT: + fs := f.Value.([]UnknownField) + l, err := UnknownFieldsLength(fs) + length += l + if err != nil { + return length, err + } + length += thrift.Binary.FieldStopLength() + default: + return length, fmt.Errorf("unknown data type %d", f.Type) + } + return +} + +// WriteUnknownFields writes fs into buf, and return written offset of the buf. +func WriteUnknownFields(buf []byte, fs []UnknownField) (offset int, err error) { + for _, f := range fs { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], f.Type, f.ID) + l, err := writeUnknownField(buf[offset:], &f) + offset += l + if err != nil { + return offset, err + } + } + return offset, nil +} + +func writeUnknownField(buf []byte, f *UnknownField) (offset int, err error) { + switch f.Type { + case thrift.BOOL: + offset += thrift.Binary.WriteBool(buf, f.Value.(bool)) + case thrift.BYTE: + offset += thrift.Binary.WriteByte(buf, f.Value.(int8)) + case thrift.DOUBLE: + offset += thrift.Binary.WriteDouble(buf, f.Value.(float64)) + case thrift.I16: + offset += thrift.Binary.WriteI16(buf, f.Value.(int16)) + case thrift.I32: + offset += thrift.Binary.WriteI32(buf, f.Value.(int32)) + case thrift.I64: + offset += thrift.Binary.WriteI64(buf, f.Value.(int64)) + case thrift.STRING: + offset += thrift.Binary.WriteString(buf, f.Value.(string)) + case thrift.SET: + vs := f.Value.([]UnknownField) + offset += thrift.Binary.WriteSetBegin(buf, f.ValType, len(vs)) + for _, v := range vs { + l, err := writeUnknownField(buf[offset:], &v) + offset += l + if err != nil { + return offset, err + } + } + case thrift.LIST: + vs := f.Value.([]UnknownField) + offset += thrift.Binary.WriteListBegin(buf, f.ValType, len(vs)) + for _, v := range vs { + l, err := writeUnknownField(buf[offset:], &v) + offset += l + if err != nil { + return offset, err + } + } + case thrift.MAP: + kvs := f.Value.([]UnknownField) + offset += thrift.Binary.WriteMapBegin(buf, f.KeyType, f.ValType, len(kvs)/2) + for i := 0; i < len(kvs); i += 2 { + l, err := writeUnknownField(buf[offset:], &kvs[i]) + offset += l + if err != nil { + return offset, err + } + l, err = writeUnknownField(buf[offset:], &kvs[i+1]) + offset += l + if err != nil { + return offset, err + } + } + case thrift.STRUCT: + fs := f.Value.([]UnknownField) + l, err := WriteUnknownFields(buf[offset:], fs) + offset += l + if err != nil { + return offset, err + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + default: + return offset, fmt.Errorf("unknown data type %d", f.Type) + } + return +} diff --git a/protocol/thrift/unknownfields/unknownfields_test.go b/protocol/thrift/unknownfields/unknownfields_test.go new file mode 100644 index 0000000..955df7d --- /dev/null +++ b/protocol/thrift/unknownfields/unknownfields_test.go @@ -0,0 +1,123 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "testing" + + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnknownFields(t *testing.T) { + type A struct { + _unknownFields []byte + } + a := &A{} + + // prepare data + b := make([]byte, 0, 1024) + expect := make([]UnknownField, 0, 11) // 11 testcases + + // BOOL, fid=1 + b = thrift.Binary.AppendFieldBegin(b, thrift.BOOL, 1) + b = thrift.Binary.AppendBool(b, true) + expect = append(expect, UnknownField{ID: 1, Type: thrift.BOOL, Value: true}) + + // BYTE, fid=2 + b = thrift.Binary.AppendFieldBegin(b, thrift.BYTE, 2) + b = thrift.Binary.AppendByte(b, 2) + expect = append(expect, UnknownField{ID: 2, Type: thrift.BYTE, Value: int8(2)}) + + // I16, fid=3 + b = thrift.Binary.AppendFieldBegin(b, thrift.I16, 3) + b = thrift.Binary.AppendI16(b, 3) + expect = append(expect, UnknownField{ID: 3, Type: thrift.I16, Value: int16(3)}) + + // I32, fid=4 + b = thrift.Binary.AppendFieldBegin(b, thrift.I32, 4) + b = thrift.Binary.AppendI32(b, 4) + expect = append(expect, UnknownField{ID: 4, Type: thrift.I32, Value: int32(4)}) + + // I64, fid=5 + b = thrift.Binary.AppendFieldBegin(b, thrift.I64, 5) + b = thrift.Binary.AppendI64(b, 5) + expect = append(expect, UnknownField{ID: 5, Type: thrift.I64, Value: int64(5)}) + + // DOUBLE, fid=6 + b = thrift.Binary.AppendFieldBegin(b, thrift.DOUBLE, 6) + b = thrift.Binary.AppendDouble(b, 6) + expect = append(expect, UnknownField{ID: 6, Type: thrift.DOUBLE, Value: float64(6)}) + + // STRING, fid=7 + b = thrift.Binary.AppendFieldBegin(b, thrift.STRING, 7) + b = thrift.Binary.AppendString(b, "7") + expect = append(expect, UnknownField{ID: 7, Type: thrift.STRING, Value: "7"}) + + // MAP, fid=8 + b = thrift.Binary.AppendFieldBegin(b, thrift.MAP, 8) + b = thrift.Binary.AppendMapBegin(b, thrift.DOUBLE, thrift.DOUBLE, 1) + b = thrift.Binary.AppendDouble(b, 8.1) + b = thrift.Binary.AppendDouble(b, 8.2) + expect = append(expect, UnknownField{ID: 8, Type: thrift.MAP, + KeyType: thrift.DOUBLE, ValType: thrift.DOUBLE, + Value: []UnknownField{ + {Type: thrift.DOUBLE, Value: float64(8.1)}, + {Type: thrift.DOUBLE, Value: float64(8.2)}, + }}) + + // SET, fid=9 + b = thrift.Binary.AppendFieldBegin(b, thrift.SET, 9) + b = thrift.Binary.AppendSetBegin(b, thrift.I64, 1) + b = thrift.Binary.AppendI64(b, 9) + expect = append(expect, UnknownField{ID: 9, Type: thrift.SET, + ValType: thrift.I64, + Value: []UnknownField{ + {Type: thrift.I64, Value: int64(9)}, + }}) + + // LIST, fid=10 + b = thrift.Binary.AppendFieldBegin(b, thrift.LIST, 10) + b = thrift.Binary.AppendListBegin(b, thrift.I64, 1) + b = thrift.Binary.AppendI64(b, 10) + expect = append(expect, UnknownField{ID: 10, Type: thrift.LIST, + ValType: thrift.I64, + Value: []UnknownField{ + {Type: thrift.I64, Value: int64(10)}, + }}) + + // STRUCT with 1 field I64, fid=11,1 + b = thrift.Binary.AppendFieldBegin(b, thrift.STRUCT, 11) + b = thrift.Binary.AppendFieldBegin(b, thrift.I64, 1) + b = thrift.Binary.AppendI64(b, 11) + b = thrift.Binary.AppendFieldStop(b) + expect = append(expect, UnknownField{ID: 11, Type: thrift.STRUCT, Value: []UnknownField{ + {ID: 1, Type: thrift.I64, Value: int64(11)}, + }}) + + // decode + a._unknownFields = b + fields, err := GetUnknownFields(a) + require.NoError(t, err) + require.Equal(t, len(expect), len(fields)) + + // test fields + for i := range fields { + assert.Equal(t, expect[i], fields[i]) + } +}