From 8e02424292690c802bef99ce772336d7393064ef Mon Sep 17 00:00:00 2001 From: Edward Viaene Date: Fri, 13 Sep 2024 14:10:03 -0500 Subject: [PATCH] make buffer thread safe --- pkg/observability/buffer.go | 39 +++++-- pkg/observability/buffer_test.go | 171 +++++++++++++++++++++++++++---- pkg/observability/constants.go | 2 +- pkg/observability/new.go | 13 ++- pkg/observability/types.go | 10 +- 5 files changed, 198 insertions(+), 37 deletions(-) diff --git a/pkg/observability/buffer.go b/pkg/observability/buffer.go index dc34ced..07b63de 100644 --- a/pkg/observability/buffer.go +++ b/pkg/observability/buffer.go @@ -1,6 +1,7 @@ package observability import ( + "bytes" "fmt" "io" "strconv" @@ -10,17 +11,24 @@ import ( ) func (o *Observability) WriteBufferToStorage(n int64) error { - o.BufferMu.Lock() - defer o.BufferMu.Unlock() + o.ActiveBufferWriters.Add(1) + defer o.ActiveBufferWriters.Done() + // copy first to temporary buffer (storage might have latency) + tempBuf := bytes.NewBuffer(make([]byte, n)) + _, err := io.CopyN(tempBuf, o.Buffer, n) + o.LastFlushed = time.Now() + if err != nil && err != io.EOF { + return fmt.Errorf("write error from buffer to temporary buffer: %s", err) + } + file, err := o.Storage.OpenFileForWriting("data-" + time.Now().Format("2003-01-02T15:04:05") + "-" + strconv.FormatUint(o.FlushOverflowSequence.Add(1), 10)) if err != nil { return fmt.Errorf("open file for writing error: %s", err) } - _, err = io.CopyN(file, &o.Buffer, n) + _, err = io.Copy(file, tempBuf) if err != nil { return fmt.Errorf("file write error: %s", err) } - o.LastFlushed = time.Now() return file.Close() } @@ -50,13 +58,11 @@ func (o *Observability) Ingest(data io.ReadCloser) error { _, err = o.Buffer.Write(encodeMessage(msgs)) if err != nil { return fmt.Errorf("write error: %s", err) - } - fmt.Printf("Buffer size: %d\n", o.Buffer.Len()) - if o.Buffer.Len() >= MAX_BUFFER_SIZE { + if o.Buffer.Len() >= o.MaxBufferSize { if o.FlushOverflow.CompareAndSwap(false, true) { go func() { // write to storage - if n := o.Buffer.Len(); n >= MAX_BUFFER_SIZE { + if n := o.Buffer.Len(); n >= o.MaxBufferSize { err := o.WriteBufferToStorage(int64(n)) if err != nil { logging.ErrorLog(fmt.Errorf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err)) @@ -68,3 +74,20 @@ func (o *Observability) Ingest(data io.ReadCloser) error { } return nil } + +func (c *ConcurrentRWBuffer) Write(p []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() + return c.buffer.Write(p) +} +func (c *ConcurrentRWBuffer) Read(p []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() + return c.buffer.Read(p) +} +func (c *ConcurrentRWBuffer) Len() int { + return c.buffer.Len() +} +func (c *ConcurrentRWBuffer) Cap() int { + return c.buffer.Cap() +} diff --git a/pkg/observability/buffer_test.go b/pkg/observability/buffer_test.go index 1b84004..b58a4e4 100644 --- a/pkg/observability/buffer_test.go +++ b/pkg/observability/buffer_test.go @@ -7,17 +7,15 @@ import ( "io" "strconv" "testing" - "time" memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestIngestion(t *testing.T) { - t.Skip() // working on this test + totalMessagesToGenerate := 1000 storage := &memorystorage.MockMemoryStorage{} - o := &Observability{ - Storage: storage, - } + o := NewWithoutMonitor(20) + o.Storage = storage payload := IncomingData{ { "date": 1720613813.197045, @@ -25,7 +23,7 @@ func TestIngestion(t *testing.T) { }, } - for i := 0; i < MAX_BUFFER_SIZE; i++ { + for i := 0; i < totalMessagesToGenerate; i++ { payload[0]["log"] = "this is string: " + strconv.Itoa(i) payloadBytes, err := json.Marshal(payload) if err != nil { @@ -38,17 +36,14 @@ func TestIngestion(t *testing.T) { } } - // flush remaining data - time.Sleep(1 * time.Second) - if o.Buffer.Len() >= MAX_BUFFER_SIZE { - if o.FlushOverflow.CompareAndSwap(false, true) { - if n := o.Buffer.Len(); n >= MAX_BUFFER_SIZE { - err := o.WriteBufferToStorage(int64(n)) - if err != nil { - t.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err) - } - } - o.FlushOverflow.Swap(false) + // wait until all data is flushed + o.ActiveBufferWriters.Wait() + + // flush remaining data that hasn't been flushed + if n := o.Buffer.Len(); n >= 0 { + err := o.WriteBufferToStorage(int64(n)) + if err != nil { + t.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err) } } @@ -64,13 +59,145 @@ func TestIngestion(t *testing.T) { t.Fatalf("read file error: %s", err) } decodedMessages := decodeMessage(messages) - for _, message := range decodedMessages { - fmt.Printf("decoded message: %s\n", message.Data["log"]) + totalMessages += len(decodedMessages) + } + if len(dirlist) == 0 { + t.Fatalf("expected multiple files in directory, got %d", len(dirlist)) + } + + if totalMessages != totalMessagesToGenerate { + t.Fatalf("Tried to generate total message count of: %d; got: %d", totalMessagesToGenerate, totalMessages) + } +} + +func TestIngestionMoreMessages(t *testing.T) { + totalMessagesToGenerate := 10000000 // 10,000,000 + storage := &memorystorage.MockMemoryStorage{} + o := NewWithoutMonitor(MAX_BUFFER_SIZE) + o.Storage = storage + payload := IncomingData{ + { + "date": 1720613813.197045, + "log": "this is string: ", + }, + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + + for i := 0; i < totalMessagesToGenerate; i++ { + data := io.NopCloser(bytes.NewReader(payloadBytes)) + err := o.Ingest(data) + if err != nil { + t.Fatalf("ingest error: %s", err) + } + } + + // wait until all data is flushed + o.ActiveBufferWriters.Wait() + + // flush remaining data that hasn't been flushed + if n := o.Buffer.Len(); n >= 0 { + err := o.WriteBufferToStorage(int64(n)) + if err != nil { + t.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err) } + } + + dirlist, err := storage.ReadDir("") + if err != nil { + t.Fatalf("read dir error: %s", err) + } + + totalMessages := 0 + for _, file := range dirlist { + messages, err := storage.ReadFile(file) + if err != nil { + t.Fatalf("read file error: %s", err) + } + decodedMessages := decodeMessage(messages) totalMessages += len(decodedMessages) } - fmt.Printf("totalmessages: %d", totalMessages) - if len(dirlist) != 3 { - t.Fatalf("expected 3 files in directory, got %d", len(dirlist)) + if len(dirlist) == 0 { + t.Fatalf("expected multiple files in directory, got %d", len(dirlist)) + } + + if totalMessages != totalMessagesToGenerate { + t.Fatalf("Tried to generate total message count of: %d; got: %d", totalMessagesToGenerate, totalMessages) + } + fmt.Printf("Buffer size (read+unread): %d in %d files\n", o.Buffer.Cap(), len(dirlist)) + +} + +func BenchmarkIngest10000000(b *testing.B) { + totalMessagesToGenerate := 10000000 // 10,000,000 + storage := &memorystorage.MockMemoryStorage{} + o := NewWithoutMonitor(MAX_BUFFER_SIZE) + o.Storage = storage + payload := IncomingData{ + { + "date": 1720613813.197045, + "log": "this is string", + }, + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + b.Fatalf("marshal error: %s", err) + } + + for i := 0; i < totalMessagesToGenerate; i++ { + data := io.NopCloser(bytes.NewReader(payloadBytes)) + err := o.Ingest(data) + if err != nil { + b.Fatalf("ingest error: %s", err) + } + } + + // wait until all data is flushed + o.ActiveBufferWriters.Wait() + + // flush remaining data that hasn't been flushed + if n := o.Buffer.Len(); n >= 0 { + err := o.WriteBufferToStorage(int64(n)) + if err != nil { + b.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err) + } + } +} + +func BenchmarkIngest100000000(b *testing.B) { + totalMessagesToGenerate := 10000000 // 10,000,000 + storage := &memorystorage.MockMemoryStorage{} + o := NewWithoutMonitor(MAX_BUFFER_SIZE) + o.Storage = storage + payload := IncomingData{ + { + "date": 1720613813.197045, + "log": "this is string", + }, + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + b.Fatalf("marshal error: %s", err) + } + + for i := 0; i < totalMessagesToGenerate; i++ { + data := io.NopCloser(bytes.NewReader(payloadBytes)) + err := o.Ingest(data) + if err != nil { + b.Fatalf("ingest error: %s", err) + } + } + + // wait until all data is flushed + o.ActiveBufferWriters.Wait() + + // flush remaining data that hasn't been flushed + if n := o.Buffer.Len(); n >= 0 { + err := o.WriteBufferToStorage(int64(n)) + if err != nil { + b.Fatalf("write log buffer to storage error (buffer: %d): %s", o.Buffer.Len(), err) + } } } diff --git a/pkg/observability/constants.go b/pkg/observability/constants.go index 1fcbb16..0427072 100644 --- a/pkg/observability/constants.go +++ b/pkg/observability/constants.go @@ -1,4 +1,4 @@ package observability -const MAX_BUFFER_SIZE = 100 +const MAX_BUFFER_SIZE = 1024 * 1024 // 1 MB const FLUSH_TIME_MAX_MINUTES = 5 diff --git a/pkg/observability/new.go b/pkg/observability/new.go index ae7dee6..e630747 100644 --- a/pkg/observability/new.go +++ b/pkg/observability/new.go @@ -1,14 +1,19 @@ package observability -import "net/http" +import ( + "net/http" +) func New() *Observability { - o := &Observability{} + o := NewWithoutMonitor(MAX_BUFFER_SIZE) go o.monitorBuffer() return o } -func NewWithoutMonitor() *Observability { - o := &Observability{} +func NewWithoutMonitor(maxBufferSize int) *Observability { + o := &Observability{ + Buffer: &ConcurrentRWBuffer{}, + MaxBufferSize: maxBufferSize, + } return o } diff --git a/pkg/observability/types.go b/pkg/observability/types.go index be85ef7..99a0f42 100644 --- a/pkg/observability/types.go +++ b/pkg/observability/types.go @@ -18,9 +18,15 @@ type FluentBitMessage struct { type Observability struct { Storage storage.Iface - Buffer bytes.Buffer + Buffer *ConcurrentRWBuffer LastFlushed time.Time - BufferMu sync.Mutex FlushOverflow atomic.Bool FlushOverflowSequence atomic.Uint64 + ActiveBufferWriters sync.WaitGroup + MaxBufferSize int +} + +type ConcurrentRWBuffer struct { + buffer bytes.Buffer + mu sync.Mutex }