From 91aea48d3f7cc49adbffaef78f86d358c83b6645 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Mon, 19 Aug 2024 13:20:12 +0300 Subject: [PATCH 1/9] fix: fixed validate command fatal error caused by filtered objects --- internal/db/postgres/cmd/dump.go | 13 +++++++--- internal/db/postgres/cmd/validate.go | 34 ++++++++++--------------- internal/db/postgres/context/context.go | 26 ++++++++++++++----- internal/db/postgres/dumpers/table.go | 22 ++++++++-------- internal/db/postgres/entries/table.go | 6 ++--- 5 files changed, 55 insertions(+), 46 deletions(-) diff --git a/internal/db/postgres/cmd/dump.go b/internal/db/postgres/cmd/dump.go index 6a3b9f45..0a4f6ae6 100644 --- a/internal/db/postgres/cmd/dump.go +++ b/internal/db/postgres/cmd/dump.go @@ -67,7 +67,8 @@ type Dump struct { // sortedTablesDumpIds - sorted tables dump ids in topological order sortedTablesDumpIds []int32 // validate shows that dump worker must be in validation mode - validate bool + validate bool + validateRowsLimit uint64 } func NewDump(cfg *domains.Config, st storages.Storager, registry *utils.TransformerRegistry) *Dump { @@ -257,13 +258,17 @@ func (d *Dump) dumpWorkerRunner( func (d *Dump) taskProducer(ctx context.Context, tasks chan<- dumpers.DumpTask) func() error { return func() error { defer close(tasks) + dataObjects := d.context.DataSectionObjects + if d.validate { + dataObjects = d.context.DataSectionObjectsToValidate + } - for _, dumpObj := range d.context.DataSectionObjects { + for _, dumpObj := range dataObjects { dumpObj.SetDumpId(d.dumpIdSequence) var task dumpers.DumpTask switch v := dumpObj.(type) { case *entries.Table: - task = dumpers.NewTableDumper(v, d.validate, d.pgDumpOptions.Pgzip) + task = dumpers.NewTableDumper(v, d.validate, d.validateRowsLimit, d.pgDumpOptions.Pgzip) case *entries.Sequence: task = dumpers.NewSequenceDumper(v) case *entries.Blobs: @@ -332,7 +337,7 @@ func (d *Dump) setDumpDependenciesGraph(tables []*entries.Table) { return entry.Oid == oid }) if idx == -1 { - panic("table not found") + panic(fmt.Sprintf("table not found: oid=%d", oid)) } t := tables[idx] // Create dependencies graph with DumpId sequence for easier restoration coordination diff --git a/internal/db/postgres/cmd/validate.go b/internal/db/postgres/cmd/validate.go index 90e90078..d756570f 100644 --- a/internal/db/postgres/cmd/validate.go +++ b/internal/db/postgres/cmd/validate.go @@ -63,6 +63,7 @@ func NewValidate(cfg *domains.Config, registry *utils.TransformerRegistry, st st d.dumpIdSequence = toc.NewDumpSequence(0) d.validate = true + d.validateRowsLimit = cfg.Validate.RowsLimit return &Validate{ Dump: d, tmpDir: tmpDirName, @@ -139,7 +140,7 @@ func (v *Validate) Run(ctx context.Context) (int, error) { return v.exitCode, nil } - if err = v.dumpTables(ctx); err != nil { + if err = v.dataDump(ctx); err != nil { return nonZeroExitCode, err } @@ -152,12 +153,20 @@ func (v *Validate) Run(ctx context.Context) (int, error) { func (v *Validate) print(ctx context.Context) error { for _, e := range v.dataEntries { - idx := slices.IndexFunc(v.context.DataSectionObjects, func(entry entries.Entry) bool { - t := entry.(*entries.Table) + idx := slices.IndexFunc(v.context.DataSectionObjectsToValidate, func(entry entries.Entry) bool { + t, ok := entry.(*entries.Table) + if !ok { + return false + } return t.DumpId == e.DumpId }) - t := v.context.DataSectionObjects[idx].(*entries.Table) + if idx == -1 { + // skip if not in DataSectionObjectsToValidate + continue + } + + t := v.context.DataSectionObjectsToValidate[idx].(*entries.Table) doc, err := v.createDocument(ctx, t) if err != nil { return fmt.Errorf("unable to create validation document: %w", err) @@ -269,23 +278,6 @@ func (v *Validate) createDocument(ctx context.Context, t *entries.Table) (valida return doc, nil } -func (v *Validate) dumpTables(ctx context.Context) error { - var tablesWithTransformers []entries.Entry - for _, item := range v.context.DataSectionObjects { - - if t, ok := item.(*entries.Table); ok && len(t.TransformersContext) > 0 { - t.ValidateLimitedRecords = v.config.Validate.RowsLimit - tablesWithTransformers = append(tablesWithTransformers, t) - } - } - v.context.DataSectionObjects = tablesWithTransformers - - if err := v.dataDump(ctx); err != nil { - return fmt.Errorf("data stage dumping error: %w", err) - } - return nil -} - func (v *Validate) printValidationWarnings() error { // TODO: Implement warnings hook, such as logging and HTTP sender for _, w := range v.context.Warnings { diff --git a/internal/db/postgres/context/context.go b/internal/db/postgres/context/context.go index c534a813..a8234152 100644 --- a/internal/db/postgres/context/context.go +++ b/internal/db/postgres/context/context.go @@ -42,6 +42,8 @@ type RuntimeContext struct { Types []*toolkit.Type // DataSectionObjects - list of objects to dump in data-section. There are sequences, tables and large objects DataSectionObjects []entries.Entry + // DataSectionObjectsToValidate - list of objects to validate in data-section + DataSectionObjectsToValidate []entries.Entry // Warnings - list of occurred ValidationWarning during validation and config building Warnings toolkit.ValidationWarnings // Registry - registry of all the registered transformers definition @@ -125,14 +127,24 @@ func NewRuntimeContext( dataSectionObjects = append(dataSectionObjects, blobEntries) } + // Generate list of tables that might be validated during the validate command call + var dataSectionObjectsToValidate []entries.Entry + for _, item := range dataSectionObjects { + + if t, ok := item.(*entries.Table); ok && len(t.TransformersContext) > 0 { + dataSectionObjectsToValidate = append(dataSectionObjectsToValidate, t) + } + } + return &RuntimeContext{ - Tables: tables, - Types: types, - DataSectionObjects: dataSectionObjects, - Warnings: warnings, - Registry: r, - DatabaseSchema: schema, - Graph: graph, + Tables: tables, + Types: types, + DataSectionObjects: dataSectionObjects, + Warnings: warnings, + Registry: r, + DatabaseSchema: schema, + Graph: graph, + DataSectionObjectsToValidate: dataSectionObjectsToValidate, }, nil } diff --git a/internal/db/postgres/dumpers/table.go b/internal/db/postgres/dumpers/table.go index 5ed121d5..48cc97e1 100644 --- a/internal/db/postgres/dumpers/table.go +++ b/internal/db/postgres/dumpers/table.go @@ -19,7 +19,6 @@ import ( "fmt" "io" - "github.com/greenmaskio/greenmask/internal/utils/ioutils" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgproto3" "github.com/rs/zerolog/log" @@ -27,20 +26,23 @@ import ( "github.com/greenmaskio/greenmask/internal/db/postgres/entries" "github.com/greenmaskio/greenmask/internal/storages" + "github.com/greenmaskio/greenmask/internal/utils/ioutils" ) type TableDumper struct { - table *entries.Table - recordNum uint64 - validate bool - usePgzip bool + table *entries.Table + recordNum uint64 + validate bool + validateRowsLimit uint64 + usePgzip bool } -func NewTableDumper(table *entries.Table, validate bool, usePgzip bool) *TableDumper { +func NewTableDumper(table *entries.Table, validate bool, rowsLimit uint64, usePgzip bool) *TableDumper { return &TableDumper{ - table: table, - validate: validate, - usePgzip: usePgzip, + table: table, + validate: validate, + usePgzip: usePgzip, + validateRowsLimit: rowsLimit, } } @@ -162,7 +164,7 @@ func (td *TableDumper) process(ctx context.Context, tx pgx.Tx, w io.WriteCloser, if td.validate { // Logic for validation limiter - exit after recordNum rows td.recordNum++ - if td.recordNum == td.table.ValidateLimitedRecords { + if td.recordNum == td.validateRowsLimit { return nil } } diff --git a/internal/db/postgres/entries/table.go b/internal/db/postgres/entries/table.go index 42109663..bc389a40 100644 --- a/internal/db/postgres/entries/table.go +++ b/internal/db/postgres/entries/table.go @@ -44,10 +44,8 @@ type Table struct { CompressedSize int64 ExcludeData bool Driver *toolkit.Driver - // ValidateLimitedRecords - perform dumping and transformation only for N records and exit - ValidateLimitedRecords uint64 - Scores int64 - SubsetConds []string + Scores int64 + SubsetConds []string } func (t *Table) HasCustomTransformer() bool { From 12bf1180d975509489ac3028eceb89520edbd6f9 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Mon, 19 Aug 2024 16:08:55 +0300 Subject: [PATCH 2/9] fix: fixed zero bytes that were written in the buffer due to wrong buffer limit --- internal/db/postgres/transformers/email.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/db/postgres/transformers/email.go b/internal/db/postgres/transformers/email.go index ca4fc820..3ec79f3c 100644 --- a/internal/db/postgres/transformers/email.go +++ b/internal/db/postgres/transformers/email.go @@ -10,10 +10,11 @@ import ( "slices" "text/template" + "github.com/rs/zerolog/log" + "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils" "github.com/greenmaskio/greenmask/internal/generators" "github.com/greenmaskio/greenmask/pkg/toolkit" - "github.com/rs/zerolog/log" ) const emailTransformerGeneratorSize = 64 @@ -207,7 +208,7 @@ func NewEmailTransformer(ctx context.Context, driver *toolkit.Driver, parameters domainTemplate: domainTmpl, validate: validate, buf: bytes.NewBuffer(nil), - hexEncodedRandomBytesBuf: make([]byte, hex.EncodedLen(emailTransformerGeneratorSize)), + hexEncodedRandomBytesBuf: make([]byte, hex.EncodedLen(maxLength)), rrctx: rrctx, }, nil, nil } From d3ee98c942d8d48bd2652d4c78d4e9e37b05cc6c Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Tue, 20 Aug 2024 12:00:30 +0300 Subject: [PATCH 3/9] feat: restore data in batches * Introduced the --batch-size flag for the restore command. * The COPY command will complete after reaching the specified batch size, allowing for transaction state checks. * The transaction spans across all batches. * If an error occurs in any batch, all previous batches will be rolled back. --- cmd/greenmask/cmd/restore/restore.go | 6 +- docs/commands/restore.md | 26 +++++- internal/db/postgres/cmd/restore.go | 4 +- internal/db/postgres/cmd/validate.go | 4 +- internal/db/postgres/pgrestore/pgrestore.go | 3 +- internal/db/postgres/restorers/table.go | 93 ++++++++++++++++--- .../postgres/restorers/table_insert_format.go | 2 +- .../transformers/custom/dynamic_definition.go | 4 +- .../utils/cmd_transformer_base.go | 4 +- internal/utils/cmd_runner/cmd_runner.go | 4 +- internal/utils/reader/reader.go | 8 +- 11 files changed, 126 insertions(+), 32 deletions(-) diff --git a/cmd/greenmask/cmd/restore/restore.go b/cmd/greenmask/cmd/restore/restore.go index 039725ba..302b2916 100644 --- a/cmd/greenmask/cmd/restore/restore.go +++ b/cmd/greenmask/cmd/restore/restore.go @@ -171,6 +171,10 @@ func init() { "pgzip", "", false, "use pgzip decompression instead of gzip", ) + Cmd.Flags().Int64P( + "batch-size", "", 0, + "the number of rows to insert in a single batch during the COPY command (0 - all rows will be inserted in a single batch)", + ) // Connection options: Cmd.Flags().StringP("host", "h", "/var/run/postgres", "database server host or socket directory") @@ -185,7 +189,7 @@ func init() { "disable-triggers", "enable-row-security", "if-exists", "no-comments", "no-data-for-failed-tables", "no-security-labels", "no-subscriptions", "no-table-access-method", "no-tablespaces", "section", "strict-names", "use-set-session-authorization", "inserts", "on-conflict-do-nothing", "restore-in-order", - "pgzip", + "pgzip", "batch-size", "host", "port", "username", } { diff --git a/docs/commands/restore.md b/docs/commands/restore.md index 6db4a469..3103da91 100644 --- a/docs/commands/restore.md +++ b/docs/commands/restore.md @@ -18,6 +18,7 @@ allowing you to configure the restoration process as needed. Mostly it supports the same flags as the `pg_restore` utility, with some extra flags for Greenmask-specific features. ```text title="Supported flags" + --batch-size int the number of rows to insert in a single batch during the COPY command (0 - all rows will be inserted in a single batch) -c, --clean clean (drop) database objects before recreating -C, --create create the target database -a, --data-only restore only the data, no schema @@ -112,5 +113,28 @@ If your database has cyclic dependencies you will be notified about it but the r By default, Greenmask uses gzip decompression to restore data. In mist cases it is quite slow and does not utilize all available resources and is a bootleneck for IO operations. To speed up the restoration process, you can use the `--pgzip` flag to use pgzip decompression instead of gzip. This method splits the data into blocks, which are -decompressed in parallel, making it ideal for handling large volumes of data. The output remains a standard gzip file. +decompressed in parallel, making it ideal for handling large volumes of data. +```shell title="example with pgzip decompression" +greenmask --config=config.yml restore latest --pgzip +``` + +### Restore data batching + +The COPY command returns the error only on transaction commit. This means that if you have a large dump and an error +occurs, you will have to wait until the end of the transaction to see the error message. To avoid this, you can use the +`--batch-size` flag to specify the number of rows to insert in a single batch during the COPY command. If an error occurs +during the batch insertion, the error message will be displayed immediately. The data will be committed **only if all +batches are inserted successfully**. + +!!! warning + + The batch size should be chosen carefully. If the batch size is too small, the restoration process will be slow. If + the batch size is too large, you may not be able to identify the error row. + +In the example below, the batch size is set to 1000 rows. This means that 1000 rows will be inserted in a single batch, +so you will be notified of any errors immediately after each batch is inserted. + +```shell title="example with batch size" +greenmask --config=config.yml restore latest --batch-size 1000 +``` diff --git a/internal/db/postgres/cmd/restore.go b/internal/db/postgres/cmd/restore.go index fc35946e..d95c62d6 100644 --- a/internal/db/postgres/cmd/restore.go +++ b/internal/db/postgres/cmd/restore.go @@ -646,7 +646,9 @@ func (r *Restore) taskPusher(ctx context.Context, tasks chan restorers.RestoreTa r.cfg.ErrorExclusions, r.restoreOpt.Pgzip, ) } else { - task = restorers.NewTableRestorer(entry, r.st, r.restoreOpt.ExitOnError, r.restoreOpt.Pgzip) + task = restorers.NewTableRestorer( + entry, r.st, r.restoreOpt.ExitOnError, r.restoreOpt.Pgzip, r.restoreOpt.BatchSize, + ) } case toc.SequenceSetDesc: diff --git a/internal/db/postgres/cmd/validate.go b/internal/db/postgres/cmd/validate.go index d756570f..0c93e0e8 100644 --- a/internal/db/postgres/cmd/validate.go +++ b/internal/db/postgres/cmd/validate.go @@ -223,7 +223,7 @@ func (v *Validate) readRecords(r *bufio.Reader, t *entries.Table) (original, tra originalRow = pgcopy.NewRow(len(t.Columns)) transformedRow = pgcopy.NewRow(len(t.Columns)) - originalLine, err = reader.ReadLine(r) + originalLine, err = reader.ReadLine(r, nil) if err != nil { if errors.Is(err, io.EOF) { return nil, nil, err @@ -235,7 +235,7 @@ func (v *Validate) readRecords(r *bufio.Reader, t *entries.Table) (original, tra return nil, nil, io.EOF } - transformedLine, err = reader.ReadLine(r) + transformedLine, err = reader.ReadLine(r, nil) if err != nil { return nil, nil, fmt.Errorf("unable to read line: %w", err) } diff --git a/internal/db/postgres/pgrestore/pgrestore.go b/internal/db/postgres/pgrestore/pgrestore.go index c52895ea..4396054d 100644 --- a/internal/db/postgres/pgrestore/pgrestore.go +++ b/internal/db/postgres/pgrestore/pgrestore.go @@ -97,7 +97,8 @@ type Options struct { Inserts bool `mapstructure:"inserts"` RestoreInOrder bool `mapstructure:"restore-in-order"` // Use pgzip decompression instead of gzip - Pgzip bool `mapstructure:"pgzip"` + Pgzip bool `mapstructure:"pgzip"` + BatchSize int64 `mapstructure:"batch-size"` // Connection options: Host string `mapstructure:"host"` diff --git a/internal/db/postgres/restorers/table.go b/internal/db/postgres/restorers/table.go index 326f75f6..be30718c 100644 --- a/internal/db/postgres/restorers/table.go +++ b/internal/db/postgres/restorers/table.go @@ -15,6 +15,7 @@ package restorers import ( + "bufio" "context" "errors" "fmt" @@ -22,6 +23,7 @@ import ( "github.com/greenmaskio/greenmask/internal/utils/ioutils" "github.com/greenmaskio/greenmask/internal/utils/pgerrors" + "github.com/greenmaskio/greenmask/internal/utils/reader" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgproto3" "github.com/rs/zerolog/log" @@ -37,14 +39,18 @@ type TableRestorer struct { St storages.Storager exitOnError bool usePgzip bool + batchSize int64 } -func NewTableRestorer(entry *toc.Entry, st storages.Storager, exitOnError bool, usePgzip bool) *TableRestorer { +func NewTableRestorer( + entry *toc.Entry, st storages.Storager, exitOnError bool, usePgzip bool, batchSize int64, +) *TableRestorer { return &TableRestorer{ Entry: entry, St: st, exitOnError: exitOnError, usePgzip: usePgzip, + batchSize: batchSize, } } @@ -117,8 +123,14 @@ func (td *TableRestorer) restoreCopy(ctx context.Context, f *pgproto3.Frontend, return fmt.Errorf("error initializing pgcopy: %w", err) } - if err := td.streamCopyData(ctx, f, r); err != nil { - return fmt.Errorf("error streaming pgcopy data: %w", err) + if td.batchSize > 0 { + if err := td.streamCopyDataByBatch(ctx, f, r); err != nil { + return fmt.Errorf("error streaming pgcopy data: %w", err) + } + } else { + if err := td.streamCopyData(ctx, f, r); err != nil { + return fmt.Errorf("error streaming pgcopy data: %w", err) + } } if err := td.postStreamingHandle(ctx, f); err != nil { @@ -134,8 +146,7 @@ func (td *TableRestorer) initCopy(ctx context.Context, f *pgproto3.Frontend) err } // Prepare for streaming the pgcopy data - process := true - for process { + for { select { case <-ctx.Done(): return ctx.Err() @@ -148,35 +159,67 @@ func (td *TableRestorer) initCopy(ctx context.Context, f *pgproto3.Frontend) err } switch v := msg.(type) { case *pgproto3.CopyInResponse: - process = false + return nil case *pgproto3.ErrorResponse: return fmt.Errorf("error from postgres connection: %w", pgerrors.NewPgError(v)) default: return fmt.Errorf("unknown message %+v", v) } } - return nil } -func (td *TableRestorer) streamCopyData(ctx context.Context, f *pgproto3.Frontend, r io.Reader) error { - // Streaming pgcopy data from table dump - +// streamCopyDataByBatch - stream pgcopy data from table dump in batches. It handles errors only on the end each batch +// If the batch size is reached it completes the batch and starts a new one. If an error occurs during the batch it +// stops immediately and returns the error +func (td *TableRestorer) streamCopyDataByBatch(ctx context.Context, f *pgproto3.Frontend, r io.Reader) (err error) { + bi := bufio.NewReader(r) buf := make([]byte, DefaultBufferSize) + var lineNum int64 for { - var n int + buf, err = reader.ReadLine(bi, buf) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return fmt.Errorf("error readimg from table dump: %w", err) + } + if isTerminationSeq(buf) { + break + } + lineNum++ + buf = append(buf, '\n') + + err = sendMessage(f, &pgproto3.CopyData{Data: buf}) + if err != nil { + return fmt.Errorf("error sending CopyData message: %w", err) + } + + if lineNum%td.batchSize == 0 { + if err = td.completeBatch(ctx, f); err != nil { + return fmt.Errorf("error completing batch: %w", err) + } + } + select { case <-ctx.Done(): return ctx.Err() default: } + } + return nil +} + +// streamCopyData - stream pgcopy data from table dump in classic way. It handles errors only on the end of the stream +func (td *TableRestorer) streamCopyData(ctx context.Context, f *pgproto3.Frontend, r io.Reader) error { + // Streaming pgcopy data from table dump + + buf := make([]byte, DefaultBufferSize) + for { + var n int n, err := r.Read(buf) if err != nil { if errors.Is(err, io.EOF) { - completionErr := sendMessage(f, &pgproto3.CopyDone{}) - if completionErr != nil { - return fmt.Errorf("error sending CopyDone message: %w", err) - } break } return fmt.Errorf("error readimg from table dump: %w", err) @@ -186,12 +229,32 @@ func (td *TableRestorer) streamCopyData(ctx context.Context, f *pgproto3.Fronten if err != nil { return fmt.Errorf("error sending DopyData message: %w", err) } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + } + return nil +} + +// completeBatch - complete batch of pgcopy data and initiate new one +func (td *TableRestorer) completeBatch(ctx context.Context, f *pgproto3.Frontend) error { + if err := td.postStreamingHandle(ctx, f); err != nil { + return err + } + if err := td.initCopy(ctx, f); err != nil { + return err } return nil } func (td *TableRestorer) postStreamingHandle(ctx context.Context, f *pgproto3.Frontend) error { // Perform post streaming handling + err := sendMessage(f, &pgproto3.CopyDone{}) + if err != nil { + return fmt.Errorf("error sending CopyDone message: %w", err) + } var mainErr error for { select { diff --git a/internal/db/postgres/restorers/table_insert_format.go b/internal/db/postgres/restorers/table_insert_format.go index 7d3359cf..5700493e 100644 --- a/internal/db/postgres/restorers/table_insert_format.go +++ b/internal/db/postgres/restorers/table_insert_format.go @@ -140,7 +140,7 @@ func (td *TableRestorerInsertFormat) streamInsertData(ctx context.Context, conn default: } - line, err := reader.ReadLine(buf) + line, err := reader.ReadLine(buf, nil) if err != nil { if errors.Is(err, io.EOF) { break diff --git a/internal/db/postgres/transformers/custom/dynamic_definition.go b/internal/db/postgres/transformers/custom/dynamic_definition.go index a562b810..16c488f1 100644 --- a/internal/db/postgres/transformers/custom/dynamic_definition.go +++ b/internal/db/postgres/transformers/custom/dynamic_definition.go @@ -85,7 +85,7 @@ func GetDynamicTransformerDefinition(ctx context.Context, executable string, arg buf := bufio.NewReader(bytes.NewBuffer(stdoutData)) for { - line, err := reader.ReadLine(buf) + line, err := reader.ReadLine(buf, nil) if err != nil { break } @@ -102,7 +102,7 @@ func GetDynamicTransformerDefinition(ctx context.Context, executable string, arg buf := bufio.NewReader(bytes.NewBuffer(stderrData)) for { - line, err := reader.ReadLine(buf) + line, err := reader.ReadLine(buf, nil) if err != nil { break } diff --git a/internal/db/postgres/transformers/utils/cmd_transformer_base.go b/internal/db/postgres/transformers/utils/cmd_transformer_base.go index 5599e6b3..5207481e 100644 --- a/internal/db/postgres/transformers/utils/cmd_transformer_base.go +++ b/internal/db/postgres/transformers/utils/cmd_transformer_base.go @@ -315,7 +315,7 @@ func (ctb *CmdTransformerBase) init() error { func (ctb *CmdTransformerBase) ReceiveStderrLine(ctx context.Context) (line []byte, err error) { go func() { - line, err = reader.ReadLine(ctb.StderrReader) + line, err = reader.ReadLine(ctb.StderrReader, nil) ctb.receiveChan <- struct{}{} }() select { @@ -333,7 +333,7 @@ func (ctb *CmdTransformerBase) ReceiveStderrLine(ctx context.Context) (line []by func (ctb *CmdTransformerBase) ReceiveStdoutLine(ctx context.Context) (line []byte, err error) { go func() { - line, err = reader.ReadLine(ctb.StdoutReader) + line, err = reader.ReadLine(ctb.StdoutReader, nil) ctb.receiveChan <- struct{}{} }() select { diff --git a/internal/utils/cmd_runner/cmd_runner.go b/internal/utils/cmd_runner/cmd_runner.go index ceec1863..3a7021f7 100644 --- a/internal/utils/cmd_runner/cmd_runner.go +++ b/internal/utils/cmd_runner/cmd_runner.go @@ -53,7 +53,7 @@ func Run(ctx context.Context, logger *zerolog.Logger, name string, args ...strin return gtx.Err() default: } - line, err := reader.ReadLine(lineScanner) + line, err := reader.ReadLine(lineScanner, nil) if err != nil { if errors.Is(err, io.EOF) { return nil @@ -73,7 +73,7 @@ func Run(ctx context.Context, logger *zerolog.Logger, name string, args ...strin return gtx.Err() default: } - line, err := reader.ReadLine(lineScanner) + line, err := reader.ReadLine(lineScanner, nil) if err != nil { if errors.Is(err, io.EOF) { return nil diff --git a/internal/utils/reader/reader.go b/internal/utils/reader/reader.go index 59164ee7..d28d24de 100644 --- a/internal/utils/reader/reader.go +++ b/internal/utils/reader/reader.go @@ -5,18 +5,18 @@ import ( "fmt" ) -func ReadLine(r *bufio.Reader) ([]byte, error) { - var res []byte +func ReadLine(r *bufio.Reader, buf []byte) ([]byte, error) { + buf = buf[:0] for { var line []byte line, isPrefix, err := r.ReadLine() if err != nil { return nil, fmt.Errorf("unable to read line: %w", err) } - res = append(res, line...) + buf = append(buf, line...) if !isPrefix { break } } - return res, nil + return buf, nil } From 3f263987bcb7a53e8fc9ddbe04b03a9cd1b91844 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Tue, 20 Aug 2024 14:47:22 +0300 Subject: [PATCH 4/9] fixed: fixed case when overridden type of column does not work --- internal/db/postgres/context/pg_catalog.go | 4 ++++ internal/db/postgres/context/table.go | 6 ++++++ pkg/toolkit/column.go | 7 +++++++ pkg/toolkit/driver.go | 6 +++--- pkg/toolkit/meta.go | 2 +- pkg/toolkit/static_parameter.go | 2 +- 6 files changed, 22 insertions(+), 5 deletions(-) diff --git a/internal/db/postgres/context/pg_catalog.go b/internal/db/postgres/context/pg_catalog.go index c7a68c6a..bf43aed2 100644 --- a/internal/db/postgres/context/pg_catalog.go +++ b/internal/db/postgres/context/pg_catalog.go @@ -154,6 +154,10 @@ func getTables( // Assigning columns, pk and fk for each table for _, t := range tables { + if len(t.Columns) > 0 { + // Columns were already initialized during the transformer initialization + continue + } columns, err := getColumnsConfig(ctx, tx, t.Oid, version) if err != nil { return nil, nil, fmt.Errorf("unable to collect table columns: %w", err) diff --git a/internal/db/postgres/context/table.go b/internal/db/postgres/context/table.go index 5096dfb0..ad14a4b9 100644 --- a/internal/db/postgres/context/table.go +++ b/internal/db/postgres/context/table.go @@ -89,6 +89,12 @@ func validateAndBuildTablesConfig( } table.Columns = columns + pkColumns, err := getPrimaryKeyColumns(ctx, tx, table.Oid) + if err != nil { + return nil, nil, fmt.Errorf("unable to collect primary key columns: %w", err) + } + table.PrimaryKey = pkColumns + // Assigning overridden column types for driver initialization if tableCfg.ColumnsTypeOverride != nil { for _, c := range table.Columns { diff --git a/pkg/toolkit/column.go b/pkg/toolkit/column.go index 51cae52b..ed017c2c 100644 --- a/pkg/toolkit/column.go +++ b/pkg/toolkit/column.go @@ -47,3 +47,10 @@ func (c *Column) GetType() (string, Oid) { } return c.TypeName, c.TypeOid } + +func (c *Column) GetTypeOid() Oid { + if c.OverriddenTypeName != "" { + return c.OverriddenTypeOid + } + return c.TypeOid +} diff --git a/pkg/toolkit/driver.go b/pkg/toolkit/driver.go index c55dfbdf..8e89cd88 100644 --- a/pkg/toolkit/driver.go +++ b/pkg/toolkit/driver.go @@ -130,7 +130,7 @@ func (d *Driver) EncodeValueByColumnIdx(idx int, src any, buf []byte) ([]byte, e return nil, fmt.Errorf("index out ouf range: must be between 0 and %d received %d", d.maxIdx, idx) } c := d.Table.Columns[idx] - oid := uint32(c.TypeOid) + oid := uint32(c.GetTypeOid()) if c.OverriddenTypeOid != 0 { oid = uint32(c.OverriddenTypeOid) } @@ -158,7 +158,7 @@ func (d *Driver) ScanValueByColumnIdx(idx int, src []byte, dest any) error { return fmt.Errorf("index out ouf range: must be between 0 and %d received %d", d.maxIdx, idx) } c := d.Table.Columns[idx] - oid := uint32(c.TypeOid) + oid := uint32(c.GetTypeOid()) if c.OverriddenTypeOid != 0 { oid = uint32(c.OverriddenTypeOid) } @@ -189,7 +189,7 @@ func (d *Driver) DecodeValueByColumnIdx(idx int, src []byte) (any, error) { return nil, fmt.Errorf("index out ouf range: must be between 0 and %d received %d", d.maxIdx, idx) } c := d.Table.Columns[idx] - oid := uint32(c.TypeOid) + oid := uint32(c.GetTypeOid()) if c.OverriddenTypeOid != 0 { oid = uint32(c.OverriddenTypeOid) } diff --git a/pkg/toolkit/meta.go b/pkg/toolkit/meta.go index 4ab19739..c5ce5417 100644 --- a/pkg/toolkit/meta.go +++ b/pkg/toolkit/meta.go @@ -18,7 +18,7 @@ type Meta struct { Table *Table `json:"table"` Parameters *Parameters `json:"parameters"` Types []*Type `json:"types"` - ColumnTypeOverrides map[string]string `json:"column_type_overrides"` + ColumnsTypeOverride map[string]string `json:"columns_type_override"` } type Parameters struct { diff --git a/pkg/toolkit/static_parameter.go b/pkg/toolkit/static_parameter.go index a2b4976d..7bc28eab 100644 --- a/pkg/toolkit/static_parameter.go +++ b/pkg/toolkit/static_parameter.go @@ -301,7 +301,7 @@ func scanValue(driver *Driver, definition *ParameterDefinition, rawValue ParamsV var typeOid uint32 if linkedColumnParameter != nil { - typeOid = uint32(linkedColumnParameter.Column.TypeOid) + typeOid = uint32(linkedColumnParameter.Column.GetTypeOid()) } else { t, ok := driver.GetTypeMap().TypeForName(definition.CastDbType) if !ok { From c65349833f108d1e94cf3c0cf9da98c40b77fb93 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Sat, 24 Aug 2024 12:24:27 +0300 Subject: [PATCH 5/9] fix: added error unused decoder setting Now if user provides the unknown key the error will be thrown Closes #176 --- cmd/greenmask/cmd/root.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/greenmask/cmd/root.go b/cmd/greenmask/cmd/root.go index b7f21f44..7c676879 100644 --- a/cmd/greenmask/cmd/root.go +++ b/cmd/greenmask/cmd/root.go @@ -140,6 +140,7 @@ func initConfig() { mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToSliceHookFunc(","), ) + cfg.ErrorUnused = true } if err := viper.Unmarshal(Config, decoderCfg); err != nil { From 69ac47e3dff51143e70884384ed03c675e026db2 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Sat, 24 Aug 2024 13:12:07 +0300 Subject: [PATCH 6/9] fix: added error unused decoder setting Fixed parameters encoding --- internal/domains/config.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/internal/domains/config.go b/internal/domains/config.go index 7737f51d..ada2d446 100644 --- a/internal/domains/config.go +++ b/internal/domains/config.go @@ -132,7 +132,11 @@ type TransformerConfig struct { // This cannot be parsed with mapstructure due to uncontrollable lowercasing // https://github.com/spf13/viper/issues/373 // Instead we have to use workaround and parse it manually - Params toolkit.StaticParameters `mapstructure:"-" yaml:"-" json:"-"` // yaml:"params" json:"params,omitempty"` + // + // Params attribute decoding is dummy. It is replaced in the runtime internal/utils/config/viper_workaround.go + // But it is required to leave mapstruicture tag to avoid errors raised by viper and decoder setting + // ErrorUnused = true. It was set in PR #177 (https://github.com/GreenmaskIO/greenmask/pull/177/files) + Params toolkit.StaticParameters `mapstructure:"params" yaml:"params" json:"params"` // MetadataParams - encoded transformer parameters - uses only for storing into storage // TODO: You need to get rid of it by creating a separate structure for storing metadata in // internal/db/postgres/storage/metadata_json.go From 06d6976acba2f3d280442a5207549c824ba42a20 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Sat, 24 Aug 2024 15:12:11 +0300 Subject: [PATCH 7/9] fix: fixed transformers threshold validation * fixed case when the min and max thresholds provided were ignored --- .../postgres/transformers/noise_date_test.go | 16 +++++ .../db/postgres/transformers/noise_float.go | 45 ++++++++------ .../postgres/transformers/noise_float_test.go | 13 +++++ .../db/postgres/transformers/noise_int.go | 24 ++++++-- .../db/postgres/transformers/noise_numeric.go | 47 ++++++++++----- .../transformers/noise_numeric_test.go | 13 +++++ .../db/postgres/transformers/random_float.go | 20 +++++-- .../transformers/random_float_test.go | 15 +++++ .../db/postgres/transformers/random_int.go | 41 ++++++++----- .../postgres/transformers/random_int_test.go | 58 +++++++++++++++++-- .../postgres/transformers/random_numeric.go | 40 +++++++++---- .../transformers/random_numeric_test.go | 14 +++++ 12 files changed, 274 insertions(+), 72 deletions(-) diff --git a/internal/db/postgres/transformers/noise_date_test.go b/internal/db/postgres/transformers/noise_date_test.go index 0d81fdf0..af84f361 100644 --- a/internal/db/postgres/transformers/noise_date_test.go +++ b/internal/db/postgres/transformers/noise_date_test.go @@ -93,6 +93,22 @@ func TestNoiseDateTransformer_Transform(t *testing.T) { max: time.Date(2024, 8, 29, 1, 1, 1, 1000, loc), }, }, + { + name: "test timestamp type with Truncate till day", + params: map[string]toolkit.ParamsValue{ + "max_ratio": toolkit.ParamsValue("1 year 1 mons 1 day 01:01:01.01"), + "truncate": toolkit.ParamsValue("month"), + "column": toolkit.ParamsValue("date_ts"), + "min": toolkit.ParamsValue("2022-06-01 22:00:00"), + "max": toolkit.ParamsValue("2024-01-29 01:01:01.001"), + }, + original: "2023-06-25 00:00:00", + result: result{ + pattern: `^\d{4}-\d{2}-01 0{2}:0{2}:0{2}$`, + min: time.Date(2022, 3, 1, 22, 00, 0, 0, loc), + max: time.Date(2024, 8, 29, 1, 1, 1, 1000, loc), + }, + }, } for _, tt := range tests { diff --git a/internal/db/postgres/transformers/noise_float.go b/internal/db/postgres/transformers/noise_float.go index d55e5067..9c8be144 100644 --- a/internal/db/postgres/transformers/noise_float.go +++ b/internal/db/postgres/transformers/noise_float.go @@ -17,7 +17,6 @@ package transformers import ( "context" "fmt" - "math" "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils" "github.com/greenmaskio/greenmask/internal/generators/transformers" @@ -97,7 +96,8 @@ func NewNoiseFloatTransformer(ctx context.Context, driver *toolkit.Driver, param var columnName, engine string var dynamicMode bool - var minValueThreshold, maxValueThreshold, minRatio, maxRatio float64 + var minValueThreshold, maxValueThreshold *float64 + var minRatio, maxRatio float64 var decimal int columnParam := parameters["column"] @@ -129,11 +129,23 @@ func NewNoiseFloatTransformer(ctx context.Context, driver *toolkit.Driver, param floatSize := c.GetColumnSize() if !dynamicMode { - if err := minParam.Scan(&minValueThreshold); err != nil { - return nil, nil, fmt.Errorf("error scanning \"min\" parameter: %w", err) + minIsEmpty, err := minParam.IsEmpty() + if err != nil { + return nil, nil, fmt.Errorf("error checking \"min\" parameter: %w", err) + } + if !minIsEmpty { + if err = minParam.Scan(&minValueThreshold); err != nil { + return nil, nil, fmt.Errorf("error scanning \"min\" parameter: %w", err) + } } - if err := maxParam.Scan(&maxValueThreshold); err != nil { - return nil, nil, fmt.Errorf("error scanning \"max\" parameter: %w", err) + maxIsEmpty, err := maxParam.IsEmpty() + if err != nil { + return nil, nil, fmt.Errorf("error checking \"max\" parameter: %w", err) + } + if !maxIsEmpty { + if err = maxParam.Scan(&maxValueThreshold); err != nil { + return nil, nil, fmt.Errorf("error scanning \"max\" parameter: %w", err) + } } } @@ -244,15 +256,21 @@ func (nft *NoiseFloatTransformer) Transform(ctx context.Context, r *toolkit.Reco } func validateNoiseFloatTypeAndSetLimit( - size int, requestedMinValue, requestedMaxValue float64, decimal int, + size int, requestedMinValue, requestedMaxValue *float64, decimal int, ) (limiter *transformers.NoiseFloat64Limiter, warns toolkit.ValidationWarnings, err error) { minValue, maxValue, err := getFloatThresholds(size) if err != nil { return nil, nil, err } + if requestedMinValue == nil { + requestedMinValue = &minValue + } + if requestedMaxValue == nil { + requestedMaxValue = &maxValue + } - if !limitIsValid(requestedMinValue, minValue, maxValue) { + if !limitIsValid(*requestedMinValue, minValue, maxValue) { warns = append(warns, toolkit.NewValidationWarning(). SetMsgf("requested min value is out of float%d range", size). SetSeverity(toolkit.ErrorValidationSeverity). @@ -263,7 +281,7 @@ func validateNoiseFloatTypeAndSetLimit( ) } - if !limitIsValid(requestedMaxValue, minValue, maxValue) { + if !limitIsValid(*requestedMaxValue, minValue, maxValue) { warns = append(warns, toolkit.NewValidationWarning(). SetMsgf("requested max value is out of float%d range", size). SetSeverity(toolkit.ErrorValidationSeverity). @@ -278,18 +296,11 @@ func validateNoiseFloatTypeAndSetLimit( return nil, warns, nil } - limiter, err = transformers.NewNoiseFloat64Limiter(-math.MaxFloat64, math.MaxFloat64, decimal) + limiter, err = transformers.NewNoiseFloat64Limiter(*requestedMinValue, *requestedMaxValue, decimal) if err != nil { return nil, nil, err } - if requestedMinValue != 0 || requestedMaxValue != 0 { - limiter, err = transformers.NewNoiseFloat64Limiter(requestedMinValue, requestedMaxValue, decimal) - if err != nil { - return nil, nil, err - } - } - return limiter, nil, nil } diff --git a/internal/db/postgres/transformers/noise_float_test.go b/internal/db/postgres/transformers/noise_float_test.go index 443df76a..9024d1a2 100644 --- a/internal/db/postgres/transformers/noise_float_test.go +++ b/internal/db/postgres/transformers/noise_float_test.go @@ -107,6 +107,19 @@ func TestNoiseFloatTransformer_Transform(t *testing.T) { input: "100", result: result{min: 90, max: 110, regexp: `^-*\d+$`}, }, + { + name: "with thresholds min zero", + columnName: "col_float8", + params: map[string]toolkit.ParamsValue{ + "min_ratio": toolkit.ParamsValue("0.2"), + "max_ratio": toolkit.ParamsValue("0.9"), + "min": toolkit.ParamsValue("0"), + "max": toolkit.ParamsValue("110"), + "decimal": toolkit.ParamsValue("0"), + }, + input: "100", + result: result{min: 0, max: 110, regexp: `^-*\d+$`}, + }, } for _, tt := range tests { diff --git a/internal/db/postgres/transformers/noise_int.go b/internal/db/postgres/transformers/noise_int.go index 8cc6474b..91d5cd9c 100644 --- a/internal/db/postgres/transformers/noise_int.go +++ b/internal/db/postgres/transformers/noise_int.go @@ -86,7 +86,7 @@ type NoiseIntTransformer struct { func NewNoiseIntTransformer(ctx context.Context, driver *toolkit.Driver, parameters map[string]toolkit.Parameterizer) (utils.Transformer, toolkit.ValidationWarnings, error) { var columnName, engine string var minRatio, maxRatio float64 - var maxValueThreshold, minValueThreshold int64 + var maxValueThreshold, minValueThreshold *int64 var dynamicMode bool columnParam := parameters["column"] @@ -118,11 +118,23 @@ func NewNoiseIntTransformer(ctx context.Context, driver *toolkit.Driver, paramet } if !dynamicMode { - if err := minParam.Scan(&maxValueThreshold); err != nil { - return nil, nil, fmt.Errorf("error scanning \"min\" parameter: %w", err) + minIsEmpty, err := minParam.IsEmpty() + if err != nil { + return nil, nil, fmt.Errorf("error checking \"min\" parameter: %w", err) + } + if !minIsEmpty { + if err = minParam.Scan(&minValueThreshold); err != nil { + return nil, nil, fmt.Errorf("error scanning \"min\" parameter: %w", err) + } + } + maxIsEmpty, err := maxParam.IsEmpty() + if err != nil { + return nil, nil, fmt.Errorf("error checking \"max\" parameter: %w", err) } - if err := maxParam.Scan(&minValueThreshold); err != nil { - return nil, nil, fmt.Errorf("error scanning \"max\" parameter: %w", err) + if !maxIsEmpty { + if err = maxParam.Scan(&maxValueThreshold); err != nil { + return nil, nil, fmt.Errorf("error scanning \"max\" parameter: %w", err) + } } } @@ -227,7 +239,7 @@ func (nit *NoiseIntTransformer) Transform(ctx context.Context, r *toolkit.Record } func validateIntTypeAndSetNoiseInt64Limiter( - size int, requestedMinValue, requestedMaxValue int64, + size int, requestedMinValue, requestedMaxValue *int64, ) (limiter *transformers.NoiseInt64Limiter, warns toolkit.ValidationWarnings, err error) { minValue, maxValue, warns, err := validateInt64AndGetLimits(size, requestedMinValue, requestedMaxValue) diff --git a/internal/db/postgres/transformers/noise_numeric.go b/internal/db/postgres/transformers/noise_numeric.go index b57e16d9..7fd14f6e 100644 --- a/internal/db/postgres/transformers/noise_numeric.go +++ b/internal/db/postgres/transformers/noise_numeric.go @@ -118,7 +118,7 @@ func NewNumericFloatTransformer(ctx context.Context, driver *toolkit.Driver, par var columnName, engine string var dynamicMode bool var minRatio, maxRatio float64 - var minValueThreshold, maxValueThreshold decimal.Decimal + var minValueThreshold, maxValueThreshold *decimal.Decimal var precision int32 columnParam := parameters["column"] @@ -149,14 +149,34 @@ func NewNumericFloatTransformer(ctx context.Context, driver *toolkit.Driver, par affectedColumns[idx] = columnName if !dynamicMode { - if err := minParam.Scan(&minValueThreshold); err != nil { - return nil, nil, fmt.Errorf("error scanning \"min\" parameter: %w", err) + minIsEmpty, err := minParam.IsEmpty() + if err != nil { + return nil, nil, fmt.Errorf("error checking \"min\" parameter: %w", err) + } + if !minIsEmpty { + if err = minParam.Scan(&minValueThreshold); err != nil { + return nil, nil, fmt.Errorf("error scanning \"min\" parameter: %w", err) + } + } + maxIsEmpty, err := maxParam.IsEmpty() + if err != nil { + return nil, nil, fmt.Errorf("error checking \"max\" parameter: %w", err) } - if err := maxParam.Scan(&maxValueThreshold); err != nil { - return nil, nil, fmt.Errorf("error scanning \"max\" parameter: %w", err) + if !maxIsEmpty { + if err = maxParam.Scan(&maxValueThreshold); err != nil { + return nil, nil, fmt.Errorf("error scanning \"max\" parameter: %w", err) + } } } + limiter, limitsWarnings, err := validateNoiseNumericTypeAndSetLimit(bigIntegerTransformerGenByteLength, minValueThreshold, maxValueThreshold) + if err != nil { + return nil, nil, err + } + if limitsWarnings.IsFatal() { + return nil, limitsWarnings, nil + } + if err := decimalParam.Scan(&precision); err != nil { return nil, nil, fmt.Errorf(`unable to scan "decimal" param: %w`, err) } @@ -169,13 +189,6 @@ func NewNumericFloatTransformer(ctx context.Context, driver *toolkit.Driver, par return nil, nil, fmt.Errorf("unable to scan \"max_ratio\" param: %w", err) } - limiter, limitsWarnings, err := validateNoiseNumericTypeAndSetLimit(bigIntegerTransformerGenByteLength, minValueThreshold, maxValueThreshold) - if err != nil { - return nil, nil, err - } - if limitsWarnings.IsFatal() { - return nil, limitsWarnings, nil - } limiter.SetPrecision(precision) t := transformers.NewNoiseNumericTransformer(limiter, minRatio, maxRatio) @@ -269,7 +282,7 @@ func (nft *NoiseNumericTransformer) Transform(ctx context.Context, r *toolkit.Re } func validateNoiseNumericTypeAndSetLimit( - size int, requestedMinValue, requestedMaxValue decimal.Decimal, + size int, requestedMinValue, requestedMaxValue *decimal.Decimal, ) (limiter *transformers.NoiseNumericLimiter, warns toolkit.ValidationWarnings, err error) { minVal, maxVal, warns, err := getNumericThresholds(size, requestedMinValue, requestedMaxValue) @@ -279,8 +292,14 @@ func validateNoiseNumericTypeAndSetLimit( if warns.IsFatal() { return nil, warns, nil } + if requestedMinValue == nil { + requestedMinValue = &minVal + } + if requestedMaxValue == nil { + requestedMaxValue = &maxVal + } - limiter, err = transformers.NewNoiseNumericLimiter(minVal, maxVal) + limiter, err = transformers.NewNoiseNumericLimiter(*requestedMinValue, *requestedMaxValue) if err != nil { return nil, nil, fmt.Errorf("error creating limiter by size: %w", err) } diff --git a/internal/db/postgres/transformers/noise_numeric_test.go b/internal/db/postgres/transformers/noise_numeric_test.go index 86ba4ced..61600277 100644 --- a/internal/db/postgres/transformers/noise_numeric_test.go +++ b/internal/db/postgres/transformers/noise_numeric_test.go @@ -97,6 +97,19 @@ func TestNoiseNumericTransformer_Transform(t *testing.T) { input: "100", result: result{min: 90, max: 110, regexp: `^-*\d+$`}, }, + { + name: "numeric with thresholds", + columnName: "id_numeric", + params: map[string]toolkit.ParamsValue{ + "min_ratio": toolkit.ParamsValue("0.2"), + "max_ratio": toolkit.ParamsValue("0.9"), + "min": toolkit.ParamsValue("0"), + "max": toolkit.ParamsValue("50"), + "decimal": toolkit.ParamsValue("4"), + }, + input: "100", + result: result{min: 10, max: 190, regexp: `^-*\d+[.]*\d{0,4}$`}, + }, } for _, tt := range tests { diff --git a/internal/db/postgres/transformers/random_float.go b/internal/db/postgres/transformers/random_float.go index 366a6b67..b6cdd693 100644 --- a/internal/db/postgres/transformers/random_float.go +++ b/internal/db/postgres/transformers/random_float.go @@ -138,11 +138,23 @@ func NewFloatTransformer(ctx context.Context, driver *toolkit.Driver, parameters } if !dynamicMode { - if err := minParam.Scan(&minVal); err != nil { - return nil, nil, fmt.Errorf("error scanning \"min\" parameter: %w", err) + minIsEmpty, err := minParam.IsEmpty() + if err != nil { + return nil, nil, fmt.Errorf("error checking \"min\" parameter: %w", err) + } + if !minIsEmpty { + if err = minParam.Scan(&minVal); err != nil { + return nil, nil, fmt.Errorf("error scanning \"min\" parameter: %w", err) + } + } + maxIsEmpty, err := maxParam.IsEmpty() + if err != nil { + return nil, nil, fmt.Errorf("error checking \"max\" parameter: %w", err) } - if err := maxParam.Scan(&maxVal); err != nil { - return nil, nil, fmt.Errorf("error scanning \"max\" parameter: %w", err) + if !maxIsEmpty { + if err = maxParam.Scan(&maxVal); err != nil { + return nil, nil, fmt.Errorf("error scanning \"max\" parameter: %w", err) + } } } diff --git a/internal/db/postgres/transformers/random_float_test.go b/internal/db/postgres/transformers/random_float_test.go index a519366e..dfd7cb7a 100644 --- a/internal/db/postgres/transformers/random_float_test.go +++ b/internal/db/postgres/transformers/random_float_test.go @@ -121,6 +121,21 @@ func TestRandomFloatTransformer_Transform(t *testing.T) { isNull: true, }, }, + { + name: "keep_null true and NULL seq", + columnName: "col_float8", + originalValue: "\\N", + params: map[string]toolkit.ParamsValue{ + "min": toolkit.ParamsValue("0"), + "max": toolkit.ParamsValue("1000"), + "decimal": toolkit.ParamsValue("0"), + "keep_null": toolkit.ParamsValue("false"), + }, + result: result{ + min: 0, + max: 1000, + }, + }, //{ // name: "text with default float8", // params: map[string]toolkit.ParamsValue{ diff --git a/internal/db/postgres/transformers/random_int.go b/internal/db/postgres/transformers/random_int.go index 9d5b7e30..5982edfb 100644 --- a/internal/db/postgres/transformers/random_int.go +++ b/internal/db/postgres/transformers/random_int.go @@ -93,7 +93,7 @@ type IntegerTransformer struct { func NewIntegerTransformer(ctx context.Context, driver *toolkit.Driver, parameters map[string]toolkit.Parameterizer) (utils.Transformer, toolkit.ValidationWarnings, error) { var columnName, engine string - var minVal, maxVal int64 + var minVal, maxVal *int64 var keepNull, dynamicMode bool columnParam := parameters["column"] @@ -127,11 +127,23 @@ func NewIntegerTransformer(ctx context.Context, driver *toolkit.Driver, paramete } if !dynamicMode { - if err := minParam.Scan(&minVal); err != nil { - return nil, nil, fmt.Errorf("error scanning \"min\" parameter: %w", err) + minIsEmpty, err := minParam.IsEmpty() + if err != nil { + return nil, nil, fmt.Errorf("error checking \"min\" parameter: %w", err) + } + if !minIsEmpty { + if err = minParam.Scan(&minVal); err != nil { + return nil, nil, fmt.Errorf("error scanning \"min\" parameter: %w", err) + } } - if err := maxParam.Scan(&maxVal); err != nil { - return nil, nil, fmt.Errorf("error scanning \"max\" parameter: %w", err) + maxIsEmpty, err := maxParam.IsEmpty() + if err != nil { + return nil, nil, fmt.Errorf("error checking \"max\" parameter: %w", err) + } + if !maxIsEmpty { + if err = maxParam.Scan(&maxVal); err != nil { + return nil, nil, fmt.Errorf("error scanning \"max\" parameter: %w", err) + } } } @@ -281,7 +293,7 @@ func limitIsValid[T int64 | float64](requestedThreshold, minValue, maxValue T) b } func validateIntTypeAndSetRandomInt64Limiter( - size int, requestedMinValue, requestedMaxValue int64, + size int, requestedMinValue, requestedMaxValue *int64, ) (limiter *transformers.Int64Limiter, warns toolkit.ValidationWarnings, err error) { minValue, maxValue, warns, err := validateInt64AndGetLimits(size, requestedMinValue, requestedMaxValue) @@ -299,15 +311,21 @@ func validateIntTypeAndSetRandomInt64Limiter( } func validateInt64AndGetLimits( - size int, requestedMinValue, requestedMaxValue int64, + size int, requestedMinValue, requestedMaxValue *int64, ) (int64, int64, toolkit.ValidationWarnings, error) { var warns toolkit.ValidationWarnings minValue, maxValue, err := getIntThresholds(size) if err != nil { return 0, 0, nil, err } + if requestedMinValue == nil { + requestedMinValue = &minValue + } + if requestedMaxValue == nil { + requestedMaxValue = &maxValue + } - if !limitIsValid(requestedMinValue, minValue, maxValue) { + if !limitIsValid(*requestedMinValue, minValue, maxValue) { warns = append(warns, toolkit.NewValidationWarning(). SetMsgf("requested min value is out of int%d range", size). SetSeverity(toolkit.ErrorValidationSeverity). @@ -318,7 +336,7 @@ func validateInt64AndGetLimits( ) } - if !limitIsValid(requestedMaxValue, minValue, maxValue) { + if !limitIsValid(*requestedMaxValue, minValue, maxValue) { warns = append(warns, toolkit.NewValidationWarning(). SetMsgf("requested max value is out of int%d range", size). SetSeverity(toolkit.ErrorValidationSeverity). @@ -332,11 +350,8 @@ func validateInt64AndGetLimits( if warns.IsFatal() { return 0, 0, warns, nil } - if requestedMinValue != 0 || requestedMaxValue != 0 { - return requestedMinValue, requestedMaxValue, nil, nil - } - return minValue, maxValue, nil, nil + return *requestedMinValue, *requestedMaxValue, nil, nil } func init() { diff --git a/internal/db/postgres/transformers/random_int_test.go b/internal/db/postgres/transformers/random_int_test.go index 5352df9d..e009c84e 100644 --- a/internal/db/postgres/transformers/random_int_test.go +++ b/internal/db/postgres/transformers/random_int_test.go @@ -13,12 +13,19 @@ import ( func TestRandomIntTransformer_Transform_random_static(t *testing.T) { + type expected struct { + min int64 + max int64 + isNull bool + } + tests := []struct { name string columnName string originalValue string expectedRegexp string params map[string]toolkit.ParamsValue + expected expected }{ { name: "int2", @@ -29,6 +36,10 @@ func TestRandomIntTransformer_Transform_random_static(t *testing.T) { "min": toolkit.ParamsValue("1"), "max": toolkit.ParamsValue("100"), }, + expected: expected{ + min: 1, + max: 100, + }, }, { name: "int4", @@ -39,6 +50,10 @@ func TestRandomIntTransformer_Transform_random_static(t *testing.T) { "min": toolkit.ParamsValue("1"), "max": toolkit.ParamsValue("100"), }, + expected: expected{ + min: 1, + max: 100, + }, }, { name: "int8", @@ -49,6 +64,10 @@ func TestRandomIntTransformer_Transform_random_static(t *testing.T) { "min": toolkit.ParamsValue("1"), "max": toolkit.ParamsValue("100"), }, + expected: expected{ + min: 1, + max: 100, + }, }, { name: "keep_null false and NULL seq", @@ -60,6 +79,10 @@ func TestRandomIntTransformer_Transform_random_static(t *testing.T) { "max": toolkit.ParamsValue("100"), "keep_null": toolkit.ParamsValue("false"), }, + expected: expected{ + min: 1, + max: 100, + }, }, { name: "keep_null true and NULL seq", @@ -72,6 +95,27 @@ func TestRandomIntTransformer_Transform_random_static(t *testing.T) { "keep_null": toolkit.ParamsValue("true"), "engine": toolkit.ParamsValue("random"), }, + expected: expected{ + min: 1, + max: 100, + isNull: true, + }, + }, + { + name: "test zero min", + columnName: "id8", + originalValue: "\\N", + expectedRegexp: fmt.Sprintf(`^(\%s)$`, "\\N"), + params: map[string]toolkit.ParamsValue{ + "min": toolkit.ParamsValue("0"), + "max": toolkit.ParamsValue("100"), + "engine": toolkit.ParamsValue("random"), + "keep_null": toolkit.ParamsValue("false"), + }, + expected: expected{ + min: 0, + max: 100, + }, }, } @@ -96,12 +140,14 @@ func TestRandomIntTransformer_Transform_random_static(t *testing.T) { record, ) require.NoError(t, err) - - encoded, err := r.Encode() - require.NoError(t, err) - res, err := encoded.Encode() - require.NoError(t, err) - require.Regexp(t, tt.expectedRegexp, string(res)) + rawData, _ := r.GetRawColumnValueByName(tt.columnName) + require.Equal(t, tt.expected.isNull, rawData.IsNull) + if !rawData.IsNull { + var resInt int64 + _, err = r.ScanColumnValueByName(tt.columnName, &resInt) + require.NoError(t, err) + require.True(t, resInt >= tt.expected.min && resInt <= tt.expected.max) + } }) } diff --git a/internal/db/postgres/transformers/random_numeric.go b/internal/db/postgres/transformers/random_numeric.go index 7199c6e8..cb73d95c 100644 --- a/internal/db/postgres/transformers/random_numeric.go +++ b/internal/db/postgres/transformers/random_numeric.go @@ -106,7 +106,7 @@ type NumericTransformer struct { func NewRandomNumericTransformer(ctx context.Context, driver *toolkit.Driver, parameters map[string]toolkit.Parameterizer) (utils.Transformer, toolkit.ValidationWarnings, error) { var columnName, engine string - var minVal, maxVal decimal.Decimal + var minVal, maxVal *decimal.Decimal var keepNull, dynamicMode bool var precision int32 @@ -144,12 +144,25 @@ func NewRandomNumericTransformer(ctx context.Context, driver *toolkit.Driver, pa } if !dynamicMode { - if err := minParam.Scan(&minVal); err != nil { - return nil, nil, fmt.Errorf("error scanning \"min\" parameter: %w", err) + minIsEmpty, err := minParam.IsEmpty() + if err != nil { + return nil, nil, fmt.Errorf("error checking \"min\" parameter: %w", err) } - if err := maxParam.Scan(&maxVal); err != nil { - return nil, nil, fmt.Errorf("error scanning \"max\" parameter: %w", err) + if !minIsEmpty { + if err = minParam.Scan(&minVal); err != nil { + return nil, nil, fmt.Errorf("error scanning \"min\" parameter: %w", err) + } + } + maxIsEmpty, err := maxParam.IsEmpty() + if err != nil { + return nil, nil, fmt.Errorf("error checking \"max\" parameter: %w", err) } + if !maxIsEmpty { + if err = maxParam.Scan(&maxVal); err != nil { + return nil, nil, fmt.Errorf("error scanning \"max\" parameter: %w", err) + } + } + } limiter, limitsWarnings, err := validateRandomNumericTypeAndSetLimit(bigIntegerTransformerGenByteLength, minVal, maxVal) @@ -250,7 +263,7 @@ func (bit *NumericTransformer) Transform(ctx context.Context, r *toolkit.Record) return r, nil } -func getNumericThresholds(size int, requestedMinValue, requestedMaxValue decimal.Decimal, +func getNumericThresholds(size int, requestedMinValue, requestedMaxValue *decimal.Decimal, ) (decimal.Decimal, decimal.Decimal, toolkit.ValidationWarnings, error) { var warns toolkit.ValidationWarnings minVal, maxVal, err := transformers.GetMinAndMaxNumericValueBySetting(size) @@ -258,11 +271,14 @@ func getNumericThresholds(size int, requestedMinValue, requestedMaxValue decimal return decimal.Decimal{}, decimal.Decimal{}, nil, fmt.Errorf("error creating limiter by size: %w", err) } - if requestedMinValue.Equal(decimal.NewFromInt(0)) && requestedMinValue.Equal(decimal.NewFromInt(0)) { - return minVal, maxVal, nil, nil + if requestedMinValue == nil { + requestedMinValue = &minVal + } + if requestedMaxValue == nil { + requestedMaxValue = &maxVal } - if !numericLimitIsValid(requestedMinValue, minVal, maxVal) { + if !numericLimitIsValid(*requestedMinValue, minVal, maxVal) { warns = append(warns, toolkit.NewValidationWarning(). SetMsgf("requested min value is out of numeric(%d) range", size). SetSeverity(toolkit.ErrorValidationSeverity). @@ -273,7 +289,7 @@ func getNumericThresholds(size int, requestedMinValue, requestedMaxValue decimal ) } - if !numericLimitIsValid(requestedMaxValue, minVal, maxVal) { + if !numericLimitIsValid(*requestedMaxValue, minVal, maxVal) { warns = append(warns, toolkit.NewValidationWarning(). SetMsgf("requested max value is out of NEMERIC(%d) range", size). SetSeverity(toolkit.ErrorValidationSeverity). @@ -286,11 +302,11 @@ func getNumericThresholds(size int, requestedMinValue, requestedMaxValue decimal if warns.IsFatal() { return decimal.Decimal{}, decimal.Decimal{}, warns, nil } - return requestedMinValue, requestedMaxValue, nil, nil + return *requestedMinValue, *requestedMaxValue, nil, nil } func validateRandomNumericTypeAndSetLimit( - size int, requestedMinValue, requestedMaxValue decimal.Decimal, + size int, requestedMinValue, requestedMaxValue *decimal.Decimal, ) (limiter *transformers.RandomNumericLimiter, warns toolkit.ValidationWarnings, err error) { minVal, maxVal, warns, err := getNumericThresholds(size, requestedMinValue, requestedMaxValue) diff --git a/internal/db/postgres/transformers/random_numeric_test.go b/internal/db/postgres/transformers/random_numeric_test.go index 429fe86e..56d87d19 100644 --- a/internal/db/postgres/transformers/random_numeric_test.go +++ b/internal/db/postgres/transformers/random_numeric_test.go @@ -67,6 +67,20 @@ func TestBigIntTransformer_Transform_random_static(t *testing.T) { isNull: true, }, }, + { + name: "Regression for implicitly set threshold", + columnName: "id_numeric", + originalValue: "12345", + params: map[string]toolkit.ParamsValue{ + "min": toolkit.ParamsValue("0.0"), + "max": toolkit.ParamsValue("10.0"), + "decimal": toolkit.ParamsValue("2"), + }, + expected: expected{ + min: decimal.RequireFromString("0.0"), + max: decimal.RequireFromString("10.0"), + }, + }, } for _, tt := range tests { From 6717ad20eb54845079c8a0da74c35fecf53eb778 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Sat, 24 Aug 2024 15:23:33 +0300 Subject: [PATCH 8/9] fix: fixed NoiseInt transformer. Added test --- .../db/postgres/transformers/noise_int.go | 16 ++--- .../postgres/transformers/noise_int_test.go | 58 ++++++++++++++----- 2 files changed, 51 insertions(+), 23 deletions(-) diff --git a/internal/db/postgres/transformers/noise_int.go b/internal/db/postgres/transformers/noise_int.go index 91d5cd9c..0339b5d7 100644 --- a/internal/db/postgres/transformers/noise_int.go +++ b/internal/db/postgres/transformers/noise_int.go @@ -138,6 +138,14 @@ func NewNoiseIntTransformer(ctx context.Context, driver *toolkit.Driver, paramet } } + limiter, limitsWarnings, err := validateIntTypeAndSetNoiseInt64Limiter(intSize, minValueThreshold, maxValueThreshold) + if err != nil { + return nil, nil, err + } + if limitsWarnings.IsFatal() { + return nil, limitsWarnings, nil + } + if err := minRatioParam.Scan(&minRatio); err != nil { return nil, nil, fmt.Errorf("unable to scan \"min_ratio\" param: %w", err) } @@ -146,14 +154,6 @@ func NewNoiseIntTransformer(ctx context.Context, driver *toolkit.Driver, paramet return nil, nil, fmt.Errorf("unable to scan \"max_ratio\" param: %w", err) } - limiter, limitsWarnings, err := validateIntTypeAndSetNoiseInt64Limiter(intSize, maxValueThreshold, minValueThreshold) - if err != nil { - return nil, nil, err - } - if limitsWarnings.IsFatal() { - return nil, limitsWarnings, nil - } - t, err := transformers.NewNoiseInt64Transformer(limiter, minRatio, maxRatio) if err != nil { return nil, nil, fmt.Errorf("error initializing common int transformer: %w", err) diff --git a/internal/db/postgres/transformers/noise_int_test.go b/internal/db/postgres/transformers/noise_int_test.go index e9d49840..012ffadd 100644 --- a/internal/db/postgres/transformers/noise_int_test.go +++ b/internal/db/postgres/transformers/noise_int_test.go @@ -16,7 +16,6 @@ package transformers import ( "context" - "fmt" "testing" "github.com/rs/zerolog/log" @@ -38,43 +37,72 @@ func TestNoiseIntTransformer_Transform(t *testing.T) { // Positive cases tests := []struct { name string - ratio float64 columnName string + params map[string]toolkit.ParamsValue originalValue string result result }{ { - name: "int2", - columnName: "id2", - ratio: 0.9, + name: "int2", + columnName: "id2", + params: map[string]toolkit.ParamsValue{ + "min_ratio": toolkit.ParamsValue("0.2"), + "max_ratio": toolkit.ParamsValue("0.9"), + }, result: result{min: 12, max: 234}, originalValue: "123", }, { - name: "int4", - columnName: "id4", - ratio: 0.9, + name: "int4", + columnName: "id4", + params: map[string]toolkit.ParamsValue{ + "min_ratio": toolkit.ParamsValue("0.2"), + "max_ratio": toolkit.ParamsValue("0.9"), + }, result: result{min: 12, max: 234}, originalValue: "123", }, { - name: "int8", - columnName: "id8", - ratio: 0.9, + name: "int8", + columnName: "id8", + params: map[string]toolkit.ParamsValue{ + "min_ratio": toolkit.ParamsValue("0.2"), + "max_ratio": toolkit.ParamsValue("0.9"), + }, result: result{min: 12, max: 234}, originalValue: "123", }, + { + name: "int8", + columnName: "id8", + params: map[string]toolkit.ParamsValue{ + "min_ratio": toolkit.ParamsValue("0.2"), + "max_ratio": toolkit.ParamsValue("0.9"), + }, + result: result{min: 12, max: 234}, + originalValue: "123", + }, + { + name: "int8", + columnName: "id8", + params: map[string]toolkit.ParamsValue{ + "min_ratio": toolkit.ParamsValue("0.2"), + "max_ratio": toolkit.ParamsValue("0.9"), + "min": toolkit.ParamsValue("0"), + "max": toolkit.ParamsValue("110"), + }, + result: result{min: 0, max: 110}, + originalValue: "123", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + tt.params["column"] = toolkit.ParamsValue(tt.columnName) driver, record := getDriverAndRecord(tt.columnName, tt.originalValue) transformerCtx, warnings, err := NoiseIntTransformerDefinition.Instance( context.Background(), - driver, map[string]toolkit.ParamsValue{ - "column": toolkit.ParamsValue(tt.columnName), - "min_ratio": toolkit.ParamsValue(fmt.Sprintf("%f", tt.ratio)), - }, + driver, tt.params, nil, ) require.NoError(t, err) From c9a055df44587b656b89a591d03d4bef6eb15cf5 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Sat, 24 Aug 2024 16:49:20 +0300 Subject: [PATCH 9/9] fix: fixed COPY restoration statement string (added semicolon and new line) Closes #179 --- internal/db/postgres/entries/table.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/db/postgres/entries/table.go b/internal/db/postgres/entries/table.go index bc389a40..4eef88ad 100644 --- a/internal/db/postgres/entries/table.go +++ b/internal/db/postgres/entries/table.go @@ -84,7 +84,7 @@ func (t *Table) Entry() (*toc.Entry, error) { } } - var query = `COPY "%s"."%s" (%s) FROM stdin` + var query = "COPY \"%s\".\"%s\" (%s) FROM stdin;\n" var schemaName, tableName string if t.LoadViaPartitionRoot && t.RootPtSchema != "" && t.RootPtName != "" { schemaName = t.RootPtSchema