diff --git a/internal/db/postgres/restorers/base_test.go b/internal/db/postgres/restorers/base_test.go index c2141b58..73c6ccc3 100644 --- a/internal/db/postgres/restorers/base_test.go +++ b/internal/db/postgres/restorers/base_test.go @@ -35,8 +35,8 @@ CREATE TABLE users ( CREATE TABLE orders ( id SERIAL PRIMARY KEY, user_id INT NOT NULL, - raise_error TEXT, order_amount NUMERIC(10, 2) NOT NULL, + raise_error TEXT, order_date DATE DEFAULT CURRENT_DATE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, CONSTRAINT fk_user FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE diff --git a/internal/db/postgres/restorers/table_insert_format_test.go b/internal/db/postgres/restorers/table_insert_format_test.go new file mode 100644 index 00000000..87880972 --- /dev/null +++ b/internal/db/postgres/restorers/table_insert_format_test.go @@ -0,0 +1,195 @@ +package restorers + +import ( + "bytes" + "compress/gzip" + "context" + + "github.com/stretchr/testify/mock" + + "github.com/greenmaskio/greenmask/internal/db/postgres/pgrestore" + "github.com/greenmaskio/greenmask/internal/db/postgres/toc" + "github.com/greenmaskio/greenmask/internal/domains" + "github.com/greenmaskio/greenmask/internal/utils/testutils" + "github.com/greenmaskio/greenmask/pkg/toolkit" +) + +func (s *restoresSuite) Test_TableRestorerInsertFormat_check_triggers_errors() { + s.Run("check triggers causes error by default", func() { + ctx := context.Background() + schemaName := "public" + tableName := "orders" + fileName := "test_table" + copyStmt := "COPY orders (id, user_id, order_amount, raise_error) FROM stdin;" + entry := &toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + CopyStmt: ©Stmt, + } + data := "6\t1\t100.50\tTest exception\n" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", ctx, mock.Anything).Return(objSrc, nil) + opt := &pgrestore.DataSectionSettings{ + ExitOnError: true, + } + t := &toolkit.Table{ + Schema: schemaName, + Name: tableName, + Columns: []*toolkit.Column{ + { + Name: "id", + TypeName: "int4", + }, + { + Name: "user_id", + TypeName: "int4", + }, + { + Name: "order_amount", + TypeName: "numeric", + }, + { + Name: "raise_error", + TypeName: "text", + }, + }, + } + + tr := NewTableRestorerInsertFormat(entry, t, st, opt, new(domains.DataRestorationErrorExclusions)) + + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + err = tr.Execute(ctx, conn) + s.Require().ErrorContains(err, "Test exception (SQLSTATE P0001)") + }) + + s.Run("disable triggers", func() { + ctx := context.Background() + schemaName := "public" + tableName := "orders" + fileName := "test_table" + copyStmt := "COPY orders (id, user_id, order_amount, raise_error) FROM stdin;" + entry := &toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + CopyStmt: ©Stmt, + } + data := "7\t1\t100.50\tTest exception\n" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", ctx, mock.Anything).Return(objSrc, nil) + opt := &pgrestore.DataSectionSettings{ + ExitOnError: true, + DisableTriggers: true, + SuperUser: s.GetSuperUser(), + } + t := &toolkit.Table{ + Schema: schemaName, + Name: tableName, + Columns: []*toolkit.Column{ + { + Name: "id", + TypeName: "int4", + }, + { + Name: "user_id", + TypeName: "int4", + }, + { + Name: "order_amount", + TypeName: "numeric", + }, + { + Name: "raise_error", + TypeName: "text", + }, + }, + } + + tr := NewTableRestorerInsertFormat(entry, t, st, opt, new(domains.DataRestorationErrorExclusions)) + + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + err = tr.Execute(ctx, conn) + s.Require().NoError(err) + }) + + s.Run("session_replication_role is replica", func() { + ctx := context.Background() + schemaName := "public" + tableName := "orders" + fileName := "test_table" + copyStmt := "COPY orders (id, user_id, order_amount, raise_error) FROM stdin;" + entry := &toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + CopyStmt: ©Stmt, + } + data := "8\t1\t100.50\tTest exception\n" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", ctx, mock.Anything).Return(objSrc, nil) + opt := &pgrestore.DataSectionSettings{ + ExitOnError: true, + UseSessionReplicationRoleReplica: true, + SuperUser: s.GetSuperUser(), + } + t := &toolkit.Table{ + Schema: schemaName, + Name: tableName, + Columns: []*toolkit.Column{ + { + Name: "id", + TypeName: "int4", + }, + { + Name: "user_id", + TypeName: "int4", + }, + { + Name: "order_amount", + TypeName: "numeric", + }, + { + Name: "raise_error", + TypeName: "text", + }, + }, + } + + tr := NewTableRestorerInsertFormat(entry, t, st, opt, new(domains.DataRestorationErrorExclusions)) + + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + err = tr.Execute(ctx, conn) + s.Require().NoError(err) + }) +} diff --git a/internal/db/postgres/restorers/table_test.go b/internal/db/postgres/restorers/table_test.go index 6246462d..00f8b07d 100644 --- a/internal/db/postgres/restorers/table_test.go +++ b/internal/db/postgres/restorers/table_test.go @@ -26,8 +26,7 @@ func (s *restoresSuite) Test_TableRestorer_check_triggers_errors() { FileName: &fileName, CopyStmt: ©Stmt, } - data := "3\t1\t100.50\tTest exception\n" + - "4\t1\t200.75\tTest exception\n" + data := "3\t1\t100.50\tTest exception\n" buf := new(bytes.Buffer) gzData := gzip.NewWriter(buf) _, err := gzData.Write([]byte(data)) @@ -62,8 +61,7 @@ func (s *restoresSuite) Test_TableRestorer_check_triggers_errors() { FileName: &fileName, CopyStmt: ©Stmt, } - data := "3\t1\t100.50\tTest exception\n" + - "4\t1\t200.75\tTest exception\n" + data := "4\t1\t100.50\tTest exception\n" buf := new(bytes.Buffer) gzData := gzip.NewWriter(buf) _, err := gzData.Write([]byte(data)) @@ -100,8 +98,7 @@ func (s *restoresSuite) Test_TableRestorer_check_triggers_errors() { FileName: &fileName, CopyStmt: ©Stmt, } - data := "3\t1\t100.50\tTest exception\n" + - "4\t1\t200.75\tTest exception\n" + data := "5\t1\t100.50\tTest exception\n" buf := new(bytes.Buffer) gzData := gzip.NewWriter(buf) _, err := gzData.Write([]byte(data)) @@ -125,5 +122,4 @@ func (s *restoresSuite) Test_TableRestorer_check_triggers_errors() { err = tr.Execute(ctx, conn) s.Require().NoError(err) }) - }