From 2abdd6ae18d0b8d7a44d7650bdf390bfb69da0c1 Mon Sep 17 00:00:00 2001 From: xiezhengyao Date: Tue, 6 Aug 2024 16:38:27 +0800 Subject: [PATCH] feat: add bufiox interfaces and spilit ttheader codec codes --- bufiox/bufreader.go | 44 ++++ bufiox/bufwriter.go | 31 +++ bufiox/defaultbuf.go | 355 +++++++++++++++++++++++++++++++ bufiox/defaultbuf_test.go | 248 +++++++++++++++++++++ bytex/buffer.go | 43 ++++ go.mod | 2 + go.sum | 2 + protocol/ttheader/decode.go | 240 +++++++++++++++++++++ protocol/ttheader/encode.go | 338 +++++++++++++++++++++++++++++ protocol/ttheader/encode_test.go | 15 ++ protocol/util/bytebuffer.go | 48 +++++ 11 files changed, 1366 insertions(+) create mode 100644 bufiox/bufreader.go create mode 100644 bufiox/bufwriter.go create mode 100644 bufiox/defaultbuf.go create mode 100644 bufiox/defaultbuf_test.go create mode 100644 bytex/buffer.go create mode 100644 protocol/ttheader/decode.go create mode 100644 protocol/ttheader/encode.go create mode 100644 protocol/ttheader/encode_test.go create mode 100644 protocol/util/bytebuffer.go diff --git a/bufiox/bufreader.go b/bufiox/bufreader.go new file mode 100644 index 0000000..0afd9c1 --- /dev/null +++ b/bufiox/bufreader.go @@ -0,0 +1,44 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bufiox + +// Reader is a buffer IO interface, which provides a user-space zero-copy method to reduce memory allocation and copy overhead. +type Reader interface { + // Next reads the next n bytes sequentially and returns a slice `p` of length `n`, + // otherwise returns an error if it is unable to read a buffer of n bytes. + // The returned `p` can be a shallow copy of the original buffer. + // Must ensure that the data in `p` is not modified before calling Release. + // + // Callers cannot use the returned data after calling Release. + Next(n int) (p []byte, err error) + + // Peek behaves the same as Next, except that it doesn't advance the reader. + // + // Callers cannot use the returned data after calling Release. + Peek(n int) (buf []byte, err error) + + // Skip skips the next n bytes sequentially, otherwise returns an error if it's unable to skip a buffer of n bytes. + Skip(n int) (err error) + + // ReadLen returns the size that has already been read. + // Read/Next/Skip will increase the size. When the release function is called, ReadLen is set to 0. + ReadLen() (n int) + + // Release will free the buffer. After release, buffer read by Next/Skip/Peek is invalid. + // Param e is used when the buffer release depend on error. + // For example, usually the write buffer will be released inside flush, + // but if flush error happen, write buffer may need to be released explicitly. + Release(e error) (err error) +} diff --git a/bufiox/bufwriter.go b/bufiox/bufwriter.go new file mode 100644 index 0000000..fe57fa6 --- /dev/null +++ b/bufiox/bufwriter.go @@ -0,0 +1,31 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bufiox + +// Writer is a buffer IO interface, which provides a user-space zero-copy method to reduce memory allocation and copy overhead. +type Writer interface { + // Malloc returns a shallow copy of the write buffer with length n, + // otherwise returns an error if it's unable to get n bytes from the write buffer. + // Must ensure that the data written by the user to buf can be flushed to the underlying io.Writer. + // + // Caller cannot write data to the returned buf after calling Flush. + Malloc(n int) (buf []byte, err error) + + // WrittenLen returns the total length of the buffer written. + WrittenLen() (length int) + + // Flush writes any malloc data to the underlying io.Writer, and reset WrittenLen to zero. + Flush() (err error) +} diff --git a/bufiox/defaultbuf.go b/bufiox/defaultbuf.go new file mode 100644 index 0000000..a6ccab6 --- /dev/null +++ b/bufiox/defaultbuf.go @@ -0,0 +1,355 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bufiox + +import ( + "errors" + "io" + + "github.com/bytedance/gopkg/lang/dirtmake" + "github.com/bytedance/gopkg/lang/mcache" +) + +const maxConsecutiveEmptyReads = 100 + +var _ Reader = &DefaultReader{} + +type DefaultReader struct { + buf []byte // buf[ri:] is the buffer for reading. + rd io.Reader // reader provided by the client + ri int // buf read positions + err error + + maxSizeStats [statsBucketNum]int + maxSizeStatsIdx int + + disableCache bool +} + +const ( + statsBucketNum = 10 + defaultBufSize = 4096 +) + +var errNegativeCount = errors.New("bufiox: negative count") + +// NewDefaultReader returns a new [Reader] whose buffer has the default size. +func NewDefaultReader(rd io.Reader) *DefaultReader { + r := &DefaultReader{} + r.reset(rd, nil, false) + return r +} + +func NewBytesReader(buf []byte) *BytesReader { + r := &BytesReader{} + r.reset(r.fakedIOReader, buf, true) + return r +} + +type BytesReader struct { + DefaultReader + fakedIOReader fakeIOReader +} + +func (r *DefaultReader) reset(rd io.Reader, buf []byte, disableCache bool) { + *r = DefaultReader{buf: buf, rd: rd, disableCache: disableCache} +} + +func (r *DefaultReader) acquireSlow(n int) int { + if r.err != nil { + if r.buf == nil { + return 0 + } + return len(r.buf) - r.ri + } + + if r.buf == nil { + maxSize := defaultBufSize + for size := range r.maxSizeStats { + if maxSize < size { + maxSize = size + } + } + for ; maxSize < n; maxSize *= 2 { + } + if r.disableCache { + r.buf = dirtmake.Bytes(0, maxSize) + } else { + r.buf = mcache.Malloc(0, maxSize) + } + } + + if n > cap(r.buf)-r.ri { + // grow buffer + var ncap int + for ncap = cap(r.buf) * 2; ncap-r.ri < n; ncap *= 2 { + } + var nbuf []byte + if r.disableCache { + nbuf = dirtmake.Bytes(ncap, ncap) + } else { + nbuf = mcache.Malloc(ncap) + } + r.buf = nbuf[:copy(nbuf, r.buf)] + } + + for i := 0; i < maxConsecutiveEmptyReads; i++ { + m, err := r.rd.Read(r.buf[len(r.buf):cap(r.buf)]) + r.buf = r.buf[:len(r.buf)+m] + if err != nil { + r.err = err + return len(r.buf) - r.ri + } + if n <= len(r.buf)-r.ri { + return n + } + } + return len(r.buf) - r.ri +} + +// fill reads a new chunk into the buffer. +func (r *DefaultReader) acquire(n int) int { + // fast path, for inline + if n <= len(r.buf)-r.ri { + return n + } + return r.acquireSlow(n) +} + +func (r *DefaultReader) Next(n int) (buf []byte, err error) { + if n < 0 { + err = errNegativeCount + return + } + m := r.acquire(n) + if n > m { + err = r.err + return + } + // nocopy read + buf = r.buf[r.ri : r.ri+n] + r.ri += n + return +} + +func (r *DefaultReader) Peek(n int) (buf []byte, err error) { + if n < 0 { + err = errNegativeCount + return + } + m := r.acquire(n) + if n > m { + err = r.err + return + } + // nocopy read + buf = r.buf[r.ri : r.ri+n] + return +} + +func (r *DefaultReader) Skip(n int) (err error) { + if n < 0 { + err = errNegativeCount + return + } + m := r.acquire(n) + if n > m { + err = r.err + return + } + r.ri += n + return +} + +func (r *DefaultReader) ReadLen() (n int) { + return r.ri +} + +func (r *DefaultReader) Read(p []byte) (m int, err error) { + m = r.acquire(len(p)) + copy(p, r.buf[r.ri:r.ri+m]) + r.ri += m + if len(p) > m { + err = r.err + } + return +} + +func (r *DefaultReader) Release(e error) error { + if len(r.buf)-r.ri == 0 { + // release buf + r.maxSizeStats[r.maxSizeStatsIdx] = cap(r.buf) + r.maxSizeStatsIdx = (r.maxSizeStatsIdx + 1) % 10 + if !r.disableCache { + mcache.Free(r.buf) + } + r.buf = nil + } else { + n := copy(r.buf, r.buf[r.ri:]) + r.buf = r.buf[:n] + } + r.ri = 0 + return nil +} + +type fakeIOReader struct{} + +func (fakeIOReader) Read(p []byte) (n int, err error) { + return 0, io.EOF +} + +var _ Writer = &DefaultWriter{} + +type DefaultWriter struct { + buf []byte + oldBuf [][]byte + wd io.Writer + err error + + maxSizeStats [statsBucketNum]int + maxSizeStatsIdx int + + disableCache bool +} + +func NewDefaultWriter(wd io.Writer) *DefaultWriter { + w := &DefaultWriter{} + w.reset(wd, nil, false) + return w +} + +func NewBytesWriter(buf *[]byte) *BytesWriter { + w := &BytesWriter{} + w.fakedIOWriter.bw = w + w.flushBytes = buf + w.reset(&w.fakedIOWriter, nil, true) + return w +} + +type BytesWriter struct { + DefaultWriter + fakedIOWriter fakeIOWriter + flushBytes *[]byte +} + +func (w *DefaultWriter) reset(wd io.Writer, buf []byte, disableCache bool) { + *w = DefaultWriter{buf: buf, wd: wd, disableCache: disableCache} +} + +func (w *DefaultWriter) acquire(n int) { + if len(w.buf)+n <= cap(w.buf) { + return + } + w.acquireSlow(n) +} + +func (w *DefaultWriter) acquireSlow(n int) { + if w.buf == nil { + maxSize := defaultBufSize + for size := range w.maxSizeStats { + if maxSize < size { + maxSize = size + } + } + for ; maxSize < n; maxSize *= 2 { + } + if w.disableCache { + w.buf = dirtmake.Bytes(0, maxSize) + } else { + w.buf = mcache.Malloc(0, maxSize) + } + } + + if n > cap(w.buf)-len(w.buf) { + // grow buffer + var ncap int + for ncap = cap(w.buf) * 2; ncap-len(w.buf) < n; ncap *= 2 { + } + var nbuf []byte + if w.disableCache { + nbuf = dirtmake.Bytes(ncap, ncap) + } else { + nbuf = mcache.Malloc(ncap) + } + w.oldBuf = append(w.oldBuf, w.buf) + w.buf = nbuf[:len(w.buf)] + } +} + +func (w *DefaultWriter) Malloc(n int) (buf []byte, err error) { + if w.err != nil { + err = w.err + return + } + if n < 0 { + err = errNegativeCount + return + } + w.acquire(n) + buf = w.buf[len(w.buf) : len(w.buf)+n] + w.buf = w.buf[:len(w.buf)+n] + return +} + +func (w *DefaultWriter) Write(p []byte) (n int, err error) { + if w.err != nil { + err = w.err + return + } + w.acquire(len(p)) + n = copy(w.buf[len(w.buf):], p) + w.buf = w.buf[:len(w.buf)+n] + return +} + +func (w *DefaultWriter) WrittenLen() int { + return len(w.buf) +} + +func (w *DefaultWriter) Flush() (err error) { + if w.err != nil { + err = w.err + return + } + if w.buf == nil { + return nil + } + // copy old buffer + var offset int + for _, oldBuf := range w.oldBuf { + offset += copy(w.buf[offset:], oldBuf[offset:]) + } + if _, err = w.wd.Write(w.buf); err != nil { + w.err = err + return err + } + w.maxSizeStats[w.maxSizeStatsIdx] = cap(w.buf) + w.maxSizeStatsIdx = (w.maxSizeStatsIdx + 1) % 10 + if !w.disableCache { + mcache.Free(w.buf) + } + w.buf = nil + w.oldBuf = nil + return nil +} + +type fakeIOWriter struct { + bw *BytesWriter +} + +func (w *fakeIOWriter) Write(p []byte) (n int, err error) { + *w.bw.flushBytes = p + return len(p), nil +} diff --git a/bufiox/defaultbuf_test.go b/bufiox/defaultbuf_test.go new file mode 100644 index 0000000..f6c1871 --- /dev/null +++ b/bufiox/defaultbuf_test.go @@ -0,0 +1,248 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bufiox + +import ( + "errors" + "fmt" + "io" + "testing" +) + +type mockReader struct { + dataSize int +} + +func (r *mockReader) Read(p []byte) (n int, err error) { + n = r.dataSize + if n > len(p) { + n = len(p) + } + if n == 0 { + return 0, io.EOF + } + for i := range p[:n] { + p[i] = byte(0xff) + } + r.dataSize -= n + return +} + +func TestDefaultReader(t *testing.T) { + tcases := []struct { + dataSize int + handle func(reader Reader) + }{ + { + dataSize: 1024, + handle: func(reader Reader) { + buf, err := reader.Next(1024) + if err != nil { + t.Fatal(err) + } + for _, b := range buf { + if b != 0xff { + t.Fatal("data not equal") + } + } + }, + }, + { + dataSize: 1024, + handle: func(reader Reader) { + buf, err := reader.Next(1025) + if err != io.EOF { + t.Fatal("err is not io.EOF", err) + } + if buf != nil { + t.Fatal("buf is not nil") + } + }, + }, + { + dataSize: 1024 * 16, + handle: func(reader Reader) { + buf, err := reader.Next(1024) + if err != nil { + t.Fatal(err) + } + for _, b := range buf { + if b != 0xff { + t.Fatal("data not equal") + } + } + if reader.ReadLen() != 1024 { + t.Fatal("read len is not 1024") + } + buf, err = reader.Next(1024 * 14) + if err != nil { + t.Fatal(err) + } + for _, b := range buf { + if b != 0xff { + t.Fatal("data not equal") + } + } + if reader.ReadLen() != 1024*15 { + t.Fatal("read len is not 1024*15") + } + err = reader.Release(nil) + if err != nil { + t.Fatal(err) + } + if reader.ReadLen() != 0 { + t.Fatal("read len is not 0") + } + buf, err = reader.Peek(1024) + if err != nil { + t.Fatal(err) + } + for _, b := range buf { + if b != 0xff { + t.Fatal("data not equal") + } + } + if reader.ReadLen() != 0 { + t.Fatal("read len is not 0") + } + err = reader.Skip(1024) + if err != nil { + t.Fatal(err) + } + if reader.ReadLen() != 1024 { + t.Fatal("read len is not 1024") + } + err = reader.Release(nil) + if err != nil { + t.Fatal(err) + } + switch r := reader.(type) { + case *DefaultReader: + if r.buf != nil { + t.Fatal("buf is not nil") + } + case *BytesReader: + if r.buf != nil { + t.Fatal("buf is not nil") + } + } + buf, err = reader.Next(1) + if err != io.EOF { + t.Fatal("err is not io.EOF", err) + } + buf, err = reader.Peek(1) + if err != io.EOF { + t.Fatal("err is not io.EOF", err) + } + err = reader.Skip(1) + if err != io.EOF { + t.Fatal("err is not io.EOF", err) + } + }, + }, + } + for _, tcase := range tcases { + r := NewDefaultReader(&mockReader{dataSize: tcase.dataSize}) + tcase.handle(r) + + buf := make([]byte, tcase.dataSize) + for i := range buf { + buf[i] = 0xff + } + br := NewBytesReader(buf) + tcase.handle(br) + } +} + +type mockWriter struct { + dataSize int +} + +func (w *mockWriter) Write(p []byte) (n int, err error) { + if w.dataSize != len(p) { + return 0, fmt.Errorf("length is not %d", w.dataSize) + } + for _, b := range p { + if b != 0xff { + return 0, errors.New("data not equal") + } + } + return len(p), nil +} + +func setBytes(b []byte, v byte) { + for i := range b { + b[i] = v + } +} + +func TestDefaultWriter(t *testing.T) { + tcases := []struct { + dataSize int + handle func(writer Writer) + }{ + { + dataSize: 1024 * 18, + handle: func(writer Writer) { + buf, err := writer.Malloc(1024) + if err != nil { + t.Fatal(err) + } + if writer.WrittenLen() != 1024 { + t.Fatal("written len is not 1024") + } + buf1, err := writer.Malloc(1024) + if err != nil { + t.Fatal(err) + } + if writer.WrittenLen() != 1024*2 { + t.Fatal("written len is not 1024*2") + } + buf2, err := writer.Malloc(1024 * 4) + if err != nil { + t.Fatal(err) + } + if writer.WrittenLen() != 1024*6 { + t.Fatal("written len is not 1024*6") + } + buf3, err := writer.Malloc(1024 * 12) + if err != nil { + t.Fatal(err) + } + if writer.WrittenLen() != 1024*18 { + t.Fatal("written len is not 1024*18") + } + setBytes(buf3, 0xff) + setBytes(buf2, 0xff) + setBytes(buf1, 0xff) + setBytes(buf, 0xff) + if err = writer.Flush(); err != nil { + t.Fatal(err) + } + }, + }, + } + for _, tcase := range tcases { + w := NewDefaultWriter(&mockWriter{dataSize: tcase.dataSize}) + tcase.handle(w) + + var buf []byte + bw := NewBytesWriter(&buf) + tcase.handle(bw) + if len(buf) != tcase.dataSize { + t.Fatal("write data size is not equal!") + } + } +} diff --git a/bytex/buffer.go b/bytex/buffer.go new file mode 100644 index 0000000..3051f34 --- /dev/null +++ b/bytex/buffer.go @@ -0,0 +1,43 @@ +package bytex + +const BinaryInplaceThreshold = 1024 + +// A Buffer is a variable-sized buffer of bytes with [Buffer.Read] and [Buffer.Write] methods. +// The zero value for Buffer is an empty buffer ready to use. +type Buffer struct { + buf [][]byte + off1, off2 int + + releaseBuf []byte +} + +func (b *Buffer) Write(p []byte) (n int, err error) { + // fast path + if len(p) >= BinaryInplaceThreshold { + buf := b.buf[len(b.buf)-1] + buf = buf[len(buf):len(buf):cap(buf)] + b.buf = append(b.buf, p, buf) + return len(p), nil + } + buf := b.buf[len(b.buf)-1] + if cap(buf)-len(buf) >= len(p) { + n = copy(buf, p) + return + } + + // slow path + b.slowWrite() +} + +func (b *Buffer) Read(p []byte) (n int, err error) { + // fast path + buf := b.buf[b.off1] + if len(p) < len(buf)-b.off2 { + n = copy(p, buf[b.off2:]) + b.off2 += n + return + } + + // slow path + b.slowRead() +} diff --git a/go.mod b/go.mod index db9c117..a7121fa 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,8 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/kr/pretty v0.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/net v0.0.0-20221014081412-f15817d10f9b // indirect + golang.org/x/text v0.3.7 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c69d798..8796672 100644 --- a/go.sum +++ b/go.sum @@ -21,11 +21,13 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/net v0.0.0-20221014081412-f15817d10f9b h1:tvrvnPFcdzp294diPnrdZZZ8XUt2Tyj7svb7X52iDuU= golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/protocol/ttheader/decode.go b/protocol/ttheader/decode.go new file mode 100644 index 0000000..954185a --- /dev/null +++ b/protocol/ttheader/decode.go @@ -0,0 +1,240 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttheader + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/gopkg/protocol/util" +) + +const ( + // MagicMask is bit mask for checking version. + MagicMask = 0xffff0000 +) + +type DecodeParam struct { + Flags HeaderFlags + + SeqID int32 + + ProtocolID ProtocolID + + // IntInfo is used to set up int key-value info into InfoIDIntKeyValue + IntInfo map[uint16]string + + // StrInfo is used to set up string key-value info into InfoIDKeyValue + StrInfo map[string]string + + PayloadLen int +} + +func DecodeFromBytes(ctx context.Context, bs []byte) (param DecodeParam, err error) { + in := bufiox.NewBytesReader(bs) + param, err = Decode(ctx, in) + _ = in.Release(nil) + return +} + +func Decode(ctx context.Context, in bufiox.Reader) (param DecodeParam, err error) { + var headerMeta []byte + headerMeta, err = in.Next(TTHeaderMetaSize) + if err != nil { + return + } + if !IsTTHeader(headerMeta) { + err = errors.New("not TTHeader protocol") + return + } + totalLen := util.Bytes2Uint32NoCheck(headerMeta[:Size32]) + + flags := util.Bytes2Uint16NoCheck(headerMeta[Size16*3:]) + param.Flags = HeaderFlags(flags) + + seqID := util.Bytes2Uint32NoCheck(headerMeta[Size32*2 : Size32*3]) + param.SeqID = int32(seqID) + + headerInfoSize := util.Bytes2Uint16NoCheck(headerMeta[Size32*3:TTHeaderMetaSize]) * 4 + if uint32(headerInfoSize) > MaxHeaderSize || headerInfoSize < 2 { + err = fmt.Errorf("invalid header length[%d]", headerInfoSize) + return + } + + var headerInfo []byte + if headerInfo, err = in.Next(int(headerInfoSize)); err != nil { + return + } + if err = checkProtocolID(headerInfo[0]); err != nil { + return + } + hdIdx := 2 + transformIDNum := int(headerInfo[1]) + if int(headerInfoSize)-hdIdx < transformIDNum { + err = fmt.Errorf("need read %d transformIDs, but not enough", transformIDNum) + return + } + transformIDs := make([]uint8, transformIDNum) + for i := 0; i < transformIDNum; i++ { + transformIDs[i] = headerInfo[hdIdx] + hdIdx++ + } + + param.IntInfo, param.StrInfo, err = readKVInfo(hdIdx, headerInfo) + if err != nil { + err = fmt.Errorf("ttHeader read kv info failed, %s, headerInfo=%#x", err.Error(), headerInfo) + return + } + + param.PayloadLen = int(totalLen - uint32(headerInfoSize) + Size32 - TTHeaderMetaSize) + return +} + +/** + * +------------------------------------------------------------+ + * | 4Byte | 2Byte | + * +------------------------------------------------------------+ + * | Length | HEADER MAGIC | + * +------------------------------------------------------------+ + */ + +func IsTTHeader(flagBuf []byte) bool { + return binary.BigEndian.Uint32(flagBuf[Size32:])&MagicMask == TTHeaderMagic +} + +func readKVInfo(idx int, buf []byte) (intKVMap map[uint16]string, strKVMap map[string]string, err error) { + for { + var infoID uint8 + infoID, err = util.Bytes2Uint8(buf, idx) + idx++ + if err != nil { + // this is the last field, read until there is no more padding + if err == io.EOF { + break + } else { + return + } + } + switch InfoIDType(infoID) { + case InfoIDPadding: + continue + case InfoIDKeyValue: + if strKVMap == nil { + strKVMap = make(map[string]string) + } + _, err = readStrKVInfo(&idx, buf, strKVMap) + if err != nil { + return + } + case InfoIDIntKeyValue: + if intKVMap == nil { + intKVMap = make(map[uint16]string) + } + _, err = readIntKVInfo(&idx, buf, intKVMap) + if err != nil { + return + } + case InfoIDACLToken: + if strKVMap == nil { + strKVMap = make(map[string]string) + } + if err = readACLToken(&idx, buf, strKVMap); err != nil { + return + } + default: + err = fmt.Errorf("invalid infoIDType[%#x]", infoID) + return + } + } + return +} + +func readIntKVInfo(idx *int, buf []byte, info map[uint16]string) (has bool, err error) { + kvSize, err := util.Bytes2Uint16(buf, *idx) + *idx += 2 + if err != nil { + return false, fmt.Errorf("error reading int kv info size: %s", err.Error()) + } + if kvSize <= 0 { + return false, nil + } + for i := uint16(0); i < kvSize; i++ { + key, err := util.Bytes2Uint16(buf, *idx) + *idx += 2 + if err != nil { + return false, fmt.Errorf("error reading int kv info: %s", err.Error()) + } + val, n, err := util.ReadString2BLen(buf, *idx) + *idx += n + if err != nil { + return false, fmt.Errorf("error reading int kv info: %s", err.Error()) + } + info[key] = val + } + return true, nil +} + +func readStrKVInfo(idx *int, buf []byte, info map[string]string) (has bool, err error) { + kvSize, err := util.Bytes2Uint16(buf, *idx) + *idx += 2 + if err != nil { + return false, fmt.Errorf("error reading str kv info size: %s", err.Error()) + } + if kvSize <= 0 { + return false, nil + } + for i := uint16(0); i < kvSize; i++ { + key, n, err := util.ReadString2BLen(buf, *idx) + *idx += n + if err != nil { + return false, fmt.Errorf("error reading str kv info: %s", err.Error()) + } + val, n, err := util.ReadString2BLen(buf, *idx) + *idx += n + if err != nil { + return false, fmt.Errorf("error reading str kv info: %s", err.Error()) + } + info[key] = val + } + return true, nil +} + +// readACLToken reads acl token +func readACLToken(idx *int, buf []byte, info map[string]string) error { + val, n, err := util.ReadString2BLen(buf, *idx) + *idx += n + if err != nil { + return fmt.Errorf("error reading acl token: %s", err.Error()) + } + info[GDPRToken] = val + return nil +} + +// protoID just for ttheader +func checkProtocolID(protoID uint8) error { + switch protoID { + case uint8(ProtocolIDThriftBinary): + case uint8(ProtocolIDKitexProtobuf): + case uint8(ProtocolIDThriftCompactV2): + // just for compatibility + default: + return fmt.Errorf("unsupported ProtocolID[%d]", protoID) + } + return nil +} diff --git a/protocol/ttheader/encode.go b/protocol/ttheader/encode.go new file mode 100644 index 0000000..18964cb --- /dev/null +++ b/protocol/ttheader/encode.go @@ -0,0 +1,338 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttheader + +import ( + "context" + "encoding/binary" + "fmt" + + "github.com/bytedance/gopkg/cloud/metainfo" + "github.com/cloudwego/gopkg/bufiox" +) + +/** + * TTHeader Protocol + * +-------------2Byte--------------|-------------2Byte-------------+ + * +----------------------------------------------------------------+ + * | 0| LENGTH | + * +----------------------------------------------------------------+ + * | 0| HEADER MAGIC | FLAGS | + * +----------------------------------------------------------------+ + * | SEQUENCE NUMBER | + * +----------------------------------------------------------------+ + * | 0| Header Size(/32) | ... + * +--------------------------------- + * + * Header is of variable size: + * (and starts at offset 14) + * + * +----------------------------------------------------------------+ + * | PROTOCOL ID |NUM TRANSFORMS . |TRANSFORM 0 ID (uint8)| + * +----------------------------------------------------------------+ + * | TRANSFORM 0 DATA ... + * +----------------------------------------------------------------+ + * | ... ... | + * +----------------------------------------------------------------+ + * | INFO 0 ID (uint8) | INFO 0 DATA ... + * +----------------------------------------------------------------+ + * | ... ... | + * +----------------------------------------------------------------+ + * | | + * | PAYLOAD | + * | | + * +----------------------------------------------------------------+ + */ + +// The byte count of 32 and 16 integer values. +const ( + Size32 = 4 + Size16 = 2 +) + +// Header keys +const ( + // Header Magics + // 0 and 16th bits must be 0 to differentiate from framed & unframed + TTHeaderMagic uint32 = 0x10000000 + MeshHeaderMagic uint32 = 0xFFAF0000 + MeshHeaderLenMask uint32 = 0x0000FFFF + + // HeaderMask uint32 = 0xFFFF0000 + FlagsMask uint32 = 0x0000FFFF + MethodMask uint32 = 0x41000000 // method first byte [A-Za-z_] + MaxFrameSize uint32 = 0x3FFFFFFF + MaxHeaderSize uint32 = 65536 + + initialBufferSize = 256 +) + +type HeaderFlags uint16 + +const ( + HeaderFlagsKey string = "HeaderFlags" + HeaderFlagSupportOutOfOrder HeaderFlags = 0x01 + HeaderFlagDuplexReverse HeaderFlags = 0x08 + HeaderFlagSASL HeaderFlags = 0x10 +) + +const ( + TTHeaderMetaSize = 14 +) + +// ProtocolID is the wrapped protocol id used in THeader. +type ProtocolID uint8 + +// Supported ProtocolID values. +const ( + ProtocolIDThriftBinary ProtocolID = 0x00 + ProtocolIDThriftCompact ProtocolID = 0x02 // Kitex not support + ProtocolIDThriftCompactV2 ProtocolID = 0x03 // Kitex not support + ProtocolIDKitexProtobuf ProtocolID = 0x04 + ProtocolIDDefault = ProtocolIDThriftBinary +) + +type InfoIDType uint8 // uint8 + +const ( + InfoIDPadding InfoIDType = 0 + InfoIDKeyValue InfoIDType = 0x01 + InfoIDIntKeyValue InfoIDType = 0x10 + InfoIDACLToken InfoIDType = 0x11 +) + +// key of acl token +// You can set up acl token through metainfo. +// eg: +// +// ctx = metainfo.WithValue(ctx, "gdpr-token", "your token") +const ( + // GDPRToken is used to set up gdpr token into InfoIDACLToken + GDPRToken = metainfo.PrefixTransient + "gdpr-token" +) + +type EncodeParam struct { + Flags HeaderFlags + + SeqID int32 + + ProtocolID ProtocolID + + // IntInfo is used to set up int key-value info into InfoIDIntKeyValue + IntInfo map[uint16]string + + // StrInfo is used to set up string key-value info into InfoIDKeyValue + StrInfo map[string]string +} + +// EncodeToBytes encode ttheader to bytes. +// NOTICE: Must call +// +// `binary.BigEndian.PutUint32(buf, uint32(totalLen))` +// +// after encoding both header and payload data to set total length of a request/response. +// And `totalLen` should be the length of header + payload - 4. +// You may refer to unit tests for examples. +func EncodeToBytes(ctx context.Context, param EncodeParam) (buf []byte, err error) { + out := bufiox.NewBytesWriter(&buf) + if _, err = Encode(ctx, param, out); err != nil { + return + } + if err = out.Flush(); err != nil { + return + } + return +} + +func Encode(ctx context.Context, param EncodeParam, out bufiox.Writer) (totalLenField []byte, err error) { + // 1. header meta + var headerMeta []byte + headerMeta, err = out.Malloc(TTHeaderMetaSize) + if err != nil { + return nil, fmt.Errorf("ttHeader malloc header meta failed, %s", err.Error()) + } + + totalLenField = headerMeta[0:4] + headerInfoSizeField := headerMeta[12:14] + binary.BigEndian.PutUint32(headerMeta[4:8], TTHeaderMagic+uint32(param.Flags)) + binary.BigEndian.PutUint32(headerMeta[8:12], uint32(param.Flags)) + + var transformIDs []uint8 // transformIDs not support TODO compress + // 2. header info, malloc and write + if err = WriteByte(byte(param.ProtocolID), out); err != nil { + return nil, fmt.Errorf("ttHeader write protocol id failed, %s", err.Error()) + } + if err = WriteByte(byte(len(transformIDs)), out); err != nil { + return nil, fmt.Errorf("ttHeader write transformIDs length failed, %s", err.Error()) + } + for tid := range transformIDs { + if err = WriteByte(byte(tid), out); err != nil { + return nil, fmt.Errorf("ttHeader write transformIDs failed, %s", err.Error()) + } + } + // PROTOCOL ID(u8) + NUM TRANSFORMS(always 0)(u8) + TRANSFORM IDs([]u8) + headerInfoSize := 1 + 1 + len(transformIDs) + headerInfoSize, err = writeKVInfo(headerInfoSize, param.IntInfo, param.StrInfo, out) + if err != nil { + return nil, fmt.Errorf("ttHeader write kv info failed, %s", err.Error()) + } + + if uint32(headerInfoSize) > MaxHeaderSize { + return nil, fmt.Errorf("invalid header length[%d]", headerInfoSize) + } + binary.BigEndian.PutUint16(headerInfoSizeField, uint16(headerInfoSize/4)) + return totalLenField, nil +} + +func writeKVInfo(writtenSize int, intKVMap map[uint16]string, strKVMap map[string]string, out bufiox.Writer) (writeSize int, err error) { + writeSize = writtenSize + // str kv info + strKVSize := len(strKVMap) + // write gdpr token into InfoIDACLToken + // supplementary doc: https://www.cloudwego.io/docs/kitex/reference/transport_protocol_ttheader/ + if gdprToken, ok := strKVMap[GDPRToken]; ok { + strKVSize-- + // INFO ID TYPE(u8) + if err = WriteByte(byte(InfoIDACLToken), out); err != nil { + return writeSize, err + } + writeSize += 1 + + wLen, err := WriteString2BLen(gdprToken, out) + if err != nil { + return writeSize, err + } + writeSize += wLen + } + + if strKVSize > 0 { + // INFO ID TYPE(u8) + NUM HEADERS(u16) + if err = WriteByte(byte(InfoIDKeyValue), out); err != nil { + return writeSize, err + } + if err = WriteUint16(uint16(strKVSize), out); err != nil { + return writeSize, err + } + writeSize += 3 + for key, val := range strKVMap { + if key == GDPRToken { + continue + } + keyWLen, err := WriteString2BLen(key, out) + if err != nil { + return writeSize, err + } + valWLen, err := WriteString2BLen(val, out) + if err != nil { + return writeSize, err + } + writeSize = writeSize + keyWLen + valWLen + } + } + + // int kv info + intKVSize := len(intKVMap) + if intKVSize > 0 { + // INFO ID TYPE(u8) + NUM HEADERS(u16) + if err = WriteByte(byte(InfoIDIntKeyValue), out); err != nil { + return writeSize, err + } + if err = WriteUint16(uint16(intKVSize), out); err != nil { + return writeSize, err + } + writeSize += 3 + for key, val := range intKVMap { + if err = WriteUint16(key, out); err != nil { + return writeSize, err + } + valWLen, err := WriteString2BLen(val, out) + if err != nil { + return writeSize, err + } + writeSize = writeSize + 2 + valWLen + } + } + + // padding = (4 - headerInfoSize%4) % 4 + padding := (4 - writeSize%4) % 4 + paddingBuf, err := out.Malloc(padding) + if err != nil { + return writeSize, err + } + for i := 0; i < len(paddingBuf); i++ { + paddingBuf[i] = byte(0) + } + writeSize += padding + return +} + +// WriteByte ... +func WriteByte(val byte, out bufiox.Writer) error { + var buf []byte + var err error + if buf, err = out.Malloc(1); err != nil { + return err + } + buf[0] = val + return nil +} + +// WriteUint32 ... +func WriteUint32(val uint32, out bufiox.Writer) error { + var buf []byte + var err error + if buf, err = out.Malloc(Size32); err != nil { + return err + } + binary.BigEndian.PutUint32(buf, val) + return nil +} + +// WriteUint16 ... +func WriteUint16(val uint16, out bufiox.Writer) error { + var buf []byte + var err error + if buf, err = out.Malloc(Size16); err != nil { + return err + } + binary.BigEndian.PutUint16(buf, val) + return nil +} + +// WriteString ... +func WriteString(val string, out bufiox.Writer) (int, error) { + var buf []byte + var err error + strLen := len(val) + if buf, err = out.Malloc(Size32 + strLen); err != nil { + return 0, err + } + binary.BigEndian.PutUint32(buf, uint32(strLen)) + copy(buf[Size32:], val) + return Size32 + strLen, nil +} + +// WriteString2BLen ... +func WriteString2BLen(val string, out bufiox.Writer) (int, error) { + var buf []byte + var err error + strLen := len(val) + if buf, err = out.Malloc(Size16 + strLen); err != nil { + return 0, err + } + binary.BigEndian.PutUint16(buf, uint16(strLen)) + copy(buf[Size16:], val) + return Size16 + strLen, nil +} diff --git a/protocol/ttheader/encode_test.go b/protocol/ttheader/encode_test.go new file mode 100644 index 0000000..8b19b68 --- /dev/null +++ b/protocol/ttheader/encode_test.go @@ -0,0 +1,15 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttheader diff --git a/protocol/util/bytebuffer.go b/protocol/util/bytebuffer.go new file mode 100644 index 0000000..5e9955f --- /dev/null +++ b/protocol/util/bytebuffer.go @@ -0,0 +1,48 @@ +package util + +import ( + "encoding/binary" + "io" +) + +// Bytes2Uint32NoCheck ... +func Bytes2Uint32NoCheck(bytes []byte) uint32 { + return binary.BigEndian.Uint32(bytes) +} + +// Bytes2Uint16NoCheck ... +func Bytes2Uint16NoCheck(bytes []byte) uint16 { + return binary.BigEndian.Uint16(bytes) +} + +// Bytes2Uint8 ... +func Bytes2Uint8(bytes []byte, off int) (uint8, error) { + if len(bytes)-off < 1 { + return 0, io.EOF + } + return bytes[off], nil +} + +// Bytes2Uint16 ... +func Bytes2Uint16(bytes []byte, off int) (uint16, error) { + if len(bytes)-off < 2 { + return 0, io.EOF + } + return binary.BigEndian.Uint16(bytes[off:]), nil +} + +// ReadString2BLen ... +func ReadString2BLen(bytes []byte, off int) (string, int, error) { + length, err := Bytes2Uint16(bytes, off) + strLen := int(length) + if err != nil { + return "", 0, err + } + off += 2 + if len(bytes)-off < strLen { + return "", 0, io.EOF + } + + buf := bytes[off : off+strLen] + return string(buf), int(length) + 2, nil +}