From 5b679338fbb6db4c85c265f94d326083f5c45334 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Sat, 7 Dec 2024 16:52:13 +0200 Subject: [PATCH] feat: Implemented gzip_reader, added tests --- internal/db/postgres/restorers/base.go | 66 ++++++++--- internal/db/postgres/restorers/base_test.go | 49 ++++++++ internal/db/postgres/restorers/table.go | 11 +- .../postgres/restorers/table_insert_format.go | 12 +- internal/utils/ioutils/gzip_reader.go | 52 ++++++++ internal/utils/ioutils/gzip_reader_test.go | 61 ++++++++++ internal/utils/ioutils/gzip_writer.go | 23 ++-- internal/utils/ioutils/gzip_writer_test.go | 111 ++++++++++++++++++ 8 files changed, 351 insertions(+), 34 deletions(-) create mode 100644 internal/utils/ioutils/gzip_reader.go create mode 100644 internal/utils/ioutils/gzip_reader_test.go create mode 100644 internal/utils/ioutils/gzip_writer_test.go diff --git a/internal/db/postgres/restorers/base.go b/internal/db/postgres/restorers/base.go index 1040e67b..2e7138e4 100644 --- a/internal/db/postgres/restorers/base.go +++ b/internal/db/postgres/restorers/base.go @@ -14,6 +14,47 @@ import ( "github.com/greenmaskio/greenmask/internal/utils/ioutils" ) +type GzipObjectWrapper struct { + gz io.ReadCloser + r io.ReadCloser +} + +func NewGzipObjectWrapper(r io.ReadCloser, usePgzip bool) (*GzipObjectWrapper, error) { + gz, err := ioutils.GetGzipReadCloser(r, usePgzip) + if err != nil { + if err := r.Close(); err != nil { + log.Warn(). + Err(err). + Msg("error closing dump file") + } + return nil, fmt.Errorf("cannot create gzip reader: %w", err) + } + + return &GzipObjectWrapper{ + gz: gz, + r: r, + }, nil + +} + +func (o *GzipObjectWrapper) Read(p []byte) (n int, err error) { + return o.gz.Read(p) +} + +func (o *GzipObjectWrapper) Close() error { + if err := o.gz.Close(); err != nil { + log.Warn(). + Err(err). + Msg("error closing gzip reader") + } + if err := o.r.Close(); err != nil { + log.Warn(). + Err(err). + Msg("error closing dump file") + } + return nil +} + type restoreBase struct { opt *pgrestore.DataSectionSettings entry *toc.Entry @@ -155,36 +196,25 @@ func (rb *restoreBase) resetTx(ctx context.Context, tx pgx.Tx) error { } // getObject returns a reader for the dump file. It warps the file in a gzip reader. -func (rb *restoreBase) getObject(ctx context.Context) (io.ReadCloser, func(), error) { +func (rb *restoreBase) getObject(ctx context.Context) (io.ReadCloser, error) { if rb.entry.FileName == nil { - return nil, nil, fmt.Errorf("file name in toc.Entry is empty") + return nil, fmt.Errorf("file name in toc.Entry is empty") } r, err := rb.st.GetObject(ctx, *rb.entry.FileName) if err != nil { - return nil, nil, fmt.Errorf("cannot open dump file: %w", err) + return nil, fmt.Errorf("cannot open dump file: %w", err) } - gz, err := ioutils.GetGzipReadCloser(r, rb.opt.UsePgzip) + + gz, err := ioutils.NewGzipReader(r, rb.opt.UsePgzip) if err != nil { if err := r.Close(); err != nil { log.Warn(). Err(err). Msg("error closing dump file") } - return nil, nil, fmt.Errorf("cannot create gzip reader: %w", err) + return nil, fmt.Errorf("cannot create gzip reader: %w", err) } - closingFunc := func() { - if err := gz.Close(); err != nil { - log.Warn(). - Err(err). - Msg("error closing gzip reader") - } - if err := r.Close(); err != nil { - log.Warn(). - Err(err). - Msg("error closing dump file") - } - } - return gz, closingFunc, nil + return gz, nil } diff --git a/internal/db/postgres/restorers/base_test.go b/internal/db/postgres/restorers/base_test.go index 1cf9922b..6e89f82d 100644 --- a/internal/db/postgres/restorers/base_test.go +++ b/internal/db/postgres/restorers/base_test.go @@ -1,10 +1,13 @@ package restorers import ( + "bytes" + "compress/gzip" "context" "fmt" "testing" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "github.com/greenmaskio/greenmask/internal/db/postgres/pgrestore" @@ -71,6 +74,14 @@ DROP TABLE IF EXISTS users; ` ) +type readCloserMock struct { + *bytes.Buffer +} + +func (r *readCloserMock) Close() error { + return nil +} + type restoresSuite struct { nonSuperUserPassword string nonSuperUser string @@ -519,6 +530,44 @@ WHERE n.nspname = $1 AND c.relname = $2 s.NoError(tx.Rollback(cxt)) } +func (s *restoresSuite) Test_restoreBase_getObject() { + schemaName := "public" + tableName := "orders" + fileName := "test.tar.gz" + + data := `20383 24ca7574-0adb-4b17-8777-93f5589dbea2 2017-12-13 13:46:49.39 +20384 d0d4a55c-7752-453e-8334-772a889fb917 2017-12-13 13:46:49.453 +20385 ac8617aa-5a2d-4bb8-a9a5-ed879a4b33cd 2017-12-13 13:46:49.5 +` + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", mock.Anything, mock.Anything). + Return(objSrc, nil) + + rb := newRestoreBase(&toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + }, st, &pgrestore.DataSectionSettings{}) + ctx := context.Background() + obj, err := rb.getObject(ctx) + s.Require().NoError(err) + readBuf := make([]byte, 1024) + n, err := obj.Read(readBuf) + s.Require().NoError(err) + s.Assert().Equal(data, string(readBuf[:n])) + s.NoError(obj.Close()) +} + func TestRestorers(t *testing.T) { suite.Run(t, new(restoresSuite)) } diff --git a/internal/db/postgres/restorers/table.go b/internal/db/postgres/restorers/table.go index 294235af..5ac2b689 100644 --- a/internal/db/postgres/restorers/table.go +++ b/internal/db/postgres/restorers/table.go @@ -57,11 +57,18 @@ func (td *TableRestorer) Execute(ctx context.Context, conn *pgx.Conn) error { return fmt.Errorf("cannot get file name from toc Entry") } - r, complete, err := td.getObject(ctx) + r, err := td.getObject(ctx) if err != nil { return fmt.Errorf("cannot get storage object: %w", err) } - defer complete() + defer func() { + if err := r.Close(); err != nil { + log.Warn(). + Err(err). + Str("objectName", td.DebugInfo()). + Msg("cannot close storage object") + } + }() // Open new transaction for each task tx, err := conn.Begin(ctx) diff --git a/internal/db/postgres/restorers/table_insert_format.go b/internal/db/postgres/restorers/table_insert_format.go index d6a958de..1b1c124c 100644 --- a/internal/db/postgres/restorers/table_insert_format.go +++ b/internal/db/postgres/restorers/table_insert_format.go @@ -85,12 +85,18 @@ func (td *TableRestorerInsertFormat) GetEntry() *toc.Entry { } func (td *TableRestorerInsertFormat) Execute(ctx context.Context, conn *pgx.Conn) error { - - r, complete, err := td.getObject(ctx) + r, err := td.getObject(ctx) if err != nil { return fmt.Errorf("cannot get storage object: %w", err) } - defer complete() + defer func() { + if err := r.Close(); err != nil { + log.Warn(). + Err(err). + Str("objectName", td.DebugInfo()). + Msg("cannot close storage object") + } + }() if err = td.streamInsertData(ctx, conn, r); err != nil { if td.opt.ExitOnError { diff --git a/internal/utils/ioutils/gzip_reader.go b/internal/utils/ioutils/gzip_reader.go new file mode 100644 index 00000000..141430f5 --- /dev/null +++ b/internal/utils/ioutils/gzip_reader.go @@ -0,0 +1,52 @@ +package ioutils + +import ( + "fmt" + "io" + + "github.com/rs/zerolog/log" +) + +type GzipReader struct { + gz io.ReadCloser + r io.ReadCloser +} + +func NewGzipReader(r io.ReadCloser, usePgzip bool) (*GzipReader, error) { + gz, err := GetGzipReadCloser(r, usePgzip) + if err != nil { + if err := r.Close(); err != nil { + log.Warn(). + Err(err). + Msg("error closing dump file") + } + return nil, fmt.Errorf("cannot create gzip reader: %w", err) + } + + return &GzipReader{ + gz: gz, + r: r, + }, nil + +} + +func (r *GzipReader) Read(p []byte) (n int, err error) { + return r.gz.Read(p) +} + +func (r *GzipReader) Close() error { + var lastErr error + if err := r.gz.Close(); err != nil { + lastErr = fmt.Errorf("error closing gzip reader: %w", err) + log.Warn(). + Err(err). + Msg("error closing gzip reader") + } + if err := r.r.Close(); err != nil { + lastErr = fmt.Errorf("error closing dump file: %w", err) + log.Warn(). + Err(err). + Msg("error closing dump file") + } + return lastErr +} diff --git a/internal/utils/ioutils/gzip_reader_test.go b/internal/utils/ioutils/gzip_reader_test.go new file mode 100644 index 00000000..43b540aa --- /dev/null +++ b/internal/utils/ioutils/gzip_reader_test.go @@ -0,0 +1,61 @@ +package ioutils + +import ( + "bytes" + "compress/gzip" + "testing" + + "github.com/stretchr/testify/require" +) + +type readCloserMock struct { + *bytes.Buffer + closeCallCount int +} + +func (r *readCloserMock) Close() error { + r.closeCallCount++ + return nil +} + +func TestNewGzipReader_Read(t *testing.T) { + data := `20383 24ca7574-0adb-4b17-8777-93f5589dbea2 2017-12-13 13:46:49.39 +20384 d0d4a55c-7752-453e-8334-772a889fb917 2017-12-13 13:46:49.453 +20385 ac8617aa-5a2d-4bb8-a9a5-ed879a4b33cd 2017-12-13 13:46:49.5 +` + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + require.NoError(t, err) + err = gzData.Flush() + require.NoError(t, err) + err = gzData.Close() + require.NoError(t, err) + objSrc := &readCloserMock{Buffer: buf} + r, err := NewGzipReader(objSrc, false) + require.NoError(t, err) + readBuf := make([]byte, 1024) + n, err := r.Read(readBuf) + require.NoError(t, err) + require.Equal(t, []byte(data), readBuf[:n]) +} + +func TestNewGzipReader_Close(t *testing.T) { + data := "" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + require.NoError(t, err) + err = gzData.Flush() + require.NoError(t, err) + err = gzData.Close() + objSrc := &readCloserMock{Buffer: buf, closeCallCount: 0} + r, err := NewGzipReader(objSrc, false) + require.NoError(t, err) + err = r.Close() + require.NoError(t, err) + require.Equal(t, 1, objSrc.closeCallCount) + gz := r.gz.(*gzip.Reader) + _, err = gz.Read([]byte{}) + require.Error(t, err) +} diff --git a/internal/utils/ioutils/gzip_writer.go b/internal/utils/ioutils/gzip_writer.go index 5b8390f7..51dd5f54 100644 --- a/internal/utils/ioutils/gzip_writer.go +++ b/internal/utils/ioutils/gzip_writer.go @@ -52,17 +52,18 @@ func (gw *GzipWriter) Write(p []byte) (int, error) { // Close - closing method with gz buffer flushing func (gw *GzipWriter) Close() error { - defer gw.w.Close() - flushErr := gw.gz.Flush() - if flushErr != nil { - log.Warn().Err(flushErr).Msg("error flushing gzip buffer") + var globalErr error + if err := gw.gz.Flush(); err != nil { + globalErr = fmt.Errorf("error flushing gzip buffer: %w", err) + log.Warn().Err(err).Msg("error flushing gzip buffer") } - if closeErr := gw.gz.Close(); closeErr != nil || flushErr != nil { - err := closeErr - if flushErr != nil { - err = flushErr - } - return fmt.Errorf("error closing gzip writer: %w", err) + if err := gw.gz.Close(); err != nil { + globalErr = fmt.Errorf("error closing gzip writer: %w", err) + log.Warn().Err(err).Msg("error closing gzip writer") } - return nil + if err := gw.w.Close(); err != nil { + globalErr = fmt.Errorf("error closing dump file: %w", err) + log.Warn().Err(err).Msg("error closing dump file") + } + return globalErr } diff --git a/internal/utils/ioutils/gzip_writer_test.go b/internal/utils/ioutils/gzip_writer_test.go new file mode 100644 index 00000000..f0800dcf --- /dev/null +++ b/internal/utils/ioutils/gzip_writer_test.go @@ -0,0 +1,111 @@ +package ioutils + +import ( + "bytes" + "compress/gzip" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +type writeCloserMock struct { + data []byte + writeCallCount int + writeCallFunc func(callCount int) error + closeCallCount int + closeCallFunc func(callCount int) error +} + +func (w *writeCloserMock) Write(p []byte) (n int, err error) { + w.writeCallCount++ + if w.writeCallFunc != nil { + return 0, w.writeCallFunc(w.writeCallCount) + } + w.data = append(w.data, p...) + return len(p), nil +} + +func (w *writeCloserMock) Close() error { + w.closeCallCount++ + if w.closeCallFunc != nil { + return w.closeCallFunc(w.closeCallCount) + } + return nil +} + +func TestNewGzipWriter_Write(t *testing.T) { + data := `20383 24ca7574-0adb-4b17-8777-93f5589dbea2 2017-12-13 13:46:49.39 +20384 d0d4a55c-7752-453e-8334-772a889fb917 2017-12-13 13:46:49.453 +20385 ac8617aa-5a2d-4bb8-a9a5-ed879a4b33cd 2017-12-13 13:46:49.5 +` + testDataBuf := new(bytes.Buffer) + gzData := gzip.NewWriter(testDataBuf) + _, err := gzData.Write([]byte(data)) + require.NoError(t, err) + err = gzData.Flush() + require.NoError(t, err) + err = gzData.Close() + require.NoError(t, err) + expectedData := testDataBuf.Bytes() + + require.NoError(t, err) + err = gzData.Close() + require.NoError(t, err) + objSrc := &writeCloserMock{} + r := NewGzipWriter(objSrc, false) + require.NoError(t, err) + _, err = r.Write([]byte(data)) + require.NoError(t, err) + err = r.Close() + require.NoError(t, err) + + require.Equal(t, expectedData, objSrc.data) +} + +func TestNewGzipWriter_Close(t *testing.T) { + data := `20383 24ca7574-0adb-4b17-8777-93f5589dbea2 2017-12-13 13:46:49.39 +20384 d0d4a55c-7752-453e-8334-772a889fb917 2017-12-13 13:46:49.453 +20385 ac8617aa-5a2d-4bb8-a9a5-ed879a4b33cd 2017-12-13 13:46:49.5 +` + t.Run("Success", func(t *testing.T) { + objSrc := &writeCloserMock{} + r := NewGzipWriter(objSrc, false) + err := r.Close() + require.NoError(t, err) + require.Equal(t, 1, objSrc.closeCallCount) + }) + + t.Run("Flush Error", func(t *testing.T) { + objSrc := &writeCloserMock{ + writeCallFunc: func(c int) error { + if c == 2 { + return errors.New("storage object error") + } + return nil + }, + } + r := NewGzipWriter(objSrc, false) + _, err := r.Write([]byte(data)) + require.NoError(t, err) + + err = r.Close() + require.Error(t, err) + require.ErrorContains(t, err, "error closing gzip writer") + require.Equal(t, 1, objSrc.closeCallCount) + require.Equal(t, 2, objSrc.writeCallCount) + }) + + t.Run("Storage object close Error", func(t *testing.T) { + objSrc := &writeCloserMock{ + closeCallFunc: func(c int) error { + return errors.New("storage object error") + }, + } + r := NewGzipWriter(objSrc, false) + err := r.Close() + require.Error(t, err) + require.Equal(t, 1, objSrc.closeCallCount) + require.ErrorContains(t, err, "error closing dump file") + }) +}