From 56d53d212caa8b45838dc3e4e2ee8d4d91e1ca1a Mon Sep 17 00:00:00 2001 From: xiezhengyao Date: Tue, 6 Aug 2024 16:38:27 +0800 Subject: [PATCH] feat: add bufiox interfaces and spilit ttheader codec codes --- bufiox/bufreader.go | 30 +++ bufiox/bufwriter.go | 17 ++ bufiox/defaultbuf.go | 361 +++++++++++++++++++++++++++++++ bufiox/defaultbuf_test.go | 1 + bytex/buffer.go | 43 ++++ protocol/ttheader/decode.go | 226 +++++++++++++++++++ protocol/ttheader/encode.go | 324 +++++++++++++++++++++++++++ protocol/ttheader/encode_test.go | 1 + protocol/util/bytebuffer.go | 48 ++++ 9 files changed, 1051 insertions(+) create mode 100644 bufiox/bufreader.go create mode 100644 bufiox/bufwriter.go create mode 100644 bufiox/defaultbuf.go create mode 100644 bufiox/defaultbuf_test.go create mode 100644 bytex/buffer.go create mode 100644 protocol/ttheader/decode.go create mode 100644 protocol/ttheader/encode.go create mode 100644 protocol/ttheader/encode_test.go create mode 100644 protocol/util/bytebuffer.go diff --git a/bufiox/bufreader.go b/bufiox/bufreader.go new file mode 100644 index 0000000..40abc0c --- /dev/null +++ b/bufiox/bufreader.go @@ -0,0 +1,30 @@ +package bufiox + +// Reader is a buffer IO interface, which provides a user-space zero-copy method to reduce memory allocation and copy overhead. +type Reader interface { + // Next reads the next n bytes sequentially and returns a slice `p` of length `n`, + // otherwise returns an error if it is unable to read a buffer of n bytes. + // The returned `p` can be a shallow copy of the original buffer. + // Must ensure that the data in `p` is not modified before calling Release. + // + // Callers cannot use the returned data after calling Release. + Next(n int) (p []byte, err error) + + // Peek behaves the same as Next, except that it doesn't advance the reader. + // + // Callers cannot use the returned data after calling Release. + Peek(n int) (buf []byte, err error) + + // Skip skips the next n bytes sequentially, otherwise returns an error if it's unable to skip a buffer of n bytes. + Skip(n int) (err error) + + // ReadLen returns the size that has already been read. + // Read/Next/Skip will increase the size. When the release function is called, ReadLen is set to 0. + ReadLen() (n int) + + // Release will free the buffer. After release, buffer read by Next/Skip/Peek is invalid. + // Param e is used when the buffer release depend on error. + // For example, usually the write buffer will be released inside flush, + // but if flush error happen, write buffer may need to be released explicitly. + Release(e error) (err error) +} diff --git a/bufiox/bufwriter.go b/bufiox/bufwriter.go new file mode 100644 index 0000000..f3175fe --- /dev/null +++ b/bufiox/bufwriter.go @@ -0,0 +1,17 @@ +package bufiox + +// Writer is a buffer IO interface, which provides a user-space zero-copy method to reduce memory allocation and copy overhead. +type Writer interface { + // Malloc returns a shallow copy of the write buffer with length n, + // otherwise returns an error if it's unable to get n bytes from the write buffer. + // Must ensure that the data written by the user to buf can be flushed to the underlying io.Writer. + // + // Caller cannot write data to the returned buf after calling Flush. + Malloc(n int) (buf []byte, err error) + + // WrittenLen returns the total length of the buffer written. + WrittenLen() (length int) + + // Flush writes any malloc data to the underlying io.Writer, and reset WrittenLen to zero. + Flush() (err error) +} diff --git a/bufiox/defaultbuf.go b/bufiox/defaultbuf.go new file mode 100644 index 0000000..4aef461 --- /dev/null +++ b/bufiox/defaultbuf.go @@ -0,0 +1,361 @@ +package bufiox + +import ( + "errors" + "io" + "sync" + + "github.com/bytedance/gopkg/lang/dirtmake" + "github.com/bytedance/gopkg/lang/mcache" +) + +const maxConsecutiveEmptyReads = 100 + +var _ Reader = &DefaultReader{} + +type DefaultReader struct { + buf []byte // buf[ri:] is the buffer for reading. + rd io.Reader // reader provided by the client + ri int // buf read positions + err error + + maxSizeStats [statsBucketNum]int + maxSizeStatsIdx int +} + +const ( + statsBucketNum = 10 + defaultBufSize = 4096 +) + +var ( + errNegativeCount = errors.New("bufiox: negative count") + errFakeIOReader = errors.New("bufiox: fake io reader") +) + +// NewDefaultReader returns a new [Reader] whose buffer has the default size. +func NewDefaultReader(rd io.Reader) *DefaultReader { + r := &DefaultReader{} + r.reset(rd, mcache.Malloc(0, defaultBufSize)) + return r +} + +func NewBytesReader(buf []byte) *BytesReader { + r := bytesReaderPool.Get().(*BytesReader) + r.reset(r.fakedIOReader, buf) + return r +} + +var bytesReaderPool = sync.Pool{ + New: func() interface{} { + return &BytesReader{} + }, +} + +type BytesReader struct { + DefaultReader + fakedIOReader fakeIOReader +} + +func (r *BytesReader) Release(e error) error { + if len(r.buf)-r.ri == 0 { + // release buf + r.reset(nil, nil) + bytesReaderPool.Put(r) + } + return nil +} + +func (r *DefaultReader) reset(rd io.Reader, buf []byte) { + *r = DefaultReader{buf: buf, rd: rd} +} + +func (r *DefaultReader) acquireSlow(n int) int { + if r.err != nil { + if r.buf == nil { + return 0 + } + return len(r.buf) - r.ri + } + + if r.buf == nil { + var maxSize int + for size := range r.maxSizeStats { + if maxSize < size { + maxSize = size + } + } + for ; maxSize < n; maxSize *= 2 { + } + r.buf = mcache.Malloc(0, maxSize) + } + + if n > cap(r.buf)-r.ri { + // grow buffer + var ncap int + for ncap = cap(r.buf) * 2; ncap-r.ri < n; ncap *= 2 { + } + nbuf := mcache.Malloc(ncap) + r.buf = nbuf[:copy(nbuf, r.buf)] + } + + for i := 0; i < maxConsecutiveEmptyReads; i++ { + m, err := r.rd.Read(r.buf[len(r.buf):cap(r.buf)]) + r.buf = r.buf[:len(r.buf)+m] + if err != nil { + r.err = err + return len(r.buf) - r.ri + } + if n <= len(r.buf)-r.ri { + return n + } + } + return len(r.buf) - r.ri +} + +// fill reads a new chunk into the buffer. +func (r *DefaultReader) acquire(n int) int { + // fast path, for inline + if n <= len(r.buf)-r.ri { + return n + } + return r.acquireSlow(n) +} + +func (r *DefaultReader) Next(n int) (buf []byte, err error) { + if n < 0 { + err = errNegativeCount + return + } + m := r.acquire(n) + if n > m { + err = r.err + return + } + // nocopy read + buf = r.buf[r.ri : r.ri+n] + r.ri += n + return +} + +func (r *DefaultReader) Peek(n int) (buf []byte, err error) { + if n < 0 { + err = errNegativeCount + return + } + m := r.acquire(n) + if n > m { + err = r.err + return + } + // nocopy read + buf = r.buf[r.ri : r.ri+n] + return +} + +func (r *DefaultReader) Skip(n int) (err error) { + if n < 0 { + err = errNegativeCount + return + } + m := r.acquire(n) + if n > m { + err = r.err + return + } + r.ri += n + return +} + +func (r *DefaultReader) ReadLen() (n int) { + return r.ri +} + +func (r *DefaultReader) Read(p []byte) (m int, err error) { + m = r.acquire(len(p)) + copy(p, r.buf[r.ri:r.ri+m]) + r.ri += m + if len(p) > m { + err = r.err + } + return +} + +func (r *DefaultReader) Release(e error) error { + if len(r.buf)-r.ri == 0 { + // release buf + r.maxSizeStats[r.maxSizeStatsIdx] = cap(r.buf) + r.maxSizeStatsIdx = (r.maxSizeStatsIdx + 1) % 10 + mcache.Free(r.buf) + r.buf = nil + } else { + n := copy(r.buf, r.buf[r.ri:]) + r.buf = r.buf[:n] + } + r.ri = 0 + return nil +} + +type fakeIOReader struct{} + +func (fakeIOReader) Read(p []byte) (n int, err error) { + return 0, errFakeIOReader +} + +var _ Writer = &DefaultWriter{} + +type DefaultWriter struct { + buf []byte + oldBuf [][]byte + wd io.Writer + err error + + maxSizeStats [statsBucketNum]int + maxSizeStatsIdx int + + disableCache bool +} + +func NewDefaultWriter(wd io.Writer) Writer { + w := &DefaultWriter{} + w.reset(wd, nil) + return w +} + +func NewBytesWriter(buf *[]byte) *BytesWriter { + w := bytesWriterPool.Get().(*BytesWriter) + w.fakedIOWriter.bw = w + w.flushBytes = buf + w.reset(&w.fakedIOWriter, nil) + w.disableCache = true + return w +} + +var bytesWriterPool = sync.Pool{ + New: func() interface{} { + return &BytesWriter{} + }, +} + +type BytesWriter struct { + DefaultWriter + fakedIOWriter fakeIOWriter + flushBytes *[]byte +} + +func (w *BytesWriter) Flush() error { + err := w.DefaultWriter.Flush() + *w = BytesWriter{} + bytesWriterPool.Put(w) + return err +} + +func (w *DefaultWriter) reset(wd io.Writer, buf []byte) { + *w = DefaultWriter{buf: buf, wd: wd} +} + +func (w *DefaultWriter) acquire(n int) { + if len(w.buf)+n <= cap(w.buf) { + return + } + w.acquireSlow(n) +} + +func (w *DefaultWriter) acquireSlow(n int) { + if w.buf == nil { + maxSize := defaultBufSize + for size := range w.maxSizeStats { + if maxSize < size { + maxSize = size + } + } + for ; maxSize < n; maxSize *= 2 { + } + if w.disableCache { + w.buf = dirtmake.Bytes(0, maxSize) + } else { + w.buf = mcache.Malloc(0, maxSize) + } + } + + if n > cap(w.buf) { + // grow buffer + var ncap int + for ncap = cap(w.buf) * 2; ncap < n; ncap *= 2 { + } + var nbuf []byte + if w.disableCache { + nbuf = dirtmake.Bytes(ncap, ncap) + } else { + nbuf = mcache.Malloc(ncap) + } + w.oldBuf = append(w.oldBuf, w.buf) + w.buf = nbuf[:len(w.buf)] + } +} + +func (w *DefaultWriter) Malloc(n int) (buf []byte, err error) { + if w.err != nil { + err = w.err + return + } + if n < 0 { + err = errNegativeCount + return + } + w.acquire(n) + buf = w.buf[len(w.buf) : len(w.buf)+n] + w.buf = w.buf[:len(w.buf)+n] + return +} + +func (w *DefaultWriter) Write(p []byte) (n int, err error) { + if w.err != nil { + err = w.err + return + } + w.acquire(len(p)) + n = copy(w.buf[len(w.buf):], p) + w.buf = w.buf[:len(w.buf)+n] + return +} + +func (w *DefaultWriter) WrittenLen() int { + return len(w.buf) +} + +func (w *DefaultWriter) Flush() (err error) { + if w.err != nil { + err = w.err + return + } + if w.buf == nil { + return nil + } + // copy old buffer + var offset int + for _, oldBuf := range w.oldBuf { + offset += copy(w.buf[offset:], oldBuf[offset:]) + } + if _, err = w.wd.Write(w.buf); err != nil { + w.err = err + return err + } + w.maxSizeStats[w.maxSizeStatsIdx] = cap(w.buf) + w.maxSizeStatsIdx = (w.maxSizeStatsIdx + 1) % 10 + if !w.disableCache { + mcache.Free(w.buf) + } + w.buf = nil + w.oldBuf = nil + return nil +} + +type fakeIOWriter struct { + bw *BytesWriter +} + +func (w *fakeIOWriter) Write(p []byte) (n int, err error) { + *w.bw.flushBytes = p + return len(p), nil +} diff --git a/bufiox/defaultbuf_test.go b/bufiox/defaultbuf_test.go new file mode 100644 index 0000000..1601852 --- /dev/null +++ b/bufiox/defaultbuf_test.go @@ -0,0 +1 @@ +package bufiox diff --git a/bytex/buffer.go b/bytex/buffer.go new file mode 100644 index 0000000..3051f34 --- /dev/null +++ b/bytex/buffer.go @@ -0,0 +1,43 @@ +package bytex + +const BinaryInplaceThreshold = 1024 + +// A Buffer is a variable-sized buffer of bytes with [Buffer.Read] and [Buffer.Write] methods. +// The zero value for Buffer is an empty buffer ready to use. +type Buffer struct { + buf [][]byte + off1, off2 int + + releaseBuf []byte +} + +func (b *Buffer) Write(p []byte) (n int, err error) { + // fast path + if len(p) >= BinaryInplaceThreshold { + buf := b.buf[len(b.buf)-1] + buf = buf[len(buf):len(buf):cap(buf)] + b.buf = append(b.buf, p, buf) + return len(p), nil + } + buf := b.buf[len(b.buf)-1] + if cap(buf)-len(buf) >= len(p) { + n = copy(buf, p) + return + } + + // slow path + b.slowWrite() +} + +func (b *Buffer) Read(p []byte) (n int, err error) { + // fast path + buf := b.buf[b.off1] + if len(p) < len(buf)-b.off2 { + n = copy(p, buf[b.off2:]) + b.off2 += n + return + } + + // slow path + b.slowRead() +} diff --git a/protocol/ttheader/decode.go b/protocol/ttheader/decode.go new file mode 100644 index 0000000..deb3837 --- /dev/null +++ b/protocol/ttheader/decode.go @@ -0,0 +1,226 @@ +package ttheader + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/gopkg/protocol/util" +) + +const ( + // MagicMask is bit mask for checking version. + MagicMask = 0xffff0000 +) + +type DecodeParam struct { + Flags HeaderFlags + + SeqID int32 + + ProtocolID ProtocolID + + // IntInfo is used to set up int key-value info into InfoIDIntKeyValue + IntInfo map[uint16]string + + // StrInfo is used to set up string key-value info into InfoIDKeyValue + StrInfo map[string]string + + PayloadLen int +} + +func DecodeFromBytes(ctx context.Context, bs []byte) (param DecodeParam, err error) { + in := bufiox.NewBytesReader(bs) + param, err = Decode(ctx, in) + _ = in.Release(nil) + return +} + +func Decode(ctx context.Context, in bufiox.Reader) (param DecodeParam, err error) { + var headerMeta []byte + headerMeta, err = in.Next(TTHeaderMetaSize) + if err != nil { + return + } + if !IsTTHeader(headerMeta) { + err = errors.New("not TTHeader protocol") + return + } + totalLen := util.Bytes2Uint32NoCheck(headerMeta[:Size32]) + + flags := util.Bytes2Uint16NoCheck(headerMeta[Size16*3:]) + param.Flags = HeaderFlags(flags) + + seqID := util.Bytes2Uint32NoCheck(headerMeta[Size32*2 : Size32*3]) + param.SeqID = int32(seqID) + + headerInfoSize := util.Bytes2Uint16NoCheck(headerMeta[Size32*3:TTHeaderMetaSize]) * 4 + if uint32(headerInfoSize) > MaxHeaderSize || headerInfoSize < 2 { + err = fmt.Errorf("invalid header length[%d]", headerInfoSize) + return + } + + var headerInfo []byte + if headerInfo, err = in.Next(int(headerInfoSize)); err != nil { + return + } + if err = checkProtocolID(headerInfo[0]); err != nil { + return + } + hdIdx := 2 + transformIDNum := int(headerInfo[1]) + if int(headerInfoSize)-hdIdx < transformIDNum { + err = fmt.Errorf("need read %d transformIDs, but not enough", transformIDNum) + return + } + transformIDs := make([]uint8, transformIDNum) + for i := 0; i < transformIDNum; i++ { + transformIDs[i] = headerInfo[hdIdx] + hdIdx++ + } + + param.IntInfo, param.StrInfo, err = readKVInfo(hdIdx, headerInfo) + if err != nil { + err = fmt.Errorf("ttHeader read kv info failed, %s, headerInfo=%#x", err.Error(), headerInfo) + return + } + + param.PayloadLen = int(totalLen - uint32(headerInfoSize) + Size32 - TTHeaderMetaSize) + return +} + +/** + * +------------------------------------------------------------+ + * | 4Byte | 2Byte | + * +------------------------------------------------------------+ + * | Length | HEADER MAGIC | + * +------------------------------------------------------------+ + */ + +func IsTTHeader(flagBuf []byte) bool { + return binary.BigEndian.Uint32(flagBuf[Size32:])&MagicMask == TTHeaderMagic +} + +func readKVInfo(idx int, buf []byte) (intKVMap map[uint16]string, strKVMap map[string]string, err error) { + for { + var infoID uint8 + infoID, err = util.Bytes2Uint8(buf, idx) + idx++ + if err != nil { + // this is the last field, read until there is no more padding + if err == io.EOF { + break + } else { + return + } + } + switch InfoIDType(infoID) { + case InfoIDPadding: + continue + case InfoIDKeyValue: + if strKVMap == nil { + strKVMap = make(map[string]string) + } + _, err = readStrKVInfo(&idx, buf, strKVMap) + if err != nil { + return + } + case InfoIDIntKeyValue: + if intKVMap == nil { + intKVMap = make(map[uint16]string) + } + _, err = readIntKVInfo(&idx, buf, intKVMap) + if err != nil { + return + } + case InfoIDACLToken: + if strKVMap == nil { + strKVMap = make(map[string]string) + } + if err = readACLToken(&idx, buf, strKVMap); err != nil { + return + } + default: + err = fmt.Errorf("invalid infoIDType[%#x]", infoID) + return + } + } + return +} + +func readIntKVInfo(idx *int, buf []byte, info map[uint16]string) (has bool, err error) { + kvSize, err := util.Bytes2Uint16(buf, *idx) + *idx += 2 + if err != nil { + return false, fmt.Errorf("error reading int kv info size: %s", err.Error()) + } + if kvSize <= 0 { + return false, nil + } + for i := uint16(0); i < kvSize; i++ { + key, err := util.Bytes2Uint16(buf, *idx) + *idx += 2 + if err != nil { + return false, fmt.Errorf("error reading int kv info: %s", err.Error()) + } + val, n, err := util.ReadString2BLen(buf, *idx) + *idx += n + if err != nil { + return false, fmt.Errorf("error reading int kv info: %s", err.Error()) + } + info[key] = val + } + return true, nil +} + +func readStrKVInfo(idx *int, buf []byte, info map[string]string) (has bool, err error) { + kvSize, err := util.Bytes2Uint16(buf, *idx) + *idx += 2 + if err != nil { + return false, fmt.Errorf("error reading str kv info size: %s", err.Error()) + } + if kvSize <= 0 { + return false, nil + } + for i := uint16(0); i < kvSize; i++ { + key, n, err := util.ReadString2BLen(buf, *idx) + *idx += n + if err != nil { + return false, fmt.Errorf("error reading str kv info: %s", err.Error()) + } + val, n, err := util.ReadString2BLen(buf, *idx) + *idx += n + if err != nil { + return false, fmt.Errorf("error reading str kv info: %s", err.Error()) + } + info[key] = val + } + return true, nil +} + +// readACLToken reads acl token +func readACLToken(idx *int, buf []byte, info map[string]string) error { + val, n, err := util.ReadString2BLen(buf, *idx) + *idx += n + if err != nil { + return fmt.Errorf("error reading acl token: %s", err.Error()) + } + info[GDPRToken] = val + return nil +} + +// protoID just for ttheader +func checkProtocolID(protoID uint8) error { + switch protoID { + case uint8(ProtocolIDThriftBinary): + case uint8(ProtocolIDKitexProtobuf): + case uint8(ProtocolIDThriftCompactV2): + // just for compatibility + default: + return fmt.Errorf("unsupported ProtocolID[%d]", protoID) + } + return nil +} diff --git a/protocol/ttheader/encode.go b/protocol/ttheader/encode.go new file mode 100644 index 0000000..874ead3 --- /dev/null +++ b/protocol/ttheader/encode.go @@ -0,0 +1,324 @@ +package ttheader + +import ( + "context" + "encoding/binary" + "fmt" + + "github.com/bytedance/gopkg/cloud/metainfo" + "github.com/cloudwego/gopkg/bufiox" +) + +/** + * TTHeader Protocol + * +-------------2Byte--------------|-------------2Byte-------------+ + * +----------------------------------------------------------------+ + * | 0| LENGTH | + * +----------------------------------------------------------------+ + * | 0| HEADER MAGIC | FLAGS | + * +----------------------------------------------------------------+ + * | SEQUENCE NUMBER | + * +----------------------------------------------------------------+ + * | 0| Header Size(/32) | ... + * +--------------------------------- + * + * Header is of variable size: + * (and starts at offset 14) + * + * +----------------------------------------------------------------+ + * | PROTOCOL ID |NUM TRANSFORMS . |TRANSFORM 0 ID (uint8)| + * +----------------------------------------------------------------+ + * | TRANSFORM 0 DATA ... + * +----------------------------------------------------------------+ + * | ... ... | + * +----------------------------------------------------------------+ + * | INFO 0 ID (uint8) | INFO 0 DATA ... + * +----------------------------------------------------------------+ + * | ... ... | + * +----------------------------------------------------------------+ + * | | + * | PAYLOAD | + * | | + * +----------------------------------------------------------------+ + */ + +// The byte count of 32 and 16 integer values. +const ( + Size32 = 4 + Size16 = 2 +) + +// Header keys +const ( + // Header Magics + // 0 and 16th bits must be 0 to differentiate from framed & unframed + TTHeaderMagic uint32 = 0x10000000 + MeshHeaderMagic uint32 = 0xFFAF0000 + MeshHeaderLenMask uint32 = 0x0000FFFF + + // HeaderMask uint32 = 0xFFFF0000 + FlagsMask uint32 = 0x0000FFFF + MethodMask uint32 = 0x41000000 // method first byte [A-Za-z_] + MaxFrameSize uint32 = 0x3FFFFFFF + MaxHeaderSize uint32 = 65536 + + initialBufferSize = 256 +) + +type HeaderFlags uint16 + +const ( + HeaderFlagsKey string = "HeaderFlags" + HeaderFlagSupportOutOfOrder HeaderFlags = 0x01 + HeaderFlagDuplexReverse HeaderFlags = 0x08 + HeaderFlagSASL HeaderFlags = 0x10 +) + +const ( + TTHeaderMetaSize = 14 +) + +// ProtocolID is the wrapped protocol id used in THeader. +type ProtocolID uint8 + +// Supported ProtocolID values. +const ( + ProtocolIDThriftBinary ProtocolID = 0x00 + ProtocolIDThriftCompact ProtocolID = 0x02 // Kitex not support + ProtocolIDThriftCompactV2 ProtocolID = 0x03 // Kitex not support + ProtocolIDKitexProtobuf ProtocolID = 0x04 + ProtocolIDDefault = ProtocolIDThriftBinary +) + +type InfoIDType uint8 // uint8 + +const ( + InfoIDPadding InfoIDType = 0 + InfoIDKeyValue InfoIDType = 0x01 + InfoIDIntKeyValue InfoIDType = 0x10 + InfoIDACLToken InfoIDType = 0x11 +) + +// key of acl token +// You can set up acl token through metainfo. +// eg: +// +// ctx = metainfo.WithValue(ctx, "gdpr-token", "your token") +const ( + // GDPRToken is used to set up gdpr token into InfoIDACLToken + GDPRToken = metainfo.PrefixTransient + "gdpr-token" +) + +type EncodeParam struct { + Flags HeaderFlags + + SeqID int32 + + ProtocolID ProtocolID + + // IntInfo is used to set up int key-value info into InfoIDIntKeyValue + IntInfo map[uint16]string + + // StrInfo is used to set up string key-value info into InfoIDKeyValue + StrInfo map[string]string +} + +// EncodeToBytes encode ttheader to bytes. +// NOTICE: Must call +// +// `binary.BigEndian.PutUint32(buf, uint32(totalLen))` +// +// after encoding both header and payload data to set total length of a request/response. +// And `totalLen` should be the length of header + payload - 4. +// You may refer to unit tests for examples. +func EncodeToBytes(ctx context.Context, param EncodeParam) (buf []byte, err error) { + out := bufiox.NewBytesWriter(&buf) + if _, err = Encode(ctx, param, out); err != nil { + return + } + if err = out.Flush(); err != nil { + return + } + return +} + +func Encode(ctx context.Context, param EncodeParam, out bufiox.Writer) (totalLenField []byte, err error) { + // 1. header meta + var headerMeta []byte + headerMeta, err = out.Malloc(TTHeaderMetaSize) + if err != nil { + return nil, fmt.Errorf("ttHeader malloc header meta failed, %s", err.Error()) + } + + totalLenField = headerMeta[0:4] + headerInfoSizeField := headerMeta[12:14] + binary.BigEndian.PutUint32(headerMeta[4:8], TTHeaderMagic+uint32(param.Flags)) + binary.BigEndian.PutUint32(headerMeta[8:12], uint32(param.Flags)) + + var transformIDs []uint8 // transformIDs not support TODO compress + // 2. header info, malloc and write + if err = WriteByte(byte(param.ProtocolID), out); err != nil { + return nil, fmt.Errorf("ttHeader write protocol id failed, %s", err.Error()) + } + if err = WriteByte(byte(len(transformIDs)), out); err != nil { + return nil, fmt.Errorf("ttHeader write transformIDs length failed, %s", err.Error()) + } + for tid := range transformIDs { + if err = WriteByte(byte(tid), out); err != nil { + return nil, fmt.Errorf("ttHeader write transformIDs failed, %s", err.Error()) + } + } + // PROTOCOL ID(u8) + NUM TRANSFORMS(always 0)(u8) + TRANSFORM IDs([]u8) + headerInfoSize := 1 + 1 + len(transformIDs) + headerInfoSize, err = writeKVInfo(headerInfoSize, param.IntInfo, param.StrInfo, out) + if err != nil { + return nil, fmt.Errorf("ttHeader write kv info failed, %s", err.Error()) + } + + if uint32(headerInfoSize) > MaxHeaderSize { + return nil, fmt.Errorf("invalid header length[%d]", headerInfoSize) + } + binary.BigEndian.PutUint16(headerInfoSizeField, uint16(headerInfoSize/4)) + return totalLenField, nil +} + +func writeKVInfo(writtenSize int, intKVMap map[uint16]string, strKVMap map[string]string, out bufiox.Writer) (writeSize int, err error) { + writeSize = writtenSize + // str kv info + strKVSize := len(strKVMap) + // write gdpr token into InfoIDACLToken + // supplementary doc: https://www.cloudwego.io/docs/kitex/reference/transport_protocol_ttheader/ + if gdprToken, ok := strKVMap[GDPRToken]; ok { + strKVSize-- + // INFO ID TYPE(u8) + if err = WriteByte(byte(InfoIDACLToken), out); err != nil { + return writeSize, err + } + writeSize += 1 + + wLen, err := WriteString2BLen(gdprToken, out) + if err != nil { + return writeSize, err + } + writeSize += wLen + } + + if strKVSize > 0 { + // INFO ID TYPE(u8) + NUM HEADERS(u16) + if err = WriteByte(byte(InfoIDKeyValue), out); err != nil { + return writeSize, err + } + if err = WriteUint16(uint16(strKVSize), out); err != nil { + return writeSize, err + } + writeSize += 3 + for key, val := range strKVMap { + if key == GDPRToken { + continue + } + keyWLen, err := WriteString2BLen(key, out) + if err != nil { + return writeSize, err + } + valWLen, err := WriteString2BLen(val, out) + if err != nil { + return writeSize, err + } + writeSize = writeSize + keyWLen + valWLen + } + } + + // int kv info + intKVSize := len(intKVMap) + if intKVSize > 0 { + // INFO ID TYPE(u8) + NUM HEADERS(u16) + if err = WriteByte(byte(InfoIDIntKeyValue), out); err != nil { + return writeSize, err + } + if err = WriteUint16(uint16(intKVSize), out); err != nil { + return writeSize, err + } + writeSize += 3 + for key, val := range intKVMap { + if err = WriteUint16(key, out); err != nil { + return writeSize, err + } + valWLen, err := WriteString2BLen(val, out) + if err != nil { + return writeSize, err + } + writeSize = writeSize + 2 + valWLen + } + } + + // padding = (4 - headerInfoSize%4) % 4 + padding := (4 - writeSize%4) % 4 + paddingBuf, err := out.Malloc(padding) + if err != nil { + return writeSize, err + } + for i := 0; i < len(paddingBuf); i++ { + paddingBuf[i] = byte(0) + } + writeSize += padding + return +} + +// WriteByte ... +func WriteByte(val byte, out bufiox.Writer) error { + var buf []byte + var err error + if buf, err = out.Malloc(1); err != nil { + return err + } + buf[0] = val + return nil +} + +// WriteUint32 ... +func WriteUint32(val uint32, out bufiox.Writer) error { + var buf []byte + var err error + if buf, err = out.Malloc(Size32); err != nil { + return err + } + binary.BigEndian.PutUint32(buf, val) + return nil +} + +// WriteUint16 ... +func WriteUint16(val uint16, out bufiox.Writer) error { + var buf []byte + var err error + if buf, err = out.Malloc(Size16); err != nil { + return err + } + binary.BigEndian.PutUint16(buf, val) + return nil +} + +// WriteString ... +func WriteString(val string, out bufiox.Writer) (int, error) { + var buf []byte + var err error + strLen := len(val) + if buf, err = out.Malloc(Size32 + strLen); err != nil { + return 0, err + } + binary.BigEndian.PutUint32(buf, uint32(strLen)) + copy(buf[Size32:], val) + return Size32 + strLen, nil +} + +// WriteString2BLen ... +func WriteString2BLen(val string, out bufiox.Writer) (int, error) { + var buf []byte + var err error + strLen := len(val) + if buf, err = out.Malloc(Size16 + strLen); err != nil { + return 0, err + } + binary.BigEndian.PutUint16(buf, uint16(strLen)) + copy(buf[Size16:], val) + return Size16 + strLen, nil +} diff --git a/protocol/ttheader/encode_test.go b/protocol/ttheader/encode_test.go new file mode 100644 index 0000000..cc1339e --- /dev/null +++ b/protocol/ttheader/encode_test.go @@ -0,0 +1 @@ +package ttheader diff --git a/protocol/util/bytebuffer.go b/protocol/util/bytebuffer.go new file mode 100644 index 0000000..5e9955f --- /dev/null +++ b/protocol/util/bytebuffer.go @@ -0,0 +1,48 @@ +package util + +import ( + "encoding/binary" + "io" +) + +// Bytes2Uint32NoCheck ... +func Bytes2Uint32NoCheck(bytes []byte) uint32 { + return binary.BigEndian.Uint32(bytes) +} + +// Bytes2Uint16NoCheck ... +func Bytes2Uint16NoCheck(bytes []byte) uint16 { + return binary.BigEndian.Uint16(bytes) +} + +// Bytes2Uint8 ... +func Bytes2Uint8(bytes []byte, off int) (uint8, error) { + if len(bytes)-off < 1 { + return 0, io.EOF + } + return bytes[off], nil +} + +// Bytes2Uint16 ... +func Bytes2Uint16(bytes []byte, off int) (uint16, error) { + if len(bytes)-off < 2 { + return 0, io.EOF + } + return binary.BigEndian.Uint16(bytes[off:]), nil +} + +// ReadString2BLen ... +func ReadString2BLen(bytes []byte, off int) (string, int, error) { + length, err := Bytes2Uint16(bytes, off) + strLen := int(length) + if err != nil { + return "", 0, err + } + off += 2 + if len(bytes)-off < strLen { + return "", 0, io.EOF + } + + buf := bytes[off : off+strLen] + return string(buf), int(length) + 2, nil +}