diff --git a/internal/db/postgres/cmd/dump.go b/internal/db/postgres/cmd/dump.go index bc4ccd0b..243e6656 100644 --- a/internal/db/postgres/cmd/dump.go +++ b/internal/db/postgres/cmd/dump.go @@ -175,8 +175,10 @@ func (d *Dump) startMainTx(ctx context.Context, conn *pgx.Conn) (pgx.Tx, error) } func (d *Dump) buildContextAndValidate(ctx context.Context, tx pgx.Tx) (err error) { - d.context, err = runtimeContext.NewRuntimeContext(ctx, tx, d.config.Dump.Transformation, d.registry, - d.pgDumpOptions, d.version) + d.context, err = runtimeContext.NewRuntimeContext( + ctx, tx, d.config.Dump.Transformation, d.registry, d.pgDumpOptions, + d.config.Dump.VirtualReferences, d.version, + ) if err != nil { return fmt.Errorf("unable to build runtime context: %w", err) } diff --git a/internal/db/postgres/cmd/validate.go b/internal/db/postgres/cmd/validate.go index 0c93e0e8..570ef628 100644 --- a/internal/db/postgres/cmd/validate.go +++ b/internal/db/postgres/cmd/validate.go @@ -122,8 +122,10 @@ func (v *Validate) Run(ctx context.Context) (int, error) { } v.config.Dump.Transformation = tablesToValidate - v.context, err = runtimeContext.NewRuntimeContext(ctx, tx, v.config.Dump.Transformation, v.registry, - v.pgDumpOptions, v.version) + v.context, err = runtimeContext.NewRuntimeContext( + ctx, tx, v.config.Dump.Transformation, v.registry, + v.pgDumpOptions, v.config.Dump.VirtualReferences, v.version, + ) if err != nil { return nonZeroExitCode, fmt.Errorf("unable to build runtime context: %w", err) } diff --git a/internal/db/postgres/context/context.go b/internal/db/postgres/context/context.go index a8234152..e6697609 100644 --- a/internal/db/postgres/context/context.go +++ b/internal/db/postgres/context/context.go @@ -23,6 +23,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/rs/zerolog/log" "github.com/greenmaskio/greenmask/internal/db/postgres/entries" "github.com/greenmaskio/greenmask/internal/db/postgres/pgdump" @@ -61,8 +62,9 @@ type RuntimeContext struct { // // warnings are fatal procedure must be terminated immediately due to lack of objects required on the next step func NewRuntimeContext( - ctx context.Context, tx pgx.Tx, cfg []*domains.Table, r *transformersUtils.TransformerRegistry, opt *pgdump.Options, - version int, + ctx context.Context, tx pgx.Tx, cfg []*domains.Table, + r *transformersUtils.TransformerRegistry, opt *pgdump.Options, + vr []*domains.VirtualReference, version int, ) (*RuntimeContext, error) { var salt []byte saltHex := os.Getenv("GREENMASK_GLOBAL_SALT") @@ -98,8 +100,14 @@ func NewRuntimeContext( if err != nil { return nil, fmt.Errorf("cannot get database schema: %w", err) } + vrWarns := validateVirtualReferences(vr, tablesEntries) + warnings = append(warnings, vrWarns...) + if len(vrWarns) > 0 { + // if there are any warnings, we should use them in the graph build + vr = nil + } - graph, err := subset.NewGraph(ctx, tx, tablesEntries) + graph, err := subset.NewGraph(ctx, tx, tablesEntries, vr) if err != nil { return nil, fmt.Errorf("error creating graph: %w", err) } @@ -109,7 +117,7 @@ func NewRuntimeContext( if err = subset.SetSubsetQueries(graph); err != nil { return nil, fmt.Errorf("cannot set subset queries: %w", err) } - + debugQueries(tablesEntries) } else { // if there are no subset tables, we can sort them by size and transformation costs // TODO: Implement tables ordering for subsetted tables as well @@ -182,3 +190,16 @@ func hasSubset(tables []*entries.Table) bool { return len(table.SubsetConds) > 0 }) } + +func debugQueries(tables []*entries.Table) { + for _, t := range tables { + if t.Query == "" { + continue + } + log.Debug(). + Str("Schema", t.Schema). + Str("Table", t.Name). + Msg("Debug query") + log.Logger.Println(t.Query) + } +} diff --git a/internal/db/postgres/context/virtual_references.go b/internal/db/postgres/context/virtual_references.go new file mode 100644 index 00000000..c4505ee7 --- /dev/null +++ b/internal/db/postgres/context/virtual_references.go @@ -0,0 +1,180 @@ +package context + +import ( + "slices" + + "github.com/greenmaskio/greenmask/internal/db/postgres/entries" + "github.com/greenmaskio/greenmask/internal/domains" + "github.com/greenmaskio/greenmask/pkg/toolkit" +) + +func getReferencedKeys(r *domains.Reference) (res []string) { + for _, ref := range r.Columns { + if ref.Name != "" { + res = append(res, ref.Name) + } else if ref.Expression != "" { + res = append(res, ref.Expression) + } + } + return +} + +func validateVirtualReferences(vrs []*domains.VirtualReference, tables []*entries.Table) (res toolkit.ValidationWarnings) { + for idx, vr := range vrs { + res = append(res, validateVirtualReference(idx, vr, tables)...) + } + return +} + +func validateVirtualReference(tableIdx int, vr *domains.VirtualReference, tables []*entries.Table) (res toolkit.ValidationWarnings) { + if vr.Schema == "" { + w := toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + SetMsg("schema is required"). + AddMeta("TableIdx", tableIdx) + res = append(res, w) + } + if vr.Name == "" { + w := toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + SetMsg("table name is required"). + AddMeta("TableIdx", tableIdx) + res = append(res, w) + } + if len(vr.References) == 0 { + w := toolkit.NewValidationWarning(). + SetMsg("virtual reference error: references are required: received empty"). + SetSeverity(toolkit.ErrorValidationSeverity). + AddMeta("TableIdx", tableIdx). + AddMeta("TableName", vr.Name). + AddMeta("TableSchema", vr.Name) + res = append(res, w) + } + + referencedTableIdx := slices.IndexFunc(tables, func(t *entries.Table) bool { + return t.Name == vr.Name && t.Schema == vr.Schema + }) + + if referencedTableIdx == -1 { + w := toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + SetMsg("virtual reference error: table not found"). + AddMeta("TableIdx", tableIdx). + AddMeta("TableName", vr.Name). + AddMeta("TableSchema", vr.Schema) + res = append(res, w) + return + } + + fkT := tables[referencedTableIdx] + + for idx, v := range vr.References { + var vrWarns toolkit.ValidationWarnings + + primaryKeyTableIdx := slices.IndexFunc(tables, func(t *entries.Table) bool { + return t.Name == v.Name && t.Schema == v.Schema + }) + if primaryKeyTableIdx == -1 { + w := toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + SetMsg("virtual reference error: table not found"). + AddMeta("ReferenceIdx", idx). + AddMeta("ReferenceName", v.Name). + AddMeta("ReferenceSchema", v.Schema) + vrWarns = append(vrWarns, w) + continue + } + pkT := tables[primaryKeyTableIdx] + + for _, w := range validateReference(idx, v, fkT, pkT) { + w.AddMeta("TableIdx", tableIdx). + SetSeverity(toolkit.ErrorValidationSeverity). + AddMeta("TableName", vr.Name). + AddMeta("TableSchema", vr.Schema) + vrWarns = append(vrWarns, w) + } + res = append(res, vrWarns...) + } + return res +} + +func validateReference(vrIdx int, v *domains.Reference, fkT, pkT *entries.Table) (res toolkit.ValidationWarnings) { + if v.Schema == "" { + w := toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + SetMsg("virtual reference error: schema is required"). + AddMeta("ReferenceIdx", vrIdx) + res = append(res, w) + } + if v.Name == "" { + w := toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + SetMsg("virtual reference error: table name is required"). + AddMeta("ReferenceIdx", vrIdx) + res = append(res, w) + } + if len(v.Columns) == 0 { + w := toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + SetMsg("columns are required: received empty"). + AddMeta("ReferenceIdx", vrIdx). + AddMeta("ReferenceName", v.Name). + AddMeta("ReferenceSchema", v.Schema) + res = append(res, w) + } + refCols := getReferencedKeys(v) + if len(refCols) != len(pkT.PrimaryKey) { + w := toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + SetMsg("virtual reference error: number of columns in reference does not match primary key"). + AddMeta("ReferenceIdx", vrIdx). + AddMeta("ReferencedTableColumns", refCols). + AddMeta("PrimaryTableColumns", pkT.PrimaryKey). + AddMeta("ReferenceName", v.Name). + AddMeta("ReferenceSchema", v.Schema) + res = append(res, w) + } + + for idx, c := range v.Columns { + var vrWarns toolkit.ValidationWarnings + for _, w := range validateColumn(idx, c, fkT) { + w.AddMeta("ReferenceIdx", vrIdx). + SetSeverity(toolkit.ErrorValidationSeverity). + AddMeta("ReferenceName", v.Name). + AddMeta("ReferenceSchema", v.Schema) + vrWarns = append(vrWarns, w) + } + res = append(res, vrWarns...) + } + + return res +} + +func validateColumn(colIdx int, c *domains.ReferencedColumn, fkT *entries.Table) (res toolkit.ValidationWarnings) { + if c.Name == "" && c.Expression == "" { + w := toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + SetMsg("virtual reference error: name or expression is required"). + AddMeta("ColumnIdx", colIdx) + res = append(res, w) + } + if c.Name != "" && c.Expression != "" { + w := toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + SetMsg("virtual reference error: name and expression are mutually exclusive"). + AddMeta("ColumnIdx", colIdx) + res = append(res, w) + } + if c.Name != "" && !slices.ContainsFunc(fkT.Columns, func(column *toolkit.Column) bool { + return column.Name == c.Name + }) { + w := toolkit.NewValidationWarning(). + SetSeverity(toolkit.ErrorValidationSeverity). + SetMsg("virtual reference error: column not found"). + AddMeta("ColumnIdx", colIdx). + AddMeta("ColumnName", c.Name) + res = append(res, w) + } + + return res +} diff --git a/internal/db/postgres/subset/component.go b/internal/db/postgres/subset/component.go index c6817b7f..98967640 100644 --- a/internal/db/postgres/subset/component.go +++ b/internal/db/postgres/subset/component.go @@ -20,7 +20,6 @@ type Component struct { // Cycles cycles [][]*Edge cyclesIdents map[string]struct{} - keys []string // groupedCycles - cycles grouped by the vertexes groupedCycles map[string][]int // groupedCyclesGraph - contains the mapping of the vertexes in the component to the edges in the original graph @@ -36,11 +35,6 @@ func NewComponent(id int, componentGraph map[int][]*Edge, tables map[int]*entrie cyclesIdents: make(map[string]struct{}), } c.findCycles() - if c.hasCycle() { - c.keys = c.getComponentKeys() - } else { - c.keys = c.getOneTable().PrimaryKey - } c.groupCycles() c.buildCyclesGraph() diff --git a/internal/db/postgres/subset/component_link.go b/internal/db/postgres/subset/component_link.go index f2560027..7c6905dc 100644 --- a/internal/db/postgres/subset/component_link.go +++ b/internal/db/postgres/subset/component_link.go @@ -5,7 +5,7 @@ type ComponentLink struct { component *Component } -func NewComponentLink(idx int, c *Component, keys, overriddenKeys []string) *ComponentLink { +func NewComponentLink(idx int, c *Component) *ComponentLink { return &ComponentLink{ idx: idx, component: c, diff --git a/internal/db/postgres/subset/graph.go b/internal/db/postgres/subset/graph.go index f03e7ade..2f1388e1 100644 --- a/internal/db/postgres/subset/graph.go +++ b/internal/db/postgres/subset/graph.go @@ -7,11 +7,11 @@ import ( "slices" "strings" + "github.com/greenmaskio/greenmask/internal/domains" + "github.com/greenmaskio/greenmask/pkg/toolkit" "github.com/jackc/pgx/v5" "github.com/rs/zerolog/log" - "github.com/greenmaskio/greenmask/pkg/toolkit" - "github.com/greenmaskio/greenmask/internal/db/postgres/entries" ) @@ -64,7 +64,9 @@ type Graph struct { } // NewGraph creates a new graph based on the provided tables by finding the references in DB between them -func NewGraph(ctx context.Context, tx pgx.Tx, tables []*entries.Table) (*Graph, error) { +func NewGraph( + ctx context.Context, tx pgx.Tx, tables []*entries.Table, vr []*domains.VirtualReference, +) (*Graph, error) { graph := make([][]*Edge, len(tables)) reversedGraph := make([][]int, len(tables)) edges := make([]*Edge, 0) @@ -76,36 +78,72 @@ func NewGraph(ctx context.Context, tx pgx.Tx, tables []*entries.Table) (*Graph, return nil, fmt.Errorf("error getting references: %w", err) } for _, ref := range refs { - foreignTableIdx := slices.IndexFunc(tables, func(t *entries.Table) bool { + referenceTableIdx := slices.IndexFunc(tables, func(t *entries.Table) bool { return t.Name == ref.Name && t.Schema == ref.Schema }) - if foreignTableIdx == -1 { + if referenceTableIdx == -1 { log.Debug(). Str("Schema", ref.Schema). Str("Table", ref.Name). - Msg("unable to find foreign table: it might be excluded from the dump") + Msg("unable to find reference table (primary): it might be excluded from the dump") continue } edge := NewEdge( edgeIdSequence, - foreignTableIdx, + referenceTableIdx, ref.IsNullable, - NewTableLink(idx, table, ref.ReferencedKeys), - NewTableLink(foreignTableIdx, tables[foreignTableIdx], tables[foreignTableIdx].PrimaryKey), + NewTableLink(idx, table, NewKeysByColumn(ref.ReferencedKeys)), + NewTableLink(referenceTableIdx, tables[referenceTableIdx], NewKeysByColumn(tables[referenceTableIdx].PrimaryKey)), + ) + graph[idx] = append( + graph[idx], + edge, + ) + + reversedGraph[referenceTableIdx] = append( + reversedGraph[referenceTableIdx], + idx, + ) + edges = append(edges, edge) + + edgeIdSequence++ + } + + for _, ref := range getVirtualReferences(vr, table) { + + referenceTableIdx := slices.IndexFunc(tables, func(t *entries.Table) bool { + return t.Name == ref.Name && t.Schema == ref.Schema + }) + + if referenceTableIdx == -1 { + log.Debug(). + Str("Schema", ref.Schema). + Str("Table", ref.Name). + Msg("unable to find reference table (primary): it might be excluded from the dump") + continue + } + + edge := NewEdge( + edgeIdSequence, + referenceTableIdx, + !ref.NotNull, + NewTableLink(idx, table, NewKeysByReferencedColumn(ref.Columns)), + NewTableLink(referenceTableIdx, tables[referenceTableIdx], NewKeysByColumn(tables[referenceTableIdx].PrimaryKey)), ) graph[idx] = append( graph[idx], edge, ) - reversedGraph[foreignTableIdx] = append( - reversedGraph[foreignTableIdx], + reversedGraph[referenceTableIdx] = append( + reversedGraph[referenceTableIdx], idx, ) edges = append(edges, edge) edgeIdSequence++ + } } g := &Graph{ @@ -277,15 +315,11 @@ func (g *Graph) buildCondensedGraph() { fromLink := NewComponentLink( fromLinkIdx, ssc[fromLinkIdx], - edge.from.keys, - overrideKeys(edge.from.table, edge.from.keys), ) toLinkIdx := originalVertexesToComponents[edge.to.idx] toLink := NewComponentLink( toLinkIdx, ssc[toLinkIdx], - edge.to.keys, - overrideKeys(edge.to.table, edge.to.keys), ) condensedEdge := NewCondensedEdge(condensedEdgeIdxSeq, fromLink, toLink, edge) g.condensedGraph[fromLinkIdx] = append(g.condensedGraph[fromLinkIdx], condensedEdge) @@ -310,8 +344,6 @@ func (g *Graph) generateAndSetQueryForScc(path *Path) { g.generateQueriesSccDfs(cq, path, nil) for _, t := range rootVertex.tables { query := cq.generateQuery(t) - fmt.Printf("********************%s.%s*****************\n", t.Schema, t.Name) - println(query) t.Query = query } } @@ -521,14 +553,14 @@ func (g *Graph) generateQueryForTables(path *Path, scopeEdge *ScopeEdge) string originalEdge := scopeEdge.originalCondensedEdge.originalEdge for _, k := range originalEdge.from.keys { leftTable := originalEdge.from.table - leftTableConds = append(leftTableConds, fmt.Sprintf(`"%s"."%s"."%s"`, leftTable.Schema, leftTable.Name, k)) + leftTableConds = append(leftTableConds, k.GetKeyReference(leftTable)) } query = fmt.Sprintf("((%s) IN (%s))", strings.Join(leftTableConds, ", "), query) if scopeEdge.isNullable { var nullableChecks []string for _, k := range originalEdge.from.keys { - nullableCheck := fmt.Sprintf(`"%s"."%s"."%s" IS NULL`, originalEdge.from.table.Schema, originalEdge.from.table.Name, k) + nullableCheck := fmt.Sprintf(`%s IS NULL`, k.GetKeyReference(originalEdge.from.table)) nullableChecks = append(nullableChecks, nullableCheck) } query = fmt.Sprintf( @@ -592,14 +624,6 @@ func getReferences(ctx context.Context, tx pgx.Tx, tableOid toolkit.Oid) ([]*too return refs, nil } -func overrideKeys(table *entries.Table, keys []string) []string { - var res []string - for _, k := range keys { - res = append(res, fmt.Sprintf(`"%s.%s.%s"`, table.Schema, table.Name, k)) - } - return res -} - func isPathForScc(path *Path, graph *Graph) bool { return graph.scc[path.rootVertex].hasCycle() } @@ -634,7 +658,10 @@ func generateQuery( var droppedKeysWithAliases []string for _, k := range droppedEdge.from.keys { t := droppedEdge.from.table - droppedKeysWithAliases = append(droppedKeysWithAliases, fmt.Sprintf(`"%s"."%s"."%s" as "%s__%s__%s"`, t.Schema, t.Name, k, t.Schema, t.Name, k)) + droppedKeysWithAliases = append( + droppedKeysWithAliases, + fmt.Sprintf(`%s as "%s__%s__%s"`, k.GetKeyReference(t), t.Schema, t.Name, k.Name), + ) } selectKeys = append(selectKeys, droppedKeysWithAliases...) @@ -684,8 +711,6 @@ func generateQuery( recursiveKeys := slices.Clone(selectKeys) for _, k := range cycle[0].from.table.PrimaryKey { t := cycle[0].from.table - //recursivePathSelectionKeys = append(recursivePathSelectionKeys, fmt.Sprintf(`coalesce("%s"."%s"."%s"::TEXT, 'NULL')`, t.Schema, t.Name, k)) - pathName := fmt.Sprintf( `"%s__%s__%s__path" || ARRAY["%s"."%s"."%s"]`, t.Schema, t.Name, k, @@ -766,7 +791,10 @@ func generateOverlapQuery( var droppedKeysWithAliases []string for _, k := range droppedEdge.from.keys { t := droppedEdge.from.table - droppedKeysWithAliases = append(droppedKeysWithAliases, fmt.Sprintf(`"%s"."%s"."%s" as "%s__%s__%s"`, t.Schema, t.Name, k, t.Schema, t.Name, k)) + droppedKeysWithAliases = append( + droppedKeysWithAliases, + fmt.Sprintf(`%s as "%s__%s__%s"`, k.GetKeyReference(t), t.Schema, t.Name, k.Name), + ) } selectKeys = append(selectKeys, droppedKeysWithAliases...) @@ -919,22 +947,16 @@ func generateIntegrityChecksForNullableEdges(nullabilityMap map[int]bool, edges leftTableKey := e.from.keys rightTableKey := e.to.keys k := fmt.Sprintf( - `("%s"."%s"."%s" IS NULL OR "%s"."%s"."%s" IS NOT NULL)`, - e.from.table.Schema, - e.from.table.Name, - leftTableKey[idx], - e.to.table.Schema, - e.to.table.Name, - rightTableKey[idx], + `(%s IS NULL OR %s IS NOT NULL)`, + leftTableKey[idx].GetKeyReference(e.from.table), + rightTableKey[idx].GetKeyReference(e.to.table), ) if _, ok := overriddenTables[e.to.table.Oid]; ok { k = fmt.Sprintf( - `("%s"."%s"."%s" IS NULL OR "%s"."%s" IS NOT NULL)`, - e.from.table.Schema, - e.from.table.Name, - leftTableKey[idx], + `(%s IS NULL OR "%s"."%s" IS NOT NULL)`, + leftTableKey[idx].GetKeyReference(e.from.table), overriddenTables[e.to.table.Oid], - rightTableKey[idx], + rightTableKey[idx].Name, ) } keys = append(keys, k) @@ -1085,3 +1107,49 @@ func shiftUntilVertexWillBeFirst(v *Edge, c []*Edge) []*Edge { } return res } + +func validateVirtualReference(r *domains.Reference, pkT, fkT *entries.Table) error { + // TODO: Create ValidationWarning for it + keys := getReferencedKeys(r) + if len(keys) == 0 { + return fmt.Errorf("no keys found in reference %s.%s", r.Schema, r.Name) + } + if len(keys) != len(pkT.PrimaryKey) { + return fmt.Errorf("number of keys in reference %s.%s does not match primary key of %s.%s", r.Schema, r.Name, pkT.Schema, pkT.Name) + } + for _, col := range r.Columns { + if col.Name == "" && col.Expression == "" { + return fmt.Errorf("empty column name and expression in reference %s.%s", r.Schema, r.Name) + } + if col.Name != "" && col.Expression != "" { + return fmt.Errorf("only name or expression should be set in reference item at the same time %s.%s", r.Schema, r.Name) + } + if col.Name != "" && !slices.ContainsFunc(fkT.Columns, func(column *toolkit.Column) bool { + return column.Name == col.Name + }) { + return fmt.Errorf("column %s not found in table %s.%s", col.Name, fkT.Schema, fkT.Name) + } + } + return nil +} + +func getVirtualReferences(vr []*domains.VirtualReference, t *entries.Table) []*domains.Reference { + idx := slices.IndexFunc(vr, func(r *domains.VirtualReference) bool { + return r.Schema == t.Schema && r.Name == t.Name + }) + if idx == -1 { + return nil + } + return vr[idx].References +} + +func getReferencedKeys(r *domains.Reference) (res []string) { + for _, ref := range r.Columns { + if ref.Name != "" { + res = append(res, ref.Name) + } else if ref.Expression != "" { + res = append(res, ref.Expression) + } + } + return +} diff --git a/internal/db/postgres/subset/query.go b/internal/db/postgres/subset/query.go index 334a951d..13e4e6b7 100644 --- a/internal/db/postgres/subset/query.go +++ b/internal/db/postgres/subset/query.go @@ -17,10 +17,9 @@ func generateJoinClauseForDroppedEdge(edge *Edge, initTableName string) string { var conds []string var leftTableKeys []string - keys := edge.from.keys table := edge.from.table - for _, key := range keys { - leftTableKeys = append(leftTableKeys, fmt.Sprintf(`%s__%s__%s`, table.Schema, table.Name, key)) + for _, key := range edge.from.keys { + leftTableKeys = append(leftTableKeys, fmt.Sprintf(`%s__%s__%s`, table.Schema, table.Name, key.Name)) } rightTable := edge.to @@ -32,13 +31,7 @@ func generateJoinClauseForDroppedEdge(edge *Edge, initTableName string) string { leftTableKeys[idx], ) - rightPart := fmt.Sprintf( - `"%s"."%s"."%s"`, - rightTable.table.Schema, - rightTable.table.Name, - edge.to.keys[idx], - ) - + rightPart := edge.to.keys[idx].GetKeyReference(rightTable.table) conds = append(conds, fmt.Sprintf(`%s = %s`, leftPart, rightPart)) } @@ -62,24 +55,14 @@ func generateJoinClauseV2(edge *Edge, joinType string, overriddenTables map[tool leftTable, rightTable := edge.from.table, edge.to.table for idx := 0; idx < len(edge.from.keys); idx++ { - leftPart := fmt.Sprintf( - `"%s"."%s"."%s"`, - leftTable.Table.Schema, - leftTable.Table.Name, - edge.from.keys[idx], - ) + leftPart := edge.from.keys[idx].GetKeyReference(leftTable) + rightPart := edge.to.keys[idx].GetKeyReference(rightTable) - rightPart := fmt.Sprintf( - `"%s"."%s"."%s"`, - rightTable.Table.Schema, - rightTable.Table.Name, - edge.to.keys[idx], - ) if override, ok := overriddenTables[rightTable.Table.Oid]; ok { rightPart = fmt.Sprintf( `"%s"."%s"`, override, - edge.to.keys[idx], + edge.to.keys[idx].Name, ) } @@ -91,7 +74,7 @@ func generateJoinClauseV2(edge *Edge, joinType string, overriddenTables map[tool rightTableName := fmt.Sprintf(`"%s"."%s"`, rightTable.Table.Schema, rightTable.Table.Name) if override, ok := overriddenTables[rightTable.Table.Oid]; ok { - rightTableName = fmt.Sprintf(`"%s"`, override) + rightTableName = override } joinClause := fmt.Sprintf( diff --git a/internal/db/postgres/subset/table_link.go b/internal/db/postgres/subset/table_link.go index b96dc276..108cf9c4 100644 --- a/internal/db/postgres/subset/table_link.go +++ b/internal/db/postgres/subset/table_link.go @@ -1,14 +1,47 @@ package subset -import "github.com/greenmaskio/greenmask/internal/db/postgres/entries" +import ( + "fmt" + + "github.com/greenmaskio/greenmask/internal/db/postgres/entries" + "github.com/greenmaskio/greenmask/internal/domains" +) + +type Key struct { + Name string + Expression string +} + +func (k *Key) GetKeyReference(t *entries.Table) string { + if k.Expression != "" { + return k.Expression + } + return fmt.Sprintf(`"%s"."%s"."%s"`, t.Schema, t.Name, k.Name) +} + +func NewKeysByColumn(cols []string) []*Key { + keys := make([]*Key, 0, len(cols)) + for _, col := range cols { + keys = append(keys, &Key{Name: col}) + } + return keys +} + +func NewKeysByReferencedColumn(cols []*domains.ReferencedColumn) []*Key { + keys := make([]*Key, 0, len(cols)) + for _, col := range cols { + keys = append(keys, &Key{Name: col.Name, Expression: col.Expression}) + } + return keys +} type TableLink struct { idx int table *entries.Table - keys []string + keys []*Key } -func NewTableLink(idx int, t *entries.Table, keys []string) *TableLink { +func NewTableLink(idx int, t *entries.Table, keys []*Key) *TableLink { return &TableLink{ idx: idx, table: t, diff --git a/internal/domains/config.go b/internal/domains/config.go index ada2d446..84d9a5c3 100644 --- a/internal/domains/config.go +++ b/internal/domains/config.go @@ -93,8 +93,9 @@ type LogConfig struct { } type Dump struct { - PgDumpOptions pgdump.Options `mapstructure:"pg_dump_options" yaml:"pg_dump_options" json:"pg_dump_options"` - Transformation []*Table `mapstructure:"transformation" yaml:"transformation" json:"transformation,omitempty"` + PgDumpOptions pgdump.Options `mapstructure:"pg_dump_options" yaml:"pg_dump_options" json:"pg_dump_options"` + Transformation []*Table `mapstructure:"transformation" yaml:"transformation" json:"transformation,omitempty"` + VirtualReferences []*VirtualReference `mapstructure:"virtual_references" yaml:"virtual_references" json:"virtual_references,omitempty"` } type Restore struct { diff --git a/internal/domains/virtual_references.go b/internal/domains/virtual_references.go new file mode 100644 index 00000000..3a0be71f --- /dev/null +++ b/internal/domains/virtual_references.go @@ -0,0 +1,19 @@ +package domains + +type ReferencedColumn struct { + Name string `mapstructure:"name" json:"name" yaml:"name"` + Expression string `mapstructure:"expression" json:"expression" yaml:"expression"` +} + +type Reference struct { + Schema string `mapstructure:"schema" json:"schema" yaml:"schema"` + Name string `mapstructure:"name" json:"name" yaml:"name"` + NotNull bool `mapstructure:"not_null" json:"not_null" yaml:"not_null"` + Columns []*ReferencedColumn `mapstructure:"columns" json:"columns" yaml:"columns"` +} + +type VirtualReference struct { + Schema string `mapstructure:"schema" json:"schema" yaml:"schema"` + Name string `mapstructure:"name" json:"name" yaml:"name"` + References []*Reference `mapstructure:"references" json:"references" yaml:"references"` +}