diff --git a/cmd/greenmask/cmd/restore/restore.go b/cmd/greenmask/cmd/restore/restore.go index af642a04..c6c489d9 100644 --- a/cmd/greenmask/cmd/restore/restore.go +++ b/cmd/greenmask/cmd/restore/restore.go @@ -20,11 +20,12 @@ import ( "path" "slices" - "github.com/greenmaskio/greenmask/internal/storages" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "github.com/spf13/viper" + "github.com/greenmaskio/greenmask/internal/storages" + cmdInternals "github.com/greenmaskio/greenmask/internal/db/postgres/cmd" pgDomains "github.com/greenmaskio/greenmask/internal/domains" "github.com/greenmaskio/greenmask/internal/storages/builder" @@ -165,7 +166,7 @@ func init() { Cmd.Flags().BoolP("use-set-session-authorization", "", false, "use SET SESSION AUTHORIZATION commands instead of ALTER OWNER commands to set ownership") Cmd.Flags().BoolP("on-conflict-do-nothing", "", false, "add ON CONFLICT DO NOTHING to INSERT commands") Cmd.Flags().BoolP("inserts", "", false, "restore data as INSERT commands, rather than COPY") - Cmd.Flags().BoolP("topological-sort", "", false, "restore tables in topological order, ensuring that dependent tables are not restored until the tables they depend on have been restored") + Cmd.Flags().BoolP("restore-in-order", "", false, "restore tables in topological order, ensuring that dependent tables are not restored until the tables they depend on have been restored") // Connection options: Cmd.Flags().StringP("host", "h", "/var/run/postgres", "database server host or socket directory") @@ -179,7 +180,7 @@ func init() { "no-owner", "function", "schema-only", "superuser", "table", "trigger", "no-privileges", "single-transaction", "disable-triggers", "enable-row-security", "if-exists", "no-comments", "no-data-for-failed-tables", "no-security-labels", "no-subscriptions", "no-table-access-method", "no-tablespaces", "section", - "strict-names", "use-set-session-authorization", "inserts", "on-conflict-do-nothing", "topological-sort", + "strict-names", "use-set-session-authorization", "inserts", "on-conflict-do-nothing", "restore-in-order", "host", "port", "username", } { diff --git a/internal/db/postgres/cmd/dump.go b/internal/db/postgres/cmd/dump.go index 9b3f2b66..a2765f99 100644 --- a/internal/db/postgres/cmd/dump.go +++ b/internal/db/postgres/cmd/dump.go @@ -337,10 +337,14 @@ func (d *Dump) setDumpDependenciesGraph(tables []*entries.Table) { t := tables[idx] // Create dependencies graph with DumpId sequence for easier restoration coordination d.dumpDependenciesGraph[t.DumpId] = []int32{} - for _, dep := range graph[oid] { + for _, depOid := range graph[oid] { + // If dependency table is not in the tables slice, it is likely excluded + if !slices.Contains(sortedOids, depOid) { + continue + } // Find dependency table in the tables slice by OID depIdx := slices.IndexFunc(tables, func(depTable *entries.Table) bool { - return depTable.Oid == dep + return depTable.Oid == depOid }) if depIdx == -1 { panic("table not found") diff --git a/internal/db/postgres/cmd/restore.go b/internal/db/postgres/cmd/restore.go index a3592886..e7cdce45 100644 --- a/internal/db/postgres/cmd/restore.go +++ b/internal/db/postgres/cmd/restore.go @@ -28,14 +28,16 @@ import ( "regexp" "slices" "strconv" + "sync" "time" - "github.com/greenmaskio/greenmask/internal/domains" "github.com/jackc/pgx/v5" "github.com/rs/zerolog/log" "golang.org/x/sync/errgroup" "gopkg.in/yaml.v3" + "github.com/greenmaskio/greenmask/internal/domains" + "github.com/greenmaskio/greenmask/internal/db/postgres/pgrestore" "github.com/greenmaskio/greenmask/internal/db/postgres/restorers" "github.com/greenmaskio/greenmask/internal/db/postgres/storage" @@ -68,6 +70,8 @@ const ( const metadataObjectName = "metadata.json" +const dependenciesCheckInterval = 15 * time.Millisecond + type Restore struct { binPath string dsn string @@ -80,9 +84,11 @@ type Restore struct { tmpDir string cfg *domains.Restore metadata *storage.Metadata + mx *sync.RWMutex preDataClenUpToc string postDataClenUpToc string + restoredDumpIds map[int32]bool } func NewRestore( @@ -90,14 +96,16 @@ func NewRestore( ) *Restore { return &Restore{ - binPath: binPath, - st: st, - pgRestore: pgrestore.NewPgRestore(binPath), - restoreOpt: &cfg.PgRestoreOptions, - scripts: s, - tmpDir: path.Join(tmpDir, fmt.Sprintf("%d", time.Now().UnixNano())), - cfg: cfg, - metadata: &storage.Metadata{}, + binPath: binPath, + st: st, + pgRestore: pgrestore.NewPgRestore(binPath), + restoreOpt: &cfg.PgRestoreOptions, + scripts: s, + tmpDir: path.Join(tmpDir, fmt.Sprintf("%d", time.Now().UnixNano())), + cfg: cfg, + metadata: &storage.Metadata{}, + restoredDumpIds: make(map[int32]bool), + mx: &sync.RWMutex{}, } } @@ -132,6 +140,27 @@ func (r *Restore) Run(ctx context.Context) error { return nil } +func (r *Restore) putDumpId(task restorers.RestoreTask) { + tableTask, ok := task.(*restorers.TableRestorer) + if !ok { + return + } + r.mx.Lock() + r.restoredDumpIds[tableTask.Entry.DumpId] = true + r.mx.Unlock() +} + +func (r *Restore) dependenciesAreRestored(deps []int32) bool { + r.mx.RLock() + defer r.mx.RUnlock() + for _, id := range deps { + if !r.restoredDumpIds[id] { + return false + } + } + return true +} + func (r *Restore) readMetadata(ctx context.Context) error { f, err := r.st.GetObject(ctx, metadataObjectName) if err != nil { @@ -532,7 +561,10 @@ func (r *Restore) sortTocEntriesInTopoOrder() []*toc.Entry { lastTableIdx := slices.IndexFunc(dataEntries, func(entry *toc.Entry) bool { return *entry.Desc == toc.SequenceSetDesc || *entry.Desc == toc.BlobsDesc }) - tableEntries := dataEntries[:lastTableIdx] + tableEntries := dataEntries + if lastTableIdx != -1 { + tableEntries = dataEntries[:lastTableIdx] + } sortedTablesEntries := make([]*toc.Entry, 0, len(tableEntries)) for _, dumpId := range r.metadata.DumpIdsOrder { idx := slices.IndexFunc(tableEntries, func(entry *toc.Entry) bool { @@ -548,16 +580,32 @@ func (r *Restore) sortTocEntriesInTopoOrder() []*toc.Entry { res = append(res, r.tocObj.Entries[:preDataEnd+1]...) res = append(res, sortedTablesEntries...) - res = append(res, dataEntries[lastTableIdx:]...) + if lastTableIdx != -1 { + res = append(res, dataEntries[lastTableIdx:]...) + } res = append(res, r.tocObj.Entries[postDataStart:]...) return res } +func (r *Restore) waitDependenciesAreRestore(ctx context.Context, deps []int32) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if r.dependenciesAreRestored(deps) { + return nil + } + time.Sleep(dependenciesCheckInterval) + } +} + func (r *Restore) taskPusher(ctx context.Context, tasks chan restorers.RestoreTask) func() error { return func() error { defer close(tasks) tocEntries := r.tocObj.Entries - if r.restoreOpt.TopologicalSort { + if r.restoreOpt.RestoreInOrder { tocEntries = r.sortTocEntriesInTopoOrder() } for _, entry := range tocEntries { @@ -573,6 +621,13 @@ func (r *Restore) taskPusher(ctx context.Context, tasks chan restorers.RestoreTa continue } + if r.restoreOpt.RestoreInOrder && r.restoreOpt.Jobs > 1 { + deps := r.metadata.DependenciesGraph[entry.DumpId] + if err := r.waitDependenciesAreRestore(ctx, deps); err != nil { + return fmt.Errorf("cannot wait for dependencies are restored: %w", err) + } + } + var task restorers.RestoreTask switch *entry.Desc { case toc.TableDataDesc: @@ -640,6 +695,7 @@ func (r *Restore) restoreWorker(ctx context.Context, tasks <-chan restorers.Rest if err = task.Execute(ctx, conn); err != nil { return fmt.Errorf("unable to perform restoration task (worker %d restoring %s): %w", id, task.DebugInfo(), err) } + r.putDumpId(task) log.Debug(). Int("workerId", id). Str("objectName", task.DebugInfo()). diff --git a/internal/db/postgres/context/context.go b/internal/db/postgres/context/context.go index abf0a785..c534a813 100644 --- a/internal/db/postgres/context/context.go +++ b/internal/db/postgres/context/context.go @@ -21,14 +21,15 @@ import ( "os" "slices" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/greenmaskio/greenmask/internal/db/postgres/entries" "github.com/greenmaskio/greenmask/internal/db/postgres/pgdump" "github.com/greenmaskio/greenmask/internal/db/postgres/subset" transformersUtils "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils" "github.com/greenmaskio/greenmask/internal/domains" "github.com/greenmaskio/greenmask/pkg/toolkit" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" ) const defaultTransformerCostMultiplier = 0.03 diff --git a/internal/db/postgres/context/pg_catalog.go b/internal/db/postgres/context/pg_catalog.go index 094a3742..c7a68c6a 100644 --- a/internal/db/postgres/context/pg_catalog.go +++ b/internal/db/postgres/context/pg_catalog.go @@ -54,7 +54,9 @@ func getDumpObjects( return tables, sequesnces, lo, nil } -func getTables(ctx context.Context, version int, tx pgx.Tx, options *pgdump.Options, config map[toolkit.Oid]*entries.Table) ([]*entries.Table, []*entries.Sequence, error) { +func getTables( + ctx context.Context, version int, tx pgx.Tx, options *pgdump.Options, config map[toolkit.Oid]*entries.Table, +) ([]*entries.Table, []*entries.Sequence, error) { // Building relation search query using regexp adaptation rules and pre-defined query templates // TODO: Refactor it to gotemplate query, err := BuildTableSearchQuery(options.Table, options.ExcludeTable, diff --git a/internal/db/postgres/pgrestore/pgrestore.go b/internal/db/postgres/pgrestore/pgrestore.go index 9e547128..ad01128b 100644 --- a/internal/db/postgres/pgrestore/pgrestore.go +++ b/internal/db/postgres/pgrestore/pgrestore.go @@ -94,7 +94,7 @@ type Options struct { // statements on fly if needed OnConflictDoNothing bool `mapstructure:"on-conflict-do-nothing"` Inserts bool `mapstructure:"inserts"` - TopologicalSort bool `mapstructure:"topological-sort"` + RestoreInOrder bool `mapstructure:"restore-in-order"` // Connection options: Host string `mapstructure:"host"` diff --git a/internal/db/postgres/subset/graph.go b/internal/db/postgres/subset/graph.go index c005db1d..1b41f281 100644 --- a/internal/db/postgres/subset/graph.go +++ b/internal/db/postgres/subset/graph.go @@ -7,10 +7,11 @@ import ( "slices" "strings" - "github.com/greenmaskio/greenmask/pkg/toolkit" "github.com/jackc/pgx/v5" "github.com/rs/zerolog/log" + "github.com/greenmaskio/greenmask/pkg/toolkit" + "github.com/greenmaskio/greenmask/internal/db/postgres/entries" ) @@ -49,6 +50,8 @@ type Graph struct { scc []*Component // condensedGraph - the condensed graph representation of the DB tables condensedGraph [][]*CondensedEdge + // reversedCondensedGraph - the reversed condensed graph representation of the DB tables + reversedCondensedGraph [][]*CondensedEdge // componentsToOriginalVertexes - the mapping condensed graph vertexes to the original graph vertexes componentsToOriginalVertexes map[int][]int // paths - the subset paths for the tables. The key is the vertex index in the graph and the value is the path for @@ -238,6 +241,7 @@ func (g *Graph) buildCondensedGraph() { // 3. Build condensed graph g.condensedGraph = make([][]*CondensedEdge, g.sscCount) + g.reversedCondensedGraph = make([][]*CondensedEdge, g.sscCount) var condensedEdgeIdxSeq int for _, edge := range g.edges { if _, ok := condensedEdges[edge.id]; ok { @@ -260,6 +264,8 @@ func (g *Graph) buildCondensedGraph() { ) condensedEdge := NewCondensedEdge(condensedEdgeIdxSeq, fromLink, toLink, edge) g.condensedGraph[fromLinkIdx] = append(g.condensedGraph[fromLinkIdx], condensedEdge) + reversedEdges := NewCondensedEdge(condensedEdgeIdxSeq, toLink, fromLink, edge) + g.reversedCondensedGraph[toLinkIdx] = append(g.reversedCondensedGraph[toLinkIdx], reversedEdges) condensedEdgeIdxSeq++ } } @@ -475,7 +481,7 @@ func (g *Graph) generateQueryForTables(path *Path, scopeEdge *ScopeEdge) string } func (g *Graph) GetSortedTablesAndDependenciesGraph() ([]toolkit.Oid, map[toolkit.Oid][]toolkit.Oid) { - condensedEdges := sortCondensedEdges(g.condensedGraph) + condensedEdges := sortCondensedEdges(g.reversedCondensedGraph) var tables []toolkit.Oid dependenciesGraph := make(map[toolkit.Oid][]toolkit.Oid) for _, condEdgeIdx := range condensedEdges { @@ -487,7 +493,11 @@ func (g *Graph) GetSortedTablesAndDependenciesGraph() ([]toolkit.Oid, map[toolki tables = append(tables, componentTables...) } - for _, edge := range g.condensedGraph { + for idx, edge := range g.reversedCondensedGraph { + for _, srcTable := range g.scc[idx].tables { + dependenciesGraph[srcTable.Oid] = make([]toolkit.Oid, 0) + } + for _, e := range edge { for _, srcTable := range e.to.component.tables { for _, dstTable := range e.from.component.tables { @@ -497,6 +507,8 @@ func (g *Graph) GetSortedTablesAndDependenciesGraph() ([]toolkit.Oid, map[toolki } } + slices.Reverse(tables) + return tables, dependenciesGraph }