Skip to content

Commit

Permalink
refactor(api/sessionrecording): move size and message bytes to Reader…
Browse files Browse the repository at this point in the history
… struct

This will prevent Go from needing to allocate and zero out 64K for each
invocation of Read. This may possibly reduce the chances of Go doing a
heap allocation.
  • Loading branch information
dustinspecker committed Jan 8, 2025
1 parent 09b0884 commit 4ca0689
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions api/sessionrecording/session_recording.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ type Reader struct {
padding int64
// rawReader is the raw data source we read from
rawReader io.Reader
// sizeBytes is used to hold the header of the current event being parsed
sizeBytes [Int64Size]byte
// messageBytes holds the current decompressed event being parsed
messageBytes [MaxProtoMessageSizeBytes]byte
// state tracks where the Reader is at in consuming a session recording
state int
// error holds any error encountered while reading a session recording
Expand Down Expand Up @@ -147,8 +151,6 @@ func (r *Reader) GetStats() ReaderStats {

// Read returns next event or io.EOF in case of the end of the parts
func (r *Reader) Read(ctx context.Context) (apievents.AuditEvent, error) {
var sizeBytes [Int64Size]byte

// periodic checks of context after fixed amount of iterations
// is an extra precaution to avoid
// accidental endless loop due to logic error crashing the system
Expand All @@ -172,28 +174,28 @@ func (r *Reader) Read(ctx context.Context) (apievents.AuditEvent, error) {
case protoReaderStateInit:
// read the part header that consists of the protocol version
// and the part size (for the V1 version of the protocol)
if _, err := io.ReadFull(r.rawReader, sizeBytes[:Int64Size]); err != nil {
if _, err := io.ReadFull(r.rawReader, r.sizeBytes[:Int64Size]); err != nil {
// reached the end of the stream
if errors.Is(err, io.EOF) {
r.state = protoReaderStateEOF
return nil, err
}
return nil, r.setError(trace.ConvertSystemError(err))
}
protocolVersion := binary.BigEndian.Uint64(sizeBytes[:Int64Size])
protocolVersion := binary.BigEndian.Uint64(r.sizeBytes[:Int64Size])
if protocolVersion != ProtoStreamV1 {
return nil, trace.BadParameter("unsupported protocol version %v", protocolVersion)
}
// read size of this gzipped part as encoded by V1 protocol version
if _, err := io.ReadFull(r.rawReader, sizeBytes[:Int64Size]); err != nil {
if _, err := io.ReadFull(r.rawReader, r.sizeBytes[:Int64Size]); err != nil {
return nil, r.setError(trace.ConvertSystemError(err))
}
partSize := binary.BigEndian.Uint64(sizeBytes[:Int64Size])
partSize := binary.BigEndian.Uint64(r.sizeBytes[:Int64Size])
// read padding size (could be 0)
if _, err := io.ReadFull(r.rawReader, sizeBytes[:Int64Size]); err != nil {
if _, err := io.ReadFull(r.rawReader, r.sizeBytes[:Int64Size]); err != nil {
return nil, r.setError(trace.ConvertSystemError(err))
}
r.padding = int64(binary.BigEndian.Uint64(sizeBytes[:Int64Size]))
r.padding = int64(binary.BigEndian.Uint64(r.sizeBytes[:Int64Size]))
r.partReader = io.LimitReader(r.rawReader, int64(partSize))
gzipReader, err := gzip.NewReader(r.partReader)
// older bugged versions of teleport would sometimes incorrectly inject padding bytes into
Expand All @@ -209,19 +211,17 @@ func (r *Reader) Read(ctx context.Context) (apievents.AuditEvent, error) {
continue
// read the next version from the gzip reader
case protoReaderStateCurrent:
var messageBytes [MaxProtoMessageSizeBytes]byte

// the record consists of length of the protobuf encoded
// message and the message itself
if _, err := io.ReadFull(r.gzipReader, sizeBytes[:Int32Size]); err != nil {
if _, err := io.ReadFull(r.gzipReader, r.sizeBytes[:Int32Size]); err != nil {
if !errors.Is(err, io.EOF) {
return nil, r.setError(trace.ConvertSystemError(err))
}

// due to a bug in older versions of teleport it was possible that padding
// bytes would end up inside of the gzip section of the archive. we should
// skip any dangling data in the gzip secion.
n, err := io.CopyBuffer(io.Discard, r.partReader, messageBytes[:])
n, err := io.CopyBuffer(io.Discard, r.partReader, r.messageBytes[:])
if err != nil {
return nil, r.setError(trace.ConvertSystemError(err))
}
Expand All @@ -238,7 +238,7 @@ func (r *Reader) Read(ctx context.Context) (apievents.AuditEvent, error) {
return nil, r.setError(trace.ConvertSystemError(err))
}
if r.padding != 0 {
skipped, err := io.CopyBuffer(io.Discard, io.LimitReader(r.rawReader, r.padding), messageBytes[:])
skipped, err := io.CopyBuffer(io.Discard, io.LimitReader(r.rawReader, r.padding), r.messageBytes[:])
if err != nil {
return nil, r.setError(trace.ConvertSystemError(err))
}
Expand All @@ -252,19 +252,19 @@ func (r *Reader) Read(ctx context.Context) (apievents.AuditEvent, error) {
r.state = protoReaderStateInit
continue
}
messageSize := binary.BigEndian.Uint32(sizeBytes[:Int32Size])
messageSize := binary.BigEndian.Uint32(r.sizeBytes[:Int32Size])
// zero message size indicates end of the part
// that sometimes is present in partially submitted parts
// that have to be filled with zeroes for parts smaller
// than minimum allowed size
if messageSize == 0 {
return nil, r.setError(trace.BadParameter("unexpected message size 0"))
}
if _, err := io.ReadFull(r.gzipReader, messageBytes[:messageSize]); err != nil {
if _, err := io.ReadFull(r.gzipReader, r.messageBytes[:messageSize]); err != nil {
return nil, r.setError(trace.ConvertSystemError(err))
}
var oneof apievents.OneOf
if err := oneof.Unmarshal(messageBytes[:messageSize]); err != nil {
if err := oneof.Unmarshal(r.messageBytes[:messageSize]); err != nil {
return nil, trace.Wrap(err)
}
event, err := apievents.FromOneOf(oneof)
Expand Down

0 comments on commit 4ca0689

Please sign in to comment.