Skip to content

Commit

Permalink
feat: restoration in topological order
Browse files Browse the repository at this point in the history
* Added reversed condensed graph
* Renamed parameter name
* Implemented dependencies check coordinator when restore command runs with --restore-in-order
  • Loading branch information
wwoytenko committed Aug 16, 2024
1 parent 6a25987 commit 950360a
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 24 deletions.
7 changes: 4 additions & 3 deletions cmd/greenmask/cmd/restore/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ import (
"path"
"slices"

"github.com/greenmaskio/greenmask/internal/storages"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"

"github.com/greenmaskio/greenmask/internal/storages"

cmdInternals "github.com/greenmaskio/greenmask/internal/db/postgres/cmd"
pgDomains "github.com/greenmaskio/greenmask/internal/domains"
"github.com/greenmaskio/greenmask/internal/storages/builder"
Expand Down Expand Up @@ -165,7 +166,7 @@ func init() {
Cmd.Flags().BoolP("use-set-session-authorization", "", false, "use SET SESSION AUTHORIZATION commands instead of ALTER OWNER commands to set ownership")
Cmd.Flags().BoolP("on-conflict-do-nothing", "", false, "add ON CONFLICT DO NOTHING to INSERT commands")
Cmd.Flags().BoolP("inserts", "", false, "restore data as INSERT commands, rather than COPY")
Cmd.Flags().BoolP("topological-sort", "", false, "restore tables in topological order, ensuring that dependent tables are not restored until the tables they depend on have been restored")
Cmd.Flags().BoolP("restore-in-order", "", false, "restore tables in topological order, ensuring that dependent tables are not restored until the tables they depend on have been restored")

// Connection options:
Cmd.Flags().StringP("host", "h", "/var/run/postgres", "database server host or socket directory")
Expand All @@ -179,7 +180,7 @@ func init() {
"no-owner", "function", "schema-only", "superuser", "table", "trigger", "no-privileges", "single-transaction",
"disable-triggers", "enable-row-security", "if-exists", "no-comments", "no-data-for-failed-tables",
"no-security-labels", "no-subscriptions", "no-table-access-method", "no-tablespaces", "section",
"strict-names", "use-set-session-authorization", "inserts", "on-conflict-do-nothing", "topological-sort",
"strict-names", "use-set-session-authorization", "inserts", "on-conflict-do-nothing", "restore-in-order",

"host", "port", "username",
} {
Expand Down
8 changes: 6 additions & 2 deletions internal/db/postgres/cmd/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,14 @@ func (d *Dump) setDumpDependenciesGraph(tables []*entries.Table) {
t := tables[idx]
// Create dependencies graph with DumpId sequence for easier restoration coordination
d.dumpDependenciesGraph[t.DumpId] = []int32{}
for _, dep := range graph[oid] {
for _, depOid := range graph[oid] {
// If dependency table is not in the tables slice, it is likely excluded
if !slices.Contains(sortedOids, depOid) {
continue
}
// Find dependency table in the tables slice by OID
depIdx := slices.IndexFunc(tables, func(depTable *entries.Table) bool {
return depTable.Oid == dep
return depTable.Oid == depOid
})
if depIdx == -1 {
panic("table not found")
Expand Down
80 changes: 68 additions & 12 deletions internal/db/postgres/cmd/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ import (
"regexp"
"slices"
"strconv"
"sync"
"time"

"github.com/greenmaskio/greenmask/internal/domains"
"github.com/jackc/pgx/v5"
"github.com/rs/zerolog/log"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v3"

"github.com/greenmaskio/greenmask/internal/domains"

"github.com/greenmaskio/greenmask/internal/db/postgres/pgrestore"
"github.com/greenmaskio/greenmask/internal/db/postgres/restorers"
"github.com/greenmaskio/greenmask/internal/db/postgres/storage"
Expand Down Expand Up @@ -68,6 +70,8 @@ const (

const metadataObjectName = "metadata.json"

const dependenciesCheckInterval = 15 * time.Millisecond

type Restore struct {
binPath string
dsn string
Expand All @@ -80,24 +84,28 @@ type Restore struct {
tmpDir string
cfg *domains.Restore
metadata *storage.Metadata
mx *sync.RWMutex

preDataClenUpToc string
postDataClenUpToc string
restoredDumpIds map[int32]bool
}

func NewRestore(
binPath string, st storages.Storager, cfg *domains.Restore, s map[string][]pgrestore.Script, tmpDir string,
) *Restore {

return &Restore{
binPath: binPath,
st: st,
pgRestore: pgrestore.NewPgRestore(binPath),
restoreOpt: &cfg.PgRestoreOptions,
scripts: s,
tmpDir: path.Join(tmpDir, fmt.Sprintf("%d", time.Now().UnixNano())),
cfg: cfg,
metadata: &storage.Metadata{},
binPath: binPath,
st: st,
pgRestore: pgrestore.NewPgRestore(binPath),
restoreOpt: &cfg.PgRestoreOptions,
scripts: s,
tmpDir: path.Join(tmpDir, fmt.Sprintf("%d", time.Now().UnixNano())),
cfg: cfg,
metadata: &storage.Metadata{},
restoredDumpIds: make(map[int32]bool),
mx: &sync.RWMutex{},
}
}

Expand Down Expand Up @@ -132,6 +140,27 @@ func (r *Restore) Run(ctx context.Context) error {
return nil
}

func (r *Restore) putDumpId(task restorers.RestoreTask) {
tableTask, ok := task.(*restorers.TableRestorer)
if !ok {
return
}
r.mx.Lock()
r.restoredDumpIds[tableTask.Entry.DumpId] = true
r.mx.Unlock()
}

func (r *Restore) dependenciesAreRestored(deps []int32) bool {
r.mx.RLock()
defer r.mx.RUnlock()
for _, id := range deps {
if !r.restoredDumpIds[id] {
return false
}
}
return true
}

func (r *Restore) readMetadata(ctx context.Context) error {
f, err := r.st.GetObject(ctx, metadataObjectName)
if err != nil {
Expand Down Expand Up @@ -532,7 +561,10 @@ func (r *Restore) sortTocEntriesInTopoOrder() []*toc.Entry {
lastTableIdx := slices.IndexFunc(dataEntries, func(entry *toc.Entry) bool {
return *entry.Desc == toc.SequenceSetDesc || *entry.Desc == toc.BlobsDesc
})
tableEntries := dataEntries[:lastTableIdx]
tableEntries := dataEntries
if lastTableIdx != -1 {
tableEntries = dataEntries[:lastTableIdx]
}
sortedTablesEntries := make([]*toc.Entry, 0, len(tableEntries))
for _, dumpId := range r.metadata.DumpIdsOrder {
idx := slices.IndexFunc(tableEntries, func(entry *toc.Entry) bool {
Expand All @@ -548,16 +580,32 @@ func (r *Restore) sortTocEntriesInTopoOrder() []*toc.Entry {

res = append(res, r.tocObj.Entries[:preDataEnd+1]...)
res = append(res, sortedTablesEntries...)
res = append(res, dataEntries[lastTableIdx:]...)
if lastTableIdx != -1 {
res = append(res, dataEntries[lastTableIdx:]...)
}
res = append(res, r.tocObj.Entries[postDataStart:]...)
return res
}

func (r *Restore) waitDependenciesAreRestore(ctx context.Context, deps []int32) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if r.dependenciesAreRestored(deps) {
return nil
}
time.Sleep(dependenciesCheckInterval)
}
}

func (r *Restore) taskPusher(ctx context.Context, tasks chan restorers.RestoreTask) func() error {
return func() error {
defer close(tasks)
tocEntries := r.tocObj.Entries
if r.restoreOpt.TopologicalSort {
if r.restoreOpt.RestoreInOrder {
tocEntries = r.sortTocEntriesInTopoOrder()
}
for _, entry := range tocEntries {
Expand All @@ -573,6 +621,13 @@ func (r *Restore) taskPusher(ctx context.Context, tasks chan restorers.RestoreTa
continue
}

if r.restoreOpt.RestoreInOrder && r.restoreOpt.Jobs > 1 {
deps := r.metadata.DependenciesGraph[entry.DumpId]
if err := r.waitDependenciesAreRestore(ctx, deps); err != nil {
return fmt.Errorf("cannot wait for dependencies are restored: %w", err)
}
}

var task restorers.RestoreTask
switch *entry.Desc {
case toc.TableDataDesc:
Expand Down Expand Up @@ -640,6 +695,7 @@ func (r *Restore) restoreWorker(ctx context.Context, tasks <-chan restorers.Rest
if err = task.Execute(ctx, conn); err != nil {
return fmt.Errorf("unable to perform restoration task (worker %d restoring %s): %w", id, task.DebugInfo(), err)
}
r.putDumpId(task)
log.Debug().
Int("workerId", id).
Str("objectName", task.DebugInfo()).
Expand Down
5 changes: 3 additions & 2 deletions internal/db/postgres/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ import (
"os"
"slices"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"

"github.com/greenmaskio/greenmask/internal/db/postgres/entries"
"github.com/greenmaskio/greenmask/internal/db/postgres/pgdump"
"github.com/greenmaskio/greenmask/internal/db/postgres/subset"
transformersUtils "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils"
"github.com/greenmaskio/greenmask/internal/domains"
"github.com/greenmaskio/greenmask/pkg/toolkit"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)

const defaultTransformerCostMultiplier = 0.03
Expand Down
4 changes: 3 additions & 1 deletion internal/db/postgres/context/pg_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ func getDumpObjects(
return tables, sequesnces, lo, nil
}

func getTables(ctx context.Context, version int, tx pgx.Tx, options *pgdump.Options, config map[toolkit.Oid]*entries.Table) ([]*entries.Table, []*entries.Sequence, error) {
func getTables(
ctx context.Context, version int, tx pgx.Tx, options *pgdump.Options, config map[toolkit.Oid]*entries.Table,
) ([]*entries.Table, []*entries.Sequence, error) {
// Building relation search query using regexp adaptation rules and pre-defined query templates
// TODO: Refactor it to gotemplate
query, err := BuildTableSearchQuery(options.Table, options.ExcludeTable,
Expand Down
2 changes: 1 addition & 1 deletion internal/db/postgres/pgrestore/pgrestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ type Options struct {
// statements on fly if needed
OnConflictDoNothing bool `mapstructure:"on-conflict-do-nothing"`
Inserts bool `mapstructure:"inserts"`
TopologicalSort bool `mapstructure:"topological-sort"`
RestoreInOrder bool `mapstructure:"restore-in-order"`

// Connection options:
Host string `mapstructure:"host"`
Expand Down
18 changes: 15 additions & 3 deletions internal/db/postgres/subset/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ import (
"slices"
"strings"

"github.com/greenmaskio/greenmask/pkg/toolkit"
"github.com/jackc/pgx/v5"
"github.com/rs/zerolog/log"

"github.com/greenmaskio/greenmask/pkg/toolkit"

"github.com/greenmaskio/greenmask/internal/db/postgres/entries"
)

Expand Down Expand Up @@ -49,6 +50,8 @@ type Graph struct {
scc []*Component
// condensedGraph - the condensed graph representation of the DB tables
condensedGraph [][]*CondensedEdge
// reversedCondensedGraph - the reversed condensed graph representation of the DB tables
reversedCondensedGraph [][]*CondensedEdge
// componentsToOriginalVertexes - the mapping condensed graph vertexes to the original graph vertexes
componentsToOriginalVertexes map[int][]int
// paths - the subset paths for the tables. The key is the vertex index in the graph and the value is the path for
Expand Down Expand Up @@ -238,6 +241,7 @@ func (g *Graph) buildCondensedGraph() {

// 3. Build condensed graph
g.condensedGraph = make([][]*CondensedEdge, g.sscCount)
g.reversedCondensedGraph = make([][]*CondensedEdge, g.sscCount)
var condensedEdgeIdxSeq int
for _, edge := range g.edges {
if _, ok := condensedEdges[edge.id]; ok {
Expand All @@ -260,6 +264,8 @@ func (g *Graph) buildCondensedGraph() {
)
condensedEdge := NewCondensedEdge(condensedEdgeIdxSeq, fromLink, toLink, edge)
g.condensedGraph[fromLinkIdx] = append(g.condensedGraph[fromLinkIdx], condensedEdge)
reversedEdges := NewCondensedEdge(condensedEdgeIdxSeq, toLink, fromLink, edge)
g.reversedCondensedGraph[toLinkIdx] = append(g.reversedCondensedGraph[toLinkIdx], reversedEdges)
condensedEdgeIdxSeq++
}
}
Expand Down Expand Up @@ -475,7 +481,7 @@ func (g *Graph) generateQueryForTables(path *Path, scopeEdge *ScopeEdge) string
}

func (g *Graph) GetSortedTablesAndDependenciesGraph() ([]toolkit.Oid, map[toolkit.Oid][]toolkit.Oid) {
condensedEdges := sortCondensedEdges(g.condensedGraph)
condensedEdges := sortCondensedEdges(g.reversedCondensedGraph)
var tables []toolkit.Oid
dependenciesGraph := make(map[toolkit.Oid][]toolkit.Oid)
for _, condEdgeIdx := range condensedEdges {
Expand All @@ -487,7 +493,11 @@ func (g *Graph) GetSortedTablesAndDependenciesGraph() ([]toolkit.Oid, map[toolki
tables = append(tables, componentTables...)
}

for _, edge := range g.condensedGraph {
for idx, edge := range g.reversedCondensedGraph {
for _, srcTable := range g.scc[idx].tables {
dependenciesGraph[srcTable.Oid] = make([]toolkit.Oid, 0)
}

for _, e := range edge {
for _, srcTable := range e.to.component.tables {
for _, dstTable := range e.from.component.tables {
Expand All @@ -497,6 +507,8 @@ func (g *Graph) GetSortedTablesAndDependenciesGraph() ([]toolkit.Oid, map[toolki
}
}

slices.Reverse(tables)

return tables, dependenciesGraph
}

Expand Down

0 comments on commit 950360a

Please sign in to comment.