Skip to content

Commit

Permalink
feat(thrift): migrate unknownfields from kitex
Browse files Browse the repository at this point in the history
  • Loading branch information
jayantxie authored and xiaost committed Jul 19, 2024
1 parent dbcaae7 commit 7822bba
Show file tree
Hide file tree
Showing 2 changed files with 461 additions and 0 deletions.
354 changes: 354 additions & 0 deletions protocol/thrift/unknownfields/unknownfields.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,354 @@
/*
* Copyright 2023 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package thrift

import (
"errors"
"fmt"
"reflect"

"github.com/cloudwego/gopkg/protocol/thrift"
)

// UnknownFieldsName ... generated by thriftgo, check thriftgo for more information
const UnknownFieldsName = "_unknownFields"

// UnknownField is used to describe an unknown field.
type UnknownField struct {
ID int16
Type thrift.TType
KeyType thrift.TType
ValType thrift.TType
Value interface{}
}

// GetUnknownFields deserialize unknownFields stored in v to a list of *UnknownFields.
func GetUnknownFields(v interface{}) (fields []UnknownField, err error) {
var buf []byte
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr && !rv.IsNil() {
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
return nil, fmt.Errorf("%T is not a struct type", v)
}
if unknownField := rv.FieldByName(UnknownFieldsName); !unknownField.IsValid() {
return nil, fmt.Errorf("%T has no field named '%s'", v, UnknownFieldsName)
} else {
buf = unknownField.Bytes()
}
return ConvertUnknownFields(buf)
}

// ConvertUnknownFields converts buf to deserialized unknown fields.
func ConvertUnknownFields(buf []byte) (fields []UnknownField, err error) {
if len(buf) == 0 {
return nil, errors.New("_unknownFields is empty")
}
var offset int
var l int
var fieldTypeId thrift.TType
var fieldId int16
for {
var f UnknownField
if offset == len(buf) {
return
}
fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:])
offset += l
if err != nil {
return nil, fmt.Errorf("read field %d begin error: %v", fieldId, err)
}
l, err = readUnknownField(&f, buf[offset:], fieldTypeId, fieldId)
offset += l
if err != nil {
return nil, fmt.Errorf("read unknown field %d error: %v", fieldId, err)
}
fields = append(fields, f)
}
}

func readUnknownField(f *UnknownField, buf []byte, fieldType thrift.TType, id int16) (length int, err error) {
var size int
var l int
f.ID = id
f.Type = fieldType
switch fieldType {
case thrift.BOOL:
f.Value, l, err = thrift.Binary.ReadBool(buf[length:])
length += l
case thrift.BYTE:
f.Value, l, err = thrift.Binary.ReadByte(buf[length:])
length += l
case thrift.I16:
f.Value, l, err = thrift.Binary.ReadI16(buf[length:])
length += l
case thrift.I32:
f.Value, l, err = thrift.Binary.ReadI32(buf[length:])
length += l
case thrift.I64:
f.Value, l, err = thrift.Binary.ReadI64(buf[length:])
length += l
case thrift.DOUBLE:
f.Value, l, err = thrift.Binary.ReadDouble(buf[length:])
length += l
case thrift.STRING:
f.Value, l, err = thrift.Binary.ReadString(buf[length:])
length += l
case thrift.SET:
var ttype thrift.TType
ttype, size, l, err = thrift.Binary.ReadSetBegin(buf[length:])
length += l
if err != nil {
return length, fmt.Errorf("read set begin error: %w", err)
}
f.ValType = ttype
set := make([]UnknownField, size)
for i := 0; i < size; i++ {
l, err2 := readUnknownField(&set[i], buf[length:], f.ValType, int16(i))
length += l
if err2 != nil {
return length, fmt.Errorf("read set elem error: %w", err2)
}
}
f.Value = set
case thrift.LIST:
var ttype thrift.TType
ttype, size, l, err = thrift.Binary.ReadListBegin(buf[length:])
length += l
if err != nil {
return length, fmt.Errorf("read list begin error: %w", err)
}
f.ValType = ttype
list := make([]UnknownField, size)
for i := 0; i < size; i++ {
l, err2 := readUnknownField(&list[i], buf[length:], f.ValType, int16(i))
length += l
if err2 != nil {
return length, fmt.Errorf("read list elem error: %w", err2)
}
}
f.Value = list
case thrift.MAP:
var kttype, vttype thrift.TType
kttype, vttype, size, l, err = thrift.Binary.ReadMapBegin(buf[length:])
length += l
if err != nil {
return length, fmt.Errorf("read map begin error: %w", err)
}
f.KeyType = kttype
f.ValType = vttype
flatMap := make([]UnknownField, size*2)
for i := 0; i < size; i++ {
l, err2 := readUnknownField(&flatMap[2*i], buf[length:], f.KeyType, int16(i))
length += l
if err2 != nil {
return length, fmt.Errorf("read map key error: %w", err2)
}
l, err2 = readUnknownField(&flatMap[2*i+1], buf[length:], f.ValType, int16(i))
length += l
if err2 != nil {
return length, fmt.Errorf("read map value error: %w", err2)
}
}
f.Value = flatMap
case thrift.STRUCT:
var field UnknownField
var fields []UnknownField
for {
fieldTypeID, fieldID, l, err := thrift.Binary.ReadFieldBegin(buf[length:])
length += l
if err != nil {
return length, fmt.Errorf("read field begin error: %w", err)
}
if fieldTypeID == thrift.STOP {
break
}
l, err = readUnknownField(&field, buf[length:], fieldTypeID, fieldID)
length += l
if err != nil {
return length, fmt.Errorf("read struct field error: %w", err)
}
fields = append(fields, field)
}
f.Value = fields
default:
return length, fmt.Errorf("unknown data type %d", f.Type)
}
if err != nil {
return length, err
}
return
}

// UnknownFieldsLength returns the length of fs.
func UnknownFieldsLength(fs []UnknownField) (int, error) {
l := 0
for _, f := range fs {
l += thrift.Binary.FieldBeginLength()
ll, err := unknownFieldLength(&f)
l += ll
if err != nil {
return l, err
}
}
return l, nil
}

func unknownFieldLength(f *UnknownField) (length int, err error) {
// use constants to avoid some type assert
switch f.Type {
case thrift.BOOL:
length += thrift.Binary.BoolLength()
case thrift.BYTE:
length += thrift.Binary.ByteLength()
case thrift.DOUBLE:
length += thrift.Binary.DoubleLength()
case thrift.I16:
length += thrift.Binary.I16Length()
case thrift.I32:
length += thrift.Binary.I32Length()
case thrift.I64:
length += thrift.Binary.I64Length()
case thrift.STRING:
length += thrift.Binary.StringLength(f.Value.(string))
case thrift.SET:
vs := f.Value.([]UnknownField)
length += thrift.Binary.SetBeginLength()
for _, v := range vs {
l, err := unknownFieldLength(&v)
length += l
if err != nil {
return length, err
}
}
case thrift.LIST:
vs := f.Value.([]UnknownField)
length += thrift.Binary.ListBeginLength()
for _, v := range vs {
l, err := unknownFieldLength(&v)
length += l
if err != nil {
return length, err
}
}
case thrift.MAP:
kvs := f.Value.([]UnknownField)
length += thrift.Binary.MapBeginLength()
for i := 0; i < len(kvs); i += 2 {
l, err := unknownFieldLength(&kvs[i])
length += l
if err != nil {
return length, err
}
l, err = unknownFieldLength(&kvs[i+1])
length += l
if err != nil {
return length, err
}
}
case thrift.STRUCT:
fs := f.Value.([]UnknownField)
l, err := UnknownFieldsLength(fs)
length += l
if err != nil {
return length, err
}
length += thrift.Binary.FieldStopLength()
default:
return length, fmt.Errorf("unknown data type %d", f.Type)
}
return
}

// WriteUnknownFields writes fs into buf, and return written offset of the buf.
func WriteUnknownFields(buf []byte, fs []UnknownField) (offset int, err error) {
for _, f := range fs {
offset += thrift.Binary.WriteFieldBegin(buf[offset:], f.Type, f.ID)
l, err := writeUnknownField(buf[offset:], &f)
offset += l
if err != nil {
return offset, err
}
}
return offset, nil
}

func writeUnknownField(buf []byte, f *UnknownField) (offset int, err error) {
switch f.Type {
case thrift.BOOL:
offset += thrift.Binary.WriteBool(buf, f.Value.(bool))
case thrift.BYTE:
offset += thrift.Binary.WriteByte(buf, f.Value.(int8))
case thrift.DOUBLE:
offset += thrift.Binary.WriteDouble(buf, f.Value.(float64))
case thrift.I16:
offset += thrift.Binary.WriteI16(buf, f.Value.(int16))
case thrift.I32:
offset += thrift.Binary.WriteI32(buf, f.Value.(int32))
case thrift.I64:
offset += thrift.Binary.WriteI64(buf, f.Value.(int64))
case thrift.STRING:
offset += thrift.Binary.WriteString(buf, f.Value.(string))
case thrift.SET:
vs := f.Value.([]UnknownField)
offset += thrift.Binary.WriteSetBegin(buf, f.ValType, len(vs))
for _, v := range vs {
l, err := writeUnknownField(buf[offset:], &v)
offset += l
if err != nil {
return offset, err
}
}
case thrift.LIST:
vs := f.Value.([]UnknownField)
offset += thrift.Binary.WriteListBegin(buf, f.ValType, len(vs))
for _, v := range vs {
l, err := writeUnknownField(buf[offset:], &v)
offset += l
if err != nil {
return offset, err
}
}
case thrift.MAP:
kvs := f.Value.([]UnknownField)
offset += thrift.Binary.WriteMapBegin(buf, f.KeyType, f.ValType, len(kvs)/2)
for i := 0; i < len(kvs); i += 2 {
l, err := writeUnknownField(buf[offset:], &kvs[i])
offset += l
if err != nil {
return offset, err
}
l, err = writeUnknownField(buf[offset:], &kvs[i+1])
offset += l
if err != nil {
return offset, err
}
}
case thrift.STRUCT:
fs := f.Value.([]UnknownField)
l, err := WriteUnknownFields(buf[offset:], fs)
offset += l
if err != nil {
return offset, err
}
offset += thrift.Binary.WriteFieldStop(buf[offset:])
default:
return offset, fmt.Errorf("unknown data type %d", f.Type)
}
return
}
Loading

0 comments on commit 7822bba

Please sign in to comment.