From de3064bc194e18377a9e0b1a865e8dfb0fe429f4 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Sun, 10 Mar 2024 22:22:25 +0200 Subject: [PATCH 1/3] Previous dump schema diff Implemented previous backup schema diff that might be performed by the option `--schema` --- cmd/greenmask/cmd/validate/validate.go | 23 ++- internal/db/postgres/cmd/dump.go | 1 + internal/db/postgres/cmd/validate.go | 175 +++++++++++++--- internal/db/postgres/context/context.go | 8 + internal/db/postgres/context/pg_catalog.go | 78 +++++++ .../db/postgres/context/pg_catalog_test.go | 25 ++- internal/db/postgres/context/schema.go | 52 +++++ internal/db/postgres/storage/metadata_json.go | 20 +- internal/domains/config.go | 1 + pkg/toolkit/database_schema.go | 191 ++++++++++++++++++ pkg/toolkit/table.go | 3 + 11 files changed, 531 insertions(+), 46 deletions(-) create mode 100644 internal/db/postgres/context/schema.go create mode 100644 pkg/toolkit/database_schema.go diff --git a/cmd/greenmask/cmd/validate/validate.go b/cmd/greenmask/cmd/validate/validate.go index 2b9fbcd2..aa73f63e 100644 --- a/cmd/greenmask/cmd/validate/validate.go +++ b/cmd/greenmask/cmd/validate/validate.go @@ -16,6 +16,7 @@ package validate import ( "context" + "os" "github.com/rs/zerolog/log" "github.com/spf13/cobra" @@ -24,6 +25,7 @@ import ( cmdInternals "github.com/greenmaskio/greenmask/internal/db/postgres/cmd" "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils" "github.com/greenmaskio/greenmask/internal/domains" + "github.com/greenmaskio/greenmask/internal/storages/builder" "github.com/greenmaskio/greenmask/internal/utils/logger" ) @@ -66,15 +68,23 @@ func run(cmd *cobra.Command, args []string) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + st, err := builder.GetStorage(ctx, &Config.Storage, &Config.Log) + if err != nil { + log.Fatal().Err(err).Msg("fatal") + } - validate, err := cmdInternals.NewValidate(Config, utils.DefaultTransformerRegistry) + validate, err := cmdInternals.NewValidate(Config, utils.DefaultTransformerRegistry, st) if err != nil { log.Fatal().Err(err).Msg("") } - if err := validate.Run(ctx); err != nil { + exitCode, err := validate.Run(ctx) + if err != nil { log.Fatal().Err(err).Msg("") } + if exitCode != 0 { + os.Exit(exitCode) + } } func init() { @@ -150,4 +160,13 @@ func init() { log.Fatal().Err(err).Msg("fatal") } + schemaFlagName := "schema" + Cmd.Flags().Bool( + schemaFlagName, false, "Make a schema diff between previous dump and the current state", + ) + flag = Cmd.Flags().Lookup(schemaFlagName) + if err := viper.BindPFlag("validate.schema", flag); err != nil { + log.Fatal().Err(err).Msg("fatal") + } + } diff --git a/internal/db/postgres/cmd/dump.go b/internal/db/postgres/cmd/dump.go index 1e0ad292..855516dc 100644 --- a/internal/db/postgres/cmd/dump.go +++ b/internal/db/postgres/cmd/dump.go @@ -366,6 +366,7 @@ func (d *Dump) mergeAndWriteToc(ctx context.Context, tx pgx.Tx) error { func (d *Dump) writeMetaData(ctx context.Context, startedAt, completedAt time.Time) error { metadata, err := storageDto.NewMetadata( d.resultToc, d.tocFileSize, startedAt, completedAt, d.config.Dump.Transformation, d.dumpedObjectSizes, + d.context.DatabaseSchema, ) if err != nil { return fmt.Errorf("unable build metadata: %w", err) diff --git a/internal/db/postgres/cmd/validate.go b/internal/db/postgres/cmd/validate.go index 55eab12b..dc60d5ca 100644 --- a/internal/db/postgres/cmd/validate.go +++ b/internal/db/postgres/cmd/validate.go @@ -4,11 +4,11 @@ import ( "bufio" "compress/gzip" "context" + "encoding/json" "errors" "fmt" "io" "os" - "path" "slices" "strconv" "strings" @@ -20,12 +20,12 @@ import ( runtimeContext "github.com/greenmaskio/greenmask/internal/db/postgres/context" "github.com/greenmaskio/greenmask/internal/db/postgres/dump_objects" "github.com/greenmaskio/greenmask/internal/db/postgres/pgcopy" + storageDto "github.com/greenmaskio/greenmask/internal/db/postgres/storage" "github.com/greenmaskio/greenmask/internal/db/postgres/toc" "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/custom" "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils" "github.com/greenmaskio/greenmask/internal/domains" "github.com/greenmaskio/greenmask/internal/storages" - "github.com/greenmaskio/greenmask/internal/storages/directory" "github.com/greenmaskio/greenmask/internal/utils/reader" "github.com/greenmaskio/greenmask/pkg/toolkit" ) @@ -40,51 +40,59 @@ const ( HorizontalTableFormat = "horizontal" ) +const ( + nonZeroExitCode = 1 + zeroExitCode = 0 +) + type closeFunc func() type Validate struct { *Dump - tmpDir string + tmpDir string + mainSt storages.Storager + exitCode int } -func NewValidate(cfg *domains.Config, registry *utils.TransformerRegistry) (*Validate, error) { - var st storages.Storager - st, err := directory.NewStorage(&directory.Config{Path: cfg.Common.TempDirectory}) - if err != nil { - return nil, fmt.Errorf("error initializing storage") - } - tmpDir := strconv.FormatInt(time.Now().UnixMilli(), 10) - st = st.SubStorage(tmpDir, true) +func NewValidate(cfg *domains.Config, registry *utils.TransformerRegistry, st storages.Storager) (*Validate, error) { + mainSt := st + tmpDirName := strconv.FormatInt(time.Now().UnixMilli(), 10) + st = st.SubStorage(tmpDirName, true) d := NewDump(cfg, st, registry) d.dumpIdSequence = toc.NewDumpSequence(0) + d.validate = true return &Validate{ - Dump: d, - tmpDir: path.Join(cfg.Common.TempDirectory, tmpDir), + Dump: d, + tmpDir: tmpDirName, + exitCode: zeroExitCode, + mainSt: mainSt, }, nil } -func (v *Validate) Run(ctx context.Context) error { +func (v *Validate) Run(ctx context.Context) (int, error) { defer func() { - // Deleting temp dir after closing it - if err := os.RemoveAll(v.tmpDir); err != nil { - log.Warn().Err(err).Msgf("unable to delete temp directory") + if !v.config.Validate.Diff { + return + } + if err := v.mainSt.Delete(ctx, v.tmpDir); err != nil { + log.Warn().Err(err).Msg("error deleting temporary directory") } }() if err := custom.BootstrapCustomTransformers(ctx, v.registry, v.config.CustomTransformers); err != nil { - return fmt.Errorf("error bootstraping custom transformers: %w", err) + return nonZeroExitCode, fmt.Errorf("error bootstraping custom transformers: %w", err) } dsn, err := v.pgDumpOptions.GetPgDSN() if err != nil { - return fmt.Errorf("cannot build connection string: %w", err) + return nonZeroExitCode, fmt.Errorf("cannot build connection string: %w", err) } conn, err := v.connect(ctx, dsn) if err != nil { - return err + return nonZeroExitCode, err } defer func() { if err := conn.Close(ctx); err != nil { @@ -94,7 +102,7 @@ func (v *Validate) Run(ctx context.Context) error { tx, err := v.startMainTx(ctx, conn) if err != nil { - return fmt.Errorf("cannot prepare backup transaction: %w", err) + return nonZeroExitCode, fmt.Errorf("cannot prepare backup transaction: %w", err) } defer func() { if err := tx.Rollback(ctx); err != nil { @@ -103,39 +111,43 @@ func (v *Validate) Run(ctx context.Context) error { }() if err = v.gatherPgFacts(ctx, tx); err != nil { - return fmt.Errorf("error gathering facts: %w", err) + return nonZeroExitCode, fmt.Errorf("error gathering facts: %w", err) } // Get list of tables to validate tablesToValidate, err := v.getTablesToValidate() if err != nil { - return err + return nonZeroExitCode, err } v.config.Dump.Transformation = tablesToValidate v.context, err = runtimeContext.NewRuntimeContext(ctx, tx, v.config.Dump.Transformation, v.registry, v.pgDumpOptions, v.version) if err != nil { - return fmt.Errorf("unable to build runtime context: %w", err) + return nonZeroExitCode, fmt.Errorf("unable to build runtime context: %w", err) } if err = v.printValidationWarnings(); err != nil { - return err + return nonZeroExitCode, err + } + + if err = v.diffWithPreviousSchema(ctx); err != nil { + return nonZeroExitCode, err } if !v.config.Validate.Data { - return nil + return v.exitCode, nil } if err = v.dumpTables(ctx); err != nil { - return err + return nonZeroExitCode, err } if err = v.print(ctx); err != nil { - return err + return nonZeroExitCode, err } - return nil + return v.exitCode, nil } func (v *Validate) print(ctx context.Context) error { @@ -328,6 +340,111 @@ func (v *Validate) getTablesToValidate() ([]*domains.Table, error) { return tablesToValidate, nil } +func (v *Validate) diffWithPreviousSchema(ctx context.Context) error { + if !v.config.Validate.Schema { + return nil + } + + dumpId, err := v.getPreviousDumpId(ctx) + if err != nil { + return fmt.Errorf("cannot get previous dump id: %w", err) + } + if dumpId == "" { + return nil + } + + md, err := v.getPreviousMetadata(ctx, dumpId) + if err != nil { + return fmt.Errorf("cannot get previous metadata: %w", err) + } + + diff := md.DatabaseSchema.Diff(v.context.DatabaseSchema) + if len(diff) > 0 { + v.exitCode = nonZeroExitCode + + err = v.printSchemaDiff(diff, dumpId) + if err != nil { + return fmt.Errorf("cannot print schema diff: %w", err) + } + } + + return nil +} + +func (v *Validate) printSchemaDiff(diff []*toolkit.DiffNode, previousDumpId string) error { + + if v.config.Validate.Format == JsonFormat { + data, err := json.Marshal(diff) + if err != nil { + return fmt.Errorf("cannot encode diff node: %w", err) + } + log.Warn(). + Str("PreviousDumpId", previousDumpId). + RawJSON("Diff", data). + Str("Hint", "Check schema changes before making new dump"). + Msg("Database schema has been changed") + return nil + } + log.Warn(). + Str("PreviousDumpId", previousDumpId). + Str("Hint", "Check schema changes before making new dump"). + Msg("Database schema has been changed") + for _, node := range diff { + log.Warn().Str("Event", node.Event).Any("Signature", node.Signature).Msg("") + } + + return nil +} + +func (v *Validate) getPreviousDumpId(ctx context.Context) (string, error) { + var backupNames []string + + _, dirs, err := v.mainSt.ListDir(ctx) + if err != nil { + return "", fmt.Errorf("cannot walk through directory: %w", err) + } + for _, dir := range dirs { + exists, err := dir.Exists(ctx, "metadata.json") + if err != nil { + return "", fmt.Errorf("cannot check file existence: %w", err) + } + if exists { + backupNames = append(backupNames, dir.Dirname()) + } + } + + slices.SortFunc( + backupNames, func(a, b string) int { + if a > b { + return -1 + } + return 1 + }, + ) + if len(backupNames) > 0 { + return backupNames[0], nil + } + return "", nil +} + +func (v *Validate) getPreviousMetadata(ctx context.Context, dumpId string) (*storageDto.Metadata, error) { + + st := v.mainSt.SubStorage(dumpId, true) + + f, err := st.GetObject(ctx, MetadataJsonFileName) + if err != nil { + return nil, fmt.Errorf("cannot open metadata file: %w", err) + } + defer f.Close() + + previousMetadata := &storageDto.Metadata{} + + if err = json.NewDecoder(f).Decode(&previousMetadata); err != nil { + return nil, fmt.Errorf("cannot decode metadata file: %w", err) + } + return previousMetadata, nil +} + func findTableBySchemaAndName(Transformations []*domains.Table, schemaName, tableName string) (*domains.Table, error) { var foundTable *domains.Table for _, t := range Transformations { diff --git a/internal/db/postgres/context/context.go b/internal/db/postgres/context/context.go index 7f72bdc0..64fb3957 100644 --- a/internal/db/postgres/context/context.go +++ b/internal/db/postgres/context/context.go @@ -42,6 +42,8 @@ type RuntimeContext struct { Registry *transformersUtils.TransformerRegistry // TypeMap - map of registered types including custom types. It's common for the whole runtime TypeMap *pgtype.Map + // DatabaseSchema - list of tables with columns - required for schema diff checking + DatabaseSchema toolkit.DatabaseSchema } // NewRuntimeContext - creating new runtime context. @@ -71,12 +73,18 @@ func NewRuntimeContext( return nil, fmt.Errorf("cannot build dump object list: %w", err) } + schema, err := getDatabaseSchema(ctx, tx, opt) + if err != nil { + return nil, fmt.Errorf("cannot get database schema: %w", err) + } + return &RuntimeContext{ Tables: tables, Types: types, DataSectionObjects: dataSectionObjects, Warnings: warnings, Registry: r, + DatabaseSchema: schema, }, nil } diff --git a/internal/db/postgres/context/pg_catalog.go b/internal/db/postgres/context/pg_catalog.go index 67af4a8d..5154d1a7 100644 --- a/internal/db/postgres/context/pg_catalog.go +++ b/internal/db/postgres/context/pg_catalog.go @@ -394,3 +394,81 @@ func BuildTableSearchQuery( return fmt.Sprintf(totalQuery, tableDataExclusionCond, tableInclusionCond, tableExclusionCond, schemaInclusionCond, schemaExclusionCond, foreignDataInclusionCond), nil } + +func BuildSchemaIntrospectionQuery(includeTable, excludeTable, includeForeignData, + includeSchema, excludeSchema []string, +) (string, error) { + + tableInclusionCond, err := renderRelationCond(includeTable, trueCond) + if err != nil { + return "", err + } + tableExclusionCond, err := renderRelationCond(excludeTable, falseCond) + if err != nil { + return "", err + } + schemaInclusionCond, err := renderNamespaceCond(includeSchema, trueCond) + if err != nil { + return "", err + } + schemaExclusionCond, err := renderNamespaceCond(excludeSchema, falseCond) + if err != nil { + return "", err + } + + foreignDataInclusionCond, err := renderForeignDataCond(includeForeignData, falseCond) + if err != nil { + return "", err + } + + totalQuery := ` + SELECT c.oid::TEXT::INT, + n.nspname as "Schema", + c.relname as "Name", + c.relkind::TEXT as "RelKind", + (coalesce(pc.oid::INT, 0)) as "RootPtOid", + (WITH RECURSIVE part_tables AS (SELECT pg_inherits.inhrelid AS parent_oid, + nmsp_child.nspname AS child_schema, + child.oid AS child_oid, + child.relname AS child, + child.relkind as kind + FROM pg_inherits + JOIN pg_class child ON pg_inherits.inhrelid = child.oid + JOIN pg_namespace nmsp_child ON nmsp_child.oid = child.relnamespace + WHERE pg_inherits.inhparent = c.oid + UNION + SELECT pt.parent_oid, + nmsp_child.nspname AS child_schema, + child.oid AS child_oid, + child.relname AS child, + child.relkind as kind + FROM part_tables pt + JOIN pg_inherits inh ON pt.child_oid = inh.inhparent + JOIN pg_class child ON inh.inhrelid = child.oid + JOIN pg_namespace nmsp_child ON nmsp_child.oid = child.relnamespace + WHERE pt.kind = 'p') + SELECT array_agg(child_oid::INT) AS oid + FROM part_tables + WHERE kind != 'p') as "ChildrenPtOids" + FROM pg_catalog.pg_class c + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + LEFT JOIN pg_catalog.pg_inherits i ON i.inhrelid = c.oid + LEFT JOIN pg_catalog.pg_class pc ON i.inhparent = pc.oid AND pc.relkind = 'p' + LEFT JOIN pg_catalog.pg_namespace pn ON pc.relnamespace = pn.oid + LEFT JOIN pg_catalog.pg_foreign_table ft ON c.oid = ft.ftrelid + LEFT JOIN pg_catalog.pg_foreign_server s ON s.oid = ft.ftserver + WHERE c.relkind IN ('r', 'f', 'p') + AND %s -- relname inclusion + AND NOT %s -- relname exclusion + AND %s -- schema inclusion + AND NOT %s -- schema exclusion + AND (s.srvname ISNULL OR %s) -- include foreign data + AND n.nspname <> 'pg_catalog' + AND n.nspname !~ '^pg_toast' + AND n.nspname <> 'information_schema' + ` + + return fmt.Sprintf(totalQuery, tableInclusionCond, tableExclusionCond, + schemaInclusionCond, schemaExclusionCond, foreignDataInclusionCond), nil + +} diff --git a/internal/db/postgres/context/pg_catalog_test.go b/internal/db/postgres/context/pg_catalog_test.go index f2479b57..801afe10 100644 --- a/internal/db/postgres/context/pg_catalog_test.go +++ b/internal/db/postgres/context/pg_catalog_test.go @@ -22,15 +22,26 @@ import ( ) func TestBuildTableSearchQuery(t *testing.T) { - var includeTable, excludeTable, excludeTableData, includeForeignData, includeSchema, excludeSchema []string - includeTable = []string{"bookings.*"} - excludeTable = []string{"booki*.boarding_pas*", "b?*.seats"} - includeSchema = []string{"booki*"} - excludeSchema = []string{"public*[[:digit:]]*1"} - excludeTableData = []string{"bookings.flights"} - includeForeignData = []string{"myserver"} + includeTable := []string{"bookings.*"} + excludeTable := []string{"booki*.boarding_pas*", "b?*.seats"} + includeSchema := []string{"booki*"} + excludeSchema := []string{"public*[[:digit:]]*1"} + excludeTableData := []string{"bookings.flights"} + includeForeignData := []string{"myserver"} res, err := BuildTableSearchQuery(includeTable, excludeTable, excludeTableData, includeForeignData, includeSchema, excludeSchema) assert.NoError(t, err) fmt.Println(res) } + +func TestBuildSchemaIntrospectionQuery(t *testing.T) { + includeTable := []string{"bookings.*"} + excludeTable := []string{"booki*.boarding_pas*", "b?*.seats"} + includeSchema := []string{"booki*"} + excludeSchema := []string{"public*[[:digit:]]*1"} + includeForeignData := []string{"myserver"} + res, err := BuildSchemaIntrospectionQuery(includeTable, excludeTable, + includeForeignData, includeSchema, excludeSchema) + assert.NoError(t, err) + fmt.Println(res) +} diff --git a/internal/db/postgres/context/schema.go b/internal/db/postgres/context/schema.go new file mode 100644 index 00000000..893d436d --- /dev/null +++ b/internal/db/postgres/context/schema.go @@ -0,0 +1,52 @@ +package context + +import ( + "context" + + "github.com/jackc/pgx/v5" + + "github.com/greenmaskio/greenmask/internal/db/postgres/pgdump" + "github.com/greenmaskio/greenmask/pkg/toolkit" +) + +func getDatabaseSchema( + ctx context.Context, tx pgx.Tx, options *pgdump.Options, +) ([]*toolkit.Table, error) { + var res []*toolkit.Table + query, err := BuildSchemaIntrospectionQuery( + options.Table, options.ExcludeTable, + options.IncludeForeignData, options.Schema, + options.ExcludeSchema, + ) + if err != nil { + return nil, err + } + rows, err := tx.Query(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + table := &toolkit.Table{} + err = rows.Scan( + &table.Oid, &table.Schema, &table.Name, &table.Kind, + &table.Parent, &table.Children, + ) + if err != nil { + return nil, err + } + res = append(res, table) + } + + // fill columns + for _, table := range res { + columns, err := getColumnsConfig(ctx, tx, table.Oid) + if err != nil { + return nil, err + } + table.Columns = columns + } + + return res, nil +} diff --git a/internal/db/postgres/storage/metadata_json.go b/internal/db/postgres/storage/metadata_json.go index 7a7aa7be..42fc2189 100644 --- a/internal/db/postgres/storage/metadata_json.go +++ b/internal/db/postgres/storage/metadata_json.go @@ -20,6 +20,8 @@ import ( "github.com/rs/zerolog/log" + "github.com/greenmaskio/greenmask/pkg/toolkit" + "github.com/greenmaskio/greenmask/internal/db/postgres/toc" "github.com/greenmaskio/greenmask/internal/domains" ) @@ -59,19 +61,20 @@ type Entry struct { } type Metadata struct { - StartedAt time.Time `yaml:"startedAt" json:"startedAt"` - CompletedAt time.Time `yaml:"completedAt" json:"completedAt"` - OriginalSize int64 `yaml:"originalSize" json:"originalSize"` - CompressedSize int64 `yaml:"compressedSize" json:"compressedSize"` - Transformers []*domains.Table `yaml:"transformers" json:"transformers"` - Header Header `yaml:"header" json:"header"` - Entries []*Entry `yaml:"entries" json:"entries"` + StartedAt time.Time `yaml:"startedAt" json:"startedAt"` + CompletedAt time.Time `yaml:"completedAt" json:"completedAt"` + OriginalSize int64 `yaml:"originalSize" json:"originalSize"` + CompressedSize int64 `yaml:"compressedSize" json:"compressedSize"` + Transformers []*domains.Table `yaml:"transformers" json:"transformers"` + DatabaseSchema toolkit.DatabaseSchema `yaml:"database_schema" json:"database_schema"` + Header Header `yaml:"header" json:"header"` + Entries []*Entry `yaml:"entries" json:"entries"` } func NewMetadata( tocObj *toc.Toc, tocFileSize int64, startedAt, completedAt time.Time, transformers []*domains.Table, - stats map[int32]ObjectSizeStat, + stats map[int32]ObjectSizeStat, databaseSchema []*toolkit.Table, ) (*Metadata, error) { var format string @@ -159,6 +162,7 @@ func NewMetadata( StartedAt: startedAt, CompletedAt: completedAt, Transformers: transformers, + DatabaseSchema: databaseSchema, Header: Header{ CreationDate: tocObj.Header.CrtmDateTime.Time(), DbName: *tocObj.Header.ArchDbName, diff --git a/internal/domains/config.go b/internal/domains/config.go index c179c457..56436ea0 100644 --- a/internal/domains/config.go +++ b/internal/domains/config.go @@ -57,6 +57,7 @@ type Validate struct { Tables []string `mapstructure:"tables" yaml:"tables" json:"tables,omitempty"` Data bool `mapstructure:"data" yaml:"data" json:"data,omitempty"` Diff bool `mapstructure:"diff" yaml:"diff" json:"diff,omitempty"` + Schema bool `mapstructure:"schema" yaml:"schema" json:"schema,omitempty"` RowsLimit uint64 `mapstructure:"rows_limit" yaml:"rows_limit" json:"rows_limit,omitempty"` ResolvedWarnings []string `mapstructure:"resolved_warnings" yaml:"resolved_warnings" json:"resolved_warnings,omitempty"` TableFormat string `mapstructure:"table_format" yaml:"table_format" json:"table_format,omitempty"` diff --git a/pkg/toolkit/database_schema.go b/pkg/toolkit/database_schema.go new file mode 100644 index 00000000..d7c7a845 --- /dev/null +++ b/pkg/toolkit/database_schema.go @@ -0,0 +1,191 @@ +package toolkit + +import ( + "fmt" + "slices" + + "github.com/rs/zerolog/log" +) + +const ( + TableRemovedDiffEvent = "TableRemoved" + TableMovedToAnotherSchemaDiffEvent = "TableMovedToAnotherSchema" + TableRenamedDiffEvent = "TableRenamed" + TableCreatedDiffEvent = "TableCreated" + ColumnCreatedDiffEvent = "ColumnCreated" + ColumnRenamedDiffEvent = "ColumnRenamed" + ColumnTypeChangedDiffEvent = "ColumnTypeChanged" +) + +type DiffNode struct { + Event string `json:"event,omitempty"` + Signature map[string]string `json:"signature,omitempty"` +} + +type DatabaseSchema []*Table + +func (ds DatabaseSchema) Diff(current DatabaseSchema) (res []*DiffNode) { + + for _, currentState := range current { + if currentState.Kind == "r" && currentState.Parent != 0 { + continue + } + + previousState, ok := ds.getTableByOid(currentState.Oid) + if !ok { + // if the table was not found by oid it was likely re-created + // Consider: Should we notify about table re-creation? + previousState, ok = ds.getTableByName(currentState.Schema, currentState.Name) + } + + if currentState.Name == "test" { + log.Debug() + } + + if !ok { + signature := map[string]string{ + "SchemaName": currentState.Schema, + "TableName": currentState.Name, + "TableOid": fmt.Sprintf("%d", currentState.Oid), + } + res = append(res, &DiffNode{Event: TableCreatedDiffEvent, Signature: signature}) + continue + } + + res = append(res, diffTables(previousState, currentState)...) + + } + return res +} + +func (ds DatabaseSchema) getTableByOid(oid Oid) (*Table, bool) { + idx := slices.IndexFunc(ds, func(table *Table) bool { + return table.Oid == oid + }) + if idx == -1 { + return nil, false + } + return ds[idx], true +} + +func (ds DatabaseSchema) getTableByName(schemaName, tableName string) (*Table, bool) { + idx := slices.IndexFunc(ds, func(table *Table) bool { + return table.Schema == schemaName && table.Name == tableName + }) + if idx == -1 { + return nil, false + } + return ds[idx], true +} + +func diffTables(previous, current *Table) (res []*DiffNode) { + if previous.Schema != current.Schema { + node := &DiffNode{ + Event: TableMovedToAnotherSchemaDiffEvent, + + Signature: map[string]string{ + "PreviousSchemaName": previous.Schema, + "CurrentSchemaName": current.Schema, + "TableName": current.Name, + "TableOid": fmt.Sprintf("%d", previous.Oid), + }, + } + res = append(res, node) + } + + if previous.Name != current.Name { + node := &DiffNode{ + Event: TableRenamedDiffEvent, + + Signature: map[string]string{ + "PreviousTableName": previous.Name, + "CurrentTableName": current.Name, + "SchemaName": current.Schema, + "TableOid": fmt.Sprintf("%d", previous.Oid), + }, + } + res = append(res, node) + } + + res = append(res, diffTableColumns(previous, current)...) + + return +} + +func diffTableColumns(previous, current *Table) (res []*DiffNode) { + for _, currentStateColumn := range current.Columns { + + previousStateColumn, ok := findColumnByAttNum(previous, currentStateColumn.Num) + if !ok { + previousStateColumn, ok = findColumnByName(previous, currentStateColumn.Name) + } + + if !ok { + node := &DiffNode{ + Event: ColumnCreatedDiffEvent, + + Signature: map[string]string{ + "TableSchema": previous.Schema, + "TableName": previous.Name, + "ColumnName": currentStateColumn.Name, + // TODO: Replace it with type def such as NUMERIC(10, 2) VARCHAR(128), etc. + "ColumnType": currentStateColumn.TypeName, + }, + } + res = append(res, node) + continue + } + + if currentStateColumn.Name != previousStateColumn.Name { + node := &DiffNode{ + Event: ColumnRenamedDiffEvent, + + Signature: map[string]string{ + "TableSchema": previous.Schema, + "TableName": previous.Name, + "PreviousColumnName": previousStateColumn.Name, + "CurrentColumnName": currentStateColumn.Name, + }, + } + res = append(res, node) + } + + if currentStateColumn.TypeOid != previousStateColumn.TypeOid { + node := &DiffNode{ + Event: ColumnTypeChangedDiffEvent, + + Signature: map[string]string{ + "TableSchema": previous.Schema, + "TableName": previous.Name, + "ColumnName": previousStateColumn.Name, + "PreviousColumnType": previousStateColumn.TypeName, + "PreviousColumnTypeOid": fmt.Sprintf("%d", previousStateColumn.TypeOid), + "CurrentColumnType": currentStateColumn.TypeName, + "CurrentColumnTypeOid": fmt.Sprintf("%d", currentStateColumn.TypeOid), + }, + } + res = append(res, node) + } + } + return +} + +func findColumnByAttNum(t *Table, num AttNum) (*Column, bool) { + idx := slices.IndexFunc(t.Columns, func(column *Column) bool { + return column.Num == num + }) + if idx == -1 { + return nil, false + } + return t.Columns[idx], true +} + +func findColumnByName(t *Table, name string) (*Column, bool) { + idx := slices.IndexFunc(t.Columns, func(column *Column) bool { + return column.Name == name + }) + if idx == -1 { + return nil, false + } + return t.Columns[idx], true +} diff --git a/pkg/toolkit/table.go b/pkg/toolkit/table.go index a1c15910..7536dc5f 100644 --- a/pkg/toolkit/table.go +++ b/pkg/toolkit/table.go @@ -21,6 +21,9 @@ type Table struct { Name string `json:"name"` Oid Oid `json:"oid"` Columns []*Column `json:"columns"` + Kind string `json:"kind"` + Parent Oid `json:"parent"` + Children []Oid `json:"children"` Constraints []Constraint `json:"-"` } From 356475018860b26edc497cab21ccb762cc02398f Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Fri, 15 Mar 2024 19:02:33 +0200 Subject: [PATCH 2/3] Added diff event messages --- internal/db/postgres/cmd/validate.go | 6 +++++- pkg/toolkit/database_schema.go | 10 +++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/internal/db/postgres/cmd/validate.go b/internal/db/postgres/cmd/validate.go index dc60d5ca..a775ad54 100644 --- a/internal/db/postgres/cmd/validate.go +++ b/internal/db/postgres/cmd/validate.go @@ -389,8 +389,12 @@ func (v *Validate) printSchemaDiff(diff []*toolkit.DiffNode, previousDumpId stri Str("PreviousDumpId", previousDumpId). Str("Hint", "Check schema changes before making new dump"). Msg("Database schema has been changed") + for _, node := range diff { - log.Warn().Str("Event", node.Event).Any("Signature", node.Signature).Msg("") + log.Warn(). + Str("Event", node.Event). + Any("Signature", node.Signature). + Msg(toolkit.DiffEventMsgs[node.Event]) } return nil diff --git a/pkg/toolkit/database_schema.go b/pkg/toolkit/database_schema.go index d7c7a845..ede7e298 100644 --- a/pkg/toolkit/database_schema.go +++ b/pkg/toolkit/database_schema.go @@ -8,7 +8,6 @@ import ( ) const ( - TableRemovedDiffEvent = "TableRemoved" TableMovedToAnotherSchemaDiffEvent = "TableMovedToAnotherSchema" TableRenamedDiffEvent = "TableRenamed" TableCreatedDiffEvent = "TableCreated" @@ -17,6 +16,15 @@ const ( ColumnTypeChangedDiffEvent = "ColumnTypeChanged" ) +var DiffEventMsgs = map[string]string{ + TableMovedToAnotherSchemaDiffEvent: "Table moved to another schema", + TableRenamedDiffEvent: "Table renamed", + TableCreatedDiffEvent: "Table created", + ColumnCreatedDiffEvent: "Column created", + ColumnRenamedDiffEvent: "Column renamed", + ColumnTypeChangedDiffEvent: "Column type changed", +} + type DiffNode struct { Event string `json:"event,omitempty"` Signature map[string]string `json:"signature,omitempty"` From 770d3e134bbcf812d82a9ea41e5feb505154e430 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Fri, 15 Mar 2024 19:22:14 +0200 Subject: [PATCH 3/3] Removed debug artifacts --- pkg/toolkit/database_schema.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pkg/toolkit/database_schema.go b/pkg/toolkit/database_schema.go index ede7e298..df58cd94 100644 --- a/pkg/toolkit/database_schema.go +++ b/pkg/toolkit/database_schema.go @@ -3,8 +3,6 @@ package toolkit import ( "fmt" "slices" - - "github.com/rs/zerolog/log" ) const ( @@ -46,10 +44,6 @@ func (ds DatabaseSchema) Diff(current DatabaseSchema) (res []*DiffNode) { previousState, ok = ds.getTableByName(currentState.Schema, currentState.Name) } - if currentState.Name == "test" { - log.Debug() - } - if !ok { signature := map[string]string{ "SchemaName": currentState.Schema,