diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 17db0451..122c36ed 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,7 +54,7 @@ jobs: - name: Run integration tests run: | - docker-compose -f docker-compose-integration.yml -p greenmask up \ + docker compose -f docker-compose-integration.yml -p greenmask up \ --renew-anon-volumes --force-recreate --build --exit-code-from greenmask \ --abort-on-container-exit greenmask diff --git a/internal/db/postgres/context/table.go b/internal/db/postgres/context/table.go index 74229ed0..5096dfb0 100644 --- a/internal/db/postgres/context/table.go +++ b/internal/db/postgres/context/table.go @@ -173,7 +173,7 @@ func getTable(ctx context.Context, tx pgx.Tx, t *domains.Table) ([]*entries.Tabl return nil, nil, fmt.Errorf("cannot aply custom query on partitioned table \"%s\".\"%s\": is not supported", table.Schema, table.Name) } table.Query = t.Query - table.SubsetConds = t.SubsetConds + table.SubsetConds = escapeSubsetConds(t.SubsetConds) if table.RelKind == 'p' { if !t.ApplyForInherited { @@ -406,3 +406,11 @@ func getTableConstraints(ctx context.Context, tx pgx.Tx, tableOid toolkit.Oid, v return constraints, nil } + +func escapeSubsetConds(conds []string) []string { + var res []string + for _, c := range conds { + res = append(res, fmt.Sprintf(`( %s )`, c)) + } + return res +} diff --git a/internal/db/postgres/dumpers/table.go b/internal/db/postgres/dumpers/table.go index bd021c75..97707e45 100644 --- a/internal/db/postgres/dumpers/table.go +++ b/internal/db/postgres/dumpers/table.go @@ -93,7 +93,7 @@ func (td *TableDumper) Execute(ctx context.Context, tx pgx.Tx, st storages.Stora if doneErr != nil { log.Warn().Err(err).Msg("error terminating transformation pipeline") } - return fmt.Errorf("error processing table dump: %w", err) + return fmt.Errorf("error processing table dump %s.%s: %w", td.table.Schema, td.table.Name, err) } log.Debug().Msg("transformation pipeline executed successfully") return pipeline.Done(gtx) diff --git a/internal/db/postgres/dumpers/transformation_pipeline.go b/internal/db/postgres/dumpers/transformation_pipeline.go index 6c6b029e..f083253e 100644 --- a/internal/db/postgres/dumpers/transformation_pipeline.go +++ b/internal/db/postgres/dumpers/transformation_pipeline.go @@ -18,8 +18,6 @@ import ( "context" "fmt" "io" - "os" - "path" "slices" "github.com/rs/zerolog/log" @@ -48,7 +46,6 @@ type TransformationPipeline struct { Transform TransformationFunc isAsync bool record *toolkit.Record - cycleResolutionFiles []io.ReadWriteCloser } func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *entries.Table, w io.Writer) (*TransformationPipeline, error) { @@ -130,17 +127,6 @@ func (tp *TransformationPipeline) Init(ctx context.Context) error { } } - // Initialize cycle resolution store files - tp.cycleResolutionFiles = make([]io.ReadWriteCloser, len(tp.table.CycleResolutionOps)) - for cycleResOpIdx, op := range tp.table.CycleResolutionOps { - file, err := os.Create(path.Join(tmpFilePath, op.FileName)) - if err != nil { - closeAllOpenFiles(tp.cycleResolutionFiles, tp.table.CycleResolutionOps[:cycleResOpIdx], true) - return fmt.Errorf("error creating cycle resolution store file: %w", err) - } - tp.cycleResolutionFiles[cycleResOpIdx] = file - } - return nil } @@ -172,9 +158,6 @@ func (tp *TransformationPipeline) Dump(ctx context.Context, data []byte) (err er return fmt.Errorf("error decoding copy line: %w", err) } tp.record.SetRow(tp.row) - if err = storeCycleResolutionOps(tp.record, tp.table.CycleResolutionOps, tp.cycleResolutionFiles); err != nil { - return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error storing cycle resolution ops: %w", err)) - } _, err = tp.Transform(ctx, tp.record) if err != nil { @@ -225,8 +208,6 @@ func (tp *TransformationPipeline) Done(ctx context.Context) error { } } - closeAllOpenFiles(tp.cycleResolutionFiles, tp.table.CycleResolutionOps, false) - if lastErr != nil { return fmt.Errorf("error terminating initialized transformer: %w", lastErr) } diff --git a/internal/db/postgres/dumpers/utils.go b/internal/db/postgres/dumpers/utils.go deleted file mode 100644 index 8ad06167..00000000 --- a/internal/db/postgres/dumpers/utils.go +++ /dev/null @@ -1,71 +0,0 @@ -package dumpers - -import ( - "fmt" - "io" - "os" - "path" - - "github.com/greenmaskio/greenmask/internal/db/postgres/entries" - "github.com/greenmaskio/greenmask/internal/db/postgres/pgcopy" - "github.com/greenmaskio/greenmask/pkg/toolkit" - "github.com/rs/zerolog/log" -) - -func storeCycleResolutionOps(r *toolkit.Record, storeOps []*entries.CycleResolutionOp, files []io.ReadWriteCloser) error { - for idx := 0; idx < len(storeOps); idx++ { - storeOp := storeOps[idx] - file := files[idx] - row := pgcopy.NewRow(len(storeOp.Columns)) - var hasNull bool - for storeColIdx, col := range storeOp.Columns { - columnIdx, _, ok := r.Driver.GetColumnByName(col) - if !ok { - return fmt.Errorf("column %s not found in record", col) - } - rawValue, err := r.Row.GetColumn(columnIdx) - if err != nil { - return fmt.Errorf("error getting column value: %w", err) - } - if rawValue.IsNull { - hasNull = true - } - if err = row.SetColumn(storeColIdx, rawValue); err != nil { - return fmt.Errorf("error setting column value: %w", err) - } - } - - if hasNull { - continue - } - - res, err := row.Encode() - if err != nil { - return fmt.Errorf("error encoding row: %w", err) - } - if _, err = file.Write(res); err != nil { - return fmt.Errorf("error writing row: %w", err) - } - if _, err = file.Write([]byte{'\n'}); err != nil { - return fmt.Errorf("error writing row: %w", err) - } - } - return nil -} - -func closeAllOpenFiles(files []io.ReadWriteCloser, cycleResolutionOps []*entries.CycleResolutionOp, remove bool) { - for cleanIdx, cleanOp := range cycleResolutionOps { - f := files[cleanIdx] - if f != nil { - log.Debug().Str("file", cleanOp.FileName).Msg("closing cycle resolution store file") - if err := f.Close(); err != nil { - log.Warn().Err(err).Msg("error closing cycle resolution store file") - } - if remove { - if err := os.Remove(path.Join(tmpFilePath, cleanOp.FileName)); err != nil { - log.Warn().Err(err).Msg("error removing cycle resolution store file") - } - } - } - } -} diff --git a/internal/db/postgres/entries/cycle_resolution_op.go b/internal/db/postgres/entries/cycle_resolution_op.go deleted file mode 100644 index 16e058a3..00000000 --- a/internal/db/postgres/entries/cycle_resolution_op.go +++ /dev/null @@ -1,17 +0,0 @@ -package entries - -import "github.com/greenmaskio/greenmask/internal/db/postgres/pgcopy" - -type CycleResolutionOp struct { - Columns []string - FileName string - Row *pgcopy.Row -} - -func NewCycleResolutionOp(fileName string, columns []string) *CycleResolutionOp { - return &CycleResolutionOp{ - FileName: fileName, - Columns: columns, - Row: pgcopy.NewRow(len(columns)), - } -} diff --git a/internal/db/postgres/entries/table.go b/internal/db/postgres/entries/table.go index c171aa9a..42109663 100644 --- a/internal/db/postgres/entries/table.go +++ b/internal/db/postgres/entries/table.go @@ -30,9 +30,7 @@ import ( // TODO: Deduplicate SubsetQueries and SubsetInQueries by path type Table struct { *toolkit.Table - Query string - // CycleResolutionOps - list of columns and file to store that must be dumped for future cycles resolution - CycleResolutionOps []*CycleResolutionOp + Query string Owner string RelKind rune RootPtSchema string diff --git a/internal/db/postgres/subset/component.go b/internal/db/postgres/subset/component.go new file mode 100644 index 00000000..fa6b38e1 --- /dev/null +++ b/internal/db/postgres/subset/component.go @@ -0,0 +1,155 @@ +package subset + +import ( + "fmt" + "slices" + "sort" + "strings" + + "github.com/greenmaskio/greenmask/internal/db/postgres/entries" +) + +type Component struct { + id int + // componentGraph - contains the mapping of the vertexes in the component to the edges in the original graph + // if the component contains one vertex and no edges, then there is only one vertex with no cycles + componentGraph map[int][]*Edge + // tables - the vertexes in the component + tables map[int]*entries.Table + // Cycles + cycles [][]*Edge + cyclesIdents map[string]struct{} + keys []string +} + +func NewComponent(id int, componentGraph map[int][]*Edge, tables map[int]*entries.Table) *Component { + c := &Component{ + id: id, + componentGraph: componentGraph, + tables: tables, + cyclesIdents: make(map[string]struct{}), + } + c.findCycles() + if c.hasCycle() { + c.keys = c.getComponentKeys() + } else { + c.keys = c.getOneTable().PrimaryKey + } + + return c +} + +func (c *Component) getSubsetConds() []string { + var subsetConds []string + for _, table := range c.tables { + if len(table.SubsetConds) > 0 { + subsetConds = append(subsetConds, table.SubsetConds...) + } + } + return subsetConds +} + +func (c *Component) getOneTable() *entries.Table { + if !c.hasCycle() { + for _, table := range c.tables { + return table + } + } + panic("cannot call get one table method for cycled scc") +} + +func (c *Component) hasCycle() bool { + return len(c.cycles) > 0 +} + +// findCycles - finds all cycles in the component +func (c *Component) findCycles() { + visited := make(map[int]bool) + var path []*Edge + recStack := make(map[int]bool) + + // Collect and sort all vertices + var vertices []int + for v := range c.componentGraph { + vertices = append(vertices, v) + } + sort.Ints(vertices) // Ensure deterministic order + + for _, v := range vertices { + if !visited[v] { + c.findAllCyclesDfs(v, visited, recStack, path) + } + } +} + +// findAllCyclesDfs - the basic DFS algorithm adapted to find all cycles in the graph and collect the cycle vertices +func (c *Component) findAllCyclesDfs(v int, visited map[int]bool, recStack map[int]bool, path []*Edge) { + visited[v] = true + recStack[v] = true + + // Sort edges to ensure deterministic order + var edges []*Edge + edges = append(edges, c.componentGraph[v]...) + sort.Slice(edges, func(i, j int) bool { + return edges[i].to.idx < edges[j].to.idx + }) + + for _, to := range edges { + + path = append(path, to) + if !visited[to.idx] { + c.findAllCyclesDfs(to.idx, visited, recStack, path) + } else if recStack[to.idx] { + // Cycle detected + var cycle []*Edge + for idx := len(path) - 1; idx >= 0; idx-- { + cycle = append(cycle, path[idx]) + if path[idx].from.idx == to.to.idx { + break + } + } + cycleId := getCycleIdent(cycle) + if _, ok := c.cyclesIdents[cycleId]; !ok { + res := slices.Clone(cycle) + slices.Reverse(res) + c.cycles = append(c.cycles, res) + c.cyclesIdents[cycleId] = struct{}{} + } + } + path = path[:len(path)-1] + } + + recStack[v] = false +} + +func getCycleIdent(cycle []*Edge) string { + ids := make([]string, 0, len(cycle)) + for _, edge := range cycle { + ids = append(ids, fmt.Sprintf("%d", edge.id)) + } + slices.Sort(ids) + return strings.Join(ids, "_") +} + +func (c *Component) getComponentKeys() []string { + if len(c.cycles) > 1 { + panic("IMPLEMENT ME: multiple cycles in the component") + } + if !c.hasCycle() { + return c.getOneTable().PrimaryKey + } + + var vertexes []int + for _, edge := range c.cycles[0] { + vertexes = append(vertexes, edge.to.idx) + } + + var keys []string + for _, v := range vertexes { + table := c.tables[v] + for _, key := range table.PrimaryKey { + keys = append(keys, fmt.Sprintf(`%s__%s__%s`, table.Schema, table.Name, key)) + } + } + return keys +} diff --git a/internal/db/postgres/subset/component_link.go b/internal/db/postgres/subset/component_link.go new file mode 100644 index 00000000..f2560027 --- /dev/null +++ b/internal/db/postgres/subset/component_link.go @@ -0,0 +1,13 @@ +package subset + +type ComponentLink struct { + idx int + component *Component +} + +func NewComponentLink(idx int, c *Component, keys, overriddenKeys []string) *ComponentLink { + return &ComponentLink{ + idx: idx, + component: c, + } +} diff --git a/internal/db/postgres/subset/component_test.go b/internal/db/postgres/subset/component_test.go new file mode 100644 index 00000000..cdf99662 --- /dev/null +++ b/internal/db/postgres/subset/component_test.go @@ -0,0 +1,195 @@ +package subset + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/greenmaskio/greenmask/internal/db/postgres/entries" +) + +func TestComponent_findCycles(t *testing.T) { + c := &Component{ + cyclesIdents: make(map[string]struct{}), + componentGraph: map[int][]*Edge{ + 1: { + { + id: 1, + idx: 2, + from: &TableLink{ + idx: 1, + }, + to: &TableLink{ + idx: 2, + }, + }, + }, + 2: { + { + id: 2, + idx: 3, + from: &TableLink{ + idx: 2, + }, + to: &TableLink{ + idx: 3, + }, + }, + }, + 3: { + { + id: 3, + idx: 1, + from: &TableLink{ + idx: 3, + }, + to: &TableLink{ + idx: 1, + }, + }, + { + id: 4, + idx: 1, + from: &TableLink{ + idx: 3, + }, + to: &TableLink{ + idx: 1, + }, + }, + { + id: 5, + idx: 4, + from: &TableLink{ + idx: 3, + }, + to: &TableLink{ + idx: 4, + }, + }, + }, + 4: { + { + id: 6, + idx: 3, + from: &TableLink{ + idx: 4, + }, + to: &TableLink{ + idx: 3, + }, + }, + { + id: 7, + idx: 1, + from: &TableLink{ + idx: 4, + }, + to: &TableLink{ + idx: 1, + }, + }, + }, + }, + tables: map[int]*entries.Table{}, + } + + c.findCycles() + require.Len(t, c.cycles, 4) +} + +func TestComponent_findCycles_pt2(t *testing.T) { + c := &Component{ + componentGraph: map[int][]*Edge{ + 1: { + { + id: 1, + idx: 2, + from: &TableLink{ + idx: 1, + }, + to: &TableLink{ + idx: 2, + }, + }, + }, + 2: { + { + id: 2, + idx: 1, + from: &TableLink{ + idx: 2, + }, + to: &TableLink{ + idx: 1, + }, + }, + { + id: 3, + idx: 1, + from: &TableLink{ + idx: 2, + }, + to: &TableLink{ + idx: 1, + }, + }, + }, + }, + tables: map[int]*entries.Table{}, + cyclesIdents: make(map[string]struct{}), + } + + c.findCycles() + require.Len(t, c.cycles, 2) +} + +func BenchmarkComponent_findCycles(b *testing.B) { + c := &Component{ + cyclesIdents: make(map[string]struct{}), + componentGraph: map[int][]*Edge{ + 1: { + { + id: 1, + idx: 2, + from: &TableLink{ + idx: 1, + }, + to: &TableLink{ + idx: 2, + }, + }, + }, + 2: { + { + id: 2, + idx: 1, + from: &TableLink{ + idx: 2, + }, + to: &TableLink{ + idx: 1, + }, + }, + { + id: 3, + idx: 1, + from: &TableLink{ + idx: 2, + }, + to: &TableLink{ + idx: 1, + }, + }, + }, + }, + tables: map[int]*entries.Table{}, + } + + // Reset the timer to exclude the setup time from the benchmark + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.findCycles() + } +} diff --git a/internal/db/postgres/subset/condenced_edge.go b/internal/db/postgres/subset/condenced_edge.go new file mode 100644 index 00000000..0d697f62 --- /dev/null +++ b/internal/db/postgres/subset/condenced_edge.go @@ -0,0 +1,17 @@ +package subset + +type CondensedEdge struct { + id int + from *ComponentLink + to *ComponentLink + originalEdge *Edge +} + +func NewCondensedEdge(id int, from, to *ComponentLink, originalEdge *Edge) *CondensedEdge { + return &CondensedEdge{ + id: id, + from: from, + to: to, + originalEdge: originalEdge, + } +} diff --git a/internal/db/postgres/subset/cte.go b/internal/db/postgres/subset/cte.go new file mode 100644 index 00000000..fa13f660 --- /dev/null +++ b/internal/db/postgres/subset/cte.go @@ -0,0 +1,71 @@ +package subset + +import ( + "fmt" + "slices" + "strings" + + "github.com/greenmaskio/greenmask/internal/db/postgres/entries" +) + +type cteQuery struct { + items []*cteItem + c *Component +} + +func newCteQuery(c *Component) *cteQuery { + return &cteQuery{ + c: c, + } +} + +func (c *cteQuery) addItem(name, query string) { + c.items = append(c.items, &cteItem{ + name: name, + query: query, + }) +} + +func (c *cteQuery) generateQuery(targetTable *entries.Table) string { + var queries []string + var excludedCteQueries []string + if len(c.c.cycles) > 1 { + panic("IMPLEMENT ME") + } + for _, edge := range c.c.cycles[0] { + if edge.from.table.Oid == targetTable.Oid { + continue + } + excludedCteQuery := fmt.Sprintf("%s__%s__ids", edge.from.table.Schema, edge.from.table.Name) + excludedCteQueries = append(excludedCteQueries, excludedCteQuery) + } + + for _, item := range c.items { + if slices.Contains(excludedCteQueries, item.name) { + continue + } + queries = append(queries, fmt.Sprintf(" %s AS (%s)", item.name, item.query)) + } + var leftTableKeys, rightTableKeys []string + rightTableName := fmt.Sprintf("%s__%s__ids", targetTable.Schema, targetTable.Name) + for _, key := range targetTable.PrimaryKey { + leftTableKeys = append(leftTableKeys, fmt.Sprintf(`"%s"."%s"."%s"`, targetTable.Schema, targetTable.Name, key)) + rightTableKeys = append(rightTableKeys, fmt.Sprintf(`"%s"."%s"`, rightTableName, key)) + } + + resultingQuery := fmt.Sprintf( + `SELECT * FROM "%s"."%s" WHERE %s IN (SELECT %s FROM "%s")`, + targetTable.Schema, + targetTable.Name, + fmt.Sprintf("(%s)", strings.Join(leftTableKeys, ",")), + strings.Join(rightTableKeys, ","), + rightTableName, + ) + res := fmt.Sprintf("WITH RECURSIVE %s %s", strings.Join(queries, ","), resultingQuery) + return res +} + +type cteItem struct { + name string + query string +} diff --git a/internal/db/postgres/subset/dfs.go b/internal/db/postgres/subset/dfs.go deleted file mode 100644 index 6344d995..00000000 --- a/internal/db/postgres/subset/dfs.go +++ /dev/null @@ -1,57 +0,0 @@ -package subset - -import "slices" - -const ( - vertexIsNotVisited = iota - vertexIsVisitedAndPrecessing - vertexIsVisitedAndCompleted -) - -const ( - emptyFromValue = -1 -) - -// getCycle returns the cycle in the graph provided based on the from slice gathered in findCycleDfs function -func getCycle(from []int, lastEdge int) []int { - var cycle []int - for e := from[lastEdge]; e != lastEdge; e = from[e] { - cycle = append(cycle, e) - } - cycle = append(cycle, lastEdge) - slices.Reverse(cycle) - return cycle -} - -// FindAllCycles returns all cycles in the graph provided -// The result contains a slice of cycles, where each cycle is a slice of vertices in the order they appear in the cycle -func FindAllCycles(graph [][]*Edge) [][]int { - var allCycles [][]int - visited := make([]int, len(graph)) - from := make([]int, len(graph)) - for idx := range from { - from[idx] = emptyFromValue - } - for v := range graph { - if visited[v] == vertexIsNotVisited { - findAllCyclesDfs(graph, v, visited, from, &allCycles) - } - } - return allCycles -} - -// findAllCyclesDfs - the basic DFS algorithm adapted to find all cycles in the graph and collect the cycle vertices -func findAllCyclesDfs(graph [][]*Edge, v int, visited []int, from []int, allCycles *[][]int) { - visited[v] = vertexIsVisitedAndPrecessing - for _, to := range graph[v] { - if visited[to.Idx] == vertexIsNotVisited { - from[to.Idx] = v - findAllCyclesDfs(graph, to.Idx, visited, from, allCycles) - } else if visited[to.Idx] == vertexIsVisitedAndPrecessing { - from[to.Idx] = v - cycle := getCycle(from, to.Idx) - *allCycles = append(*allCycles, cycle) - } - } - visited[v] = vertexIsVisitedAndCompleted -} diff --git a/internal/db/postgres/subset/edge.go b/internal/db/postgres/subset/edge.go index 0aea4980..fc7c1626 100644 --- a/internal/db/postgres/subset/edge.go +++ b/internal/db/postgres/subset/edge.go @@ -1,24 +1,19 @@ package subset type Edge struct { - Id int - Idx int - A *TableLink - B *TableLink + id int + idx int + isNullable bool + from *TableLink + to *TableLink } -func NewEdge(id, idx int, a *TableLink, b *TableLink) *Edge { +func NewEdge(id, idx int, isNullable bool, a *TableLink, b *TableLink) *Edge { return &Edge{ - Id: id, - Idx: idx, - A: a, - B: b, + id: id, + idx: idx, + isNullable: isNullable, + from: a, + to: b, } } - -func (e *Edge) GetLeftAndRightTable(idx int) (*TableLink, *TableLink) { - if e.A.Idx == idx { - return e.A, e.B - } - return e.B, e.A -} diff --git a/internal/db/postgres/subset/graph.go b/internal/db/postgres/subset/graph.go index 2b0c6186..a6903cf3 100644 --- a/internal/db/postgres/subset/graph.go +++ b/internal/db/postgres/subset/graph.go @@ -1,9 +1,11 @@ package subset import ( + "cmp" "context" "fmt" "slices" + "strings" "github.com/greenmaskio/greenmask/pkg/toolkit" "github.com/jackc/pgx/v5" @@ -12,11 +14,17 @@ import ( "github.com/greenmaskio/greenmask/internal/db/postgres/entries" ) +const ( + sscVertexIsVisited = 1 + sscVertexIsNotVisited = -1 +) + var ( foreignKeyColumnsQuery = ` - SELECT n.nspname as fk_table_schema, - fk_ref_table.relname as fk_table_name, - array_agg(curr_table_attrs.attname) curr_table_columns + SELECT n.nspname as fk_table_schema, + fk_ref_table.relname as fk_table_name, + array_agg(curr_table_attrs.attname) curr_table_columns, + bool_or(NOT attnotnull) as is_nullable FROM pg_catalog.pg_constraint curr_table_con join pg_catalog.pg_class fk_ref_table on curr_table_con.confrelid = fk_ref_table.oid join pg_catalog.pg_namespace n on fk_ref_table.relnamespace = n.oid @@ -35,19 +43,27 @@ type Graph struct { tables []*entries.Table // graph - the oriented graph representation of the DB tables graph [][]*Edge - // cycledVertexes - it shows last vertex before cycled edge - cycledVertexes map[int][]*Edge - // cycles - the cycles in the graph with topological order - cycles [][]int - // Paths - the subset Paths for the tables. The key is the vertex index in the graph and the value is the path for + // graph - the oriented graph representation of the DB tables + reversedGraph [][]int + // scc - the strongly connected components in the graph + scc []*Component + // condensedGraph - the condensed graph representation of the DB tables + condensedGraph [][]*CondensedEdge + // componentsToOriginalVertexes - the mapping condensed graph vertexes to the original graph vertexes + componentsToOriginalVertexes map[int][]int + // paths - the subset paths for the tables. The key is the vertex index in the graph and the value is the path for // creating the subset query - Paths map[int]*Path - edges []*Edge + paths map[int]*Path + edges []*Edge + visited []int + order []int + sscCount int } // NewGraph creates a new graph based on the provided tables by finding the references in DB between them func NewGraph(ctx context.Context, tx pgx.Tx, tables []*entries.Table) (*Graph, error) { - orientedGraph := make([][]*Edge, len(tables)) + graph := make([][]*Edge, len(tables)) + reversedGraph := make([][]int, len(tables)) edges := make([]*Edge, 0) var edgeIdSequence int @@ -71,133 +87,391 @@ func NewGraph(ctx context.Context, tx pgx.Tx, tables []*entries.Table) (*Graph, edge := NewEdge( edgeIdSequence, foreignTableIdx, + ref.IsNullable, NewTableLink(idx, table, ref.ReferencedKeys), NewTableLink(foreignTableIdx, tables[foreignTableIdx], tables[foreignTableIdx].PrimaryKey), ) - orientedGraph[idx] = append( - orientedGraph[idx], + graph[idx] = append( + graph[idx], edge, ) + + reversedGraph[foreignTableIdx] = append( + reversedGraph[foreignTableIdx], + idx, + ) edges = append(edges, edge) edgeIdSequence++ } } return &Graph{ - tables: tables, - graph: orientedGraph, - cycledVertexes: make(map[int][]*Edge), - Paths: make(map[int]*Path), - edges: edges, + tables: tables, + graph: graph, + paths: make(map[int]*Path), + edges: edges, + visited: make([]int, len(tables)), + order: make([]int, 0), + reversedGraph: reversedGraph, }, nil } -// findCycles - finds all cycles in the graph -func (g *Graph) findCycles() { - visited := make([]int, len(g.graph)) - from := make([]int, len(g.graph)) - for idx := range from { - from[idx] = emptyFromValue - } - for v := range g.graph { - if visited[v] == vertexIsNotVisited { - g.findAllCyclesDfs(v, visited, from) +// findSubsetVertexes - finds the subset vertexes in the graph +func (g *Graph) findSubsetVertexes() { + for v := range g.condensedGraph { + path := NewPath(v) + var from, fullFrom []*CondensedEdge + if len(g.scc[v].getSubsetConds()) > 0 { + path.AddVertex(v) + } + g.subsetDfs(path, v, &fullFrom, &from, rootScopeId) + + if path.Len() > 0 { + g.paths[v] = path } } - g.debugCycles() } -// debugCycles - debugs the cycles in the graph -func (g *Graph) debugCycles() { - if len(g.cycles) == 0 { - return +func (g *Graph) subsetDfs(path *Path, v int, fullFrom, from *[]*CondensedEdge, scopeId int) { + for _, to := range g.condensedGraph[v] { + *fullFrom = append(*fullFrom, to) + *from = append(*from, to) + currentScopeId := scopeId + if len(g.scc[to.to.idx].getSubsetConds()) > 0 { + for _, e := range *from { + currentScopeId = path.AddEdge(e, currentScopeId) + } + *from = (*from)[:0] + } + g.subsetDfs(path, to.to.idx, fullFrom, from, currentScopeId) + *fullFrom = (*fullFrom)[:len(*fullFrom)-1] + if len(*from) > 0 { + *from = (*from)[:len(*from)-1] + } } +} - for _, foundCycle := range g.cycles { - var cycle []string - for _, v := range foundCycle { - cycle = append(cycle, fmt.Sprintf("%s.%s", g.tables[v].Schema, g.tables[v].Name)) +// findScc - finds the strongly connected components in the graph +func (g *Graph) findScc() []int { + g.order = g.order[:0] + g.eraseVisited() + for v := range g.graph { + if g.visited[v] == sscVertexIsNotVisited { + g.topologicalSortDfs(v) } - cycle = append(cycle, fmt.Sprintf("%s.%s", g.tables[foundCycle[0]].Schema, g.tables[foundCycle[0]].Name)) - if slices.ContainsFunc(foundCycle, func(i int) bool { - return len(g.tables[i].SubsetConds) > 0 - }) { - log.Warn().Strs("cycle", cycle).Msg("cycle detected") - panic("IMPLEMENT ME: cycle detected: implement cycles resolution") + } + slices.Reverse(g.order) + + g.eraseVisited() + var sscCount int + for _, v := range g.order { + if g.visited[v] == sscVertexIsNotVisited { + g.markComponentDfs(v, sscCount) + sscCount++ } } + g.sscCount = sscCount + return g.visited +} +func (g *Graph) eraseVisited() { + for idx := range g.visited { + g.visited[idx] = sscVertexIsNotVisited + } } -// findAllCyclesDfs - the basic DFS algorithm adapted to find all cycles in the graph and collect the cycle vertices -func (g *Graph) findAllCyclesDfs(v int, visited []int, from []int) { - visited[v] = vertexIsVisitedAndPrecessing +func (g *Graph) topologicalSortDfs(v int) { + g.visited[v] = sscVertexIsVisited for _, to := range g.graph[v] { - if visited[to.Idx] == vertexIsNotVisited { - from[to.Idx] = v - g.findAllCyclesDfs(to.Idx, visited, from) - } else if visited[to.Idx] == vertexIsVisitedAndPrecessing { - from[to.Idx] = v - g.cycles = append(g.cycles, g.getCycle(from, to.Idx)) + if g.visited[to.idx] == sscVertexIsNotVisited { + g.topologicalSortDfs(to.idx) } } - visited[v] = vertexIsVisitedAndCompleted + g.order = append(g.order, v) } -// getCycle returns the cycle in the graph provided based on the "from" slice -func (g *Graph) getCycle(from []int, lastVertex int) []int { - var cycle []int - for v := from[lastVertex]; v != lastVertex; v = from[v] { - cycle = append(cycle, v) +func (g *Graph) markComponentDfs(v, component int) { + g.visited[v] = component + for _, to := range g.reversedGraph[v] { + if g.visited[to] == sscVertexIsNotVisited { + g.markComponentDfs(to, component) + } } - cycle = append(cycle, lastVertex) - slices.Reverse(cycle) - return cycle } -// findSubsetVertexes - finds the subset vertexes in the graph -func (g *Graph) findSubsetVertexes() { - for v := range g.graph { - path := NewPath(v) - visited := make([]int, len(g.graph)) - var from, fullFrom []*Edge - if len(g.tables[v].SubsetConds) > 0 { - path.AddVertex(v) +func (g *Graph) buildCondensedGraph() { + g.findScc() + + originalVertexesToComponents := g.visited + componentsToOriginalVertexes := make(map[int][]int, g.sscCount) + for vertexIdx, componentIdx := range originalVertexesToComponents { + componentsToOriginalVertexes[componentIdx] = append(componentsToOriginalVertexes[componentIdx], vertexIdx) + } + g.componentsToOriginalVertexes = componentsToOriginalVertexes + + // 1. Collect all tables for the component + // 2. Find all edges within the component + condensedEdges := make(map[int]struct{}) + var ssc []*Component + for componentIdx := 0; componentIdx < g.sscCount; componentIdx++ { + + tables := make(map[int]*entries.Table) + for _, vertexIdx := range componentsToOriginalVertexes[componentIdx] { + tables[vertexIdx] = g.tables[vertexIdx] } - g.subsetDfs(path, v, &fullFrom, &from, visited, rootScopeId) - if path.Len() > 0 { - g.Paths[v] = path + componentGraph := make(map[int][]*Edge) + for _, vertexIdx := range componentsToOriginalVertexes[componentIdx] { + var edges []*Edge + for _, e := range g.graph[vertexIdx] { + if slices.Contains(componentsToOriginalVertexes[componentIdx], e.to.idx) { + edges = append(edges, e) + condensedEdges[e.id] = struct{}{} + } + } + componentGraph[vertexIdx] = edges } + + ssc = append(ssc, NewComponent(componentIdx, componentGraph, tables)) + } + g.scc = ssc + + // 3. Build condensed graph + g.condensedGraph = make([][]*CondensedEdge, g.sscCount) + var condensedEdgeIdxSeq int + for _, edge := range g.edges { + if _, ok := condensedEdges[edge.id]; ok { + continue + } + + fromLinkIdx := originalVertexesToComponents[edge.from.idx] + fromLink := NewComponentLink( + fromLinkIdx, + ssc[fromLinkIdx], + edge.from.keys, + overrideKeys(edge.from.table, edge.from.keys), + ) + toLinkIdx := originalVertexesToComponents[edge.to.idx] + toLink := NewComponentLink( + toLinkIdx, + ssc[toLinkIdx], + edge.to.keys, + overrideKeys(edge.to.table, edge.to.keys), + ) + condensedEdge := NewCondensedEdge(condensedEdgeIdxSeq, fromLink, toLink, edge) + g.condensedGraph[fromLinkIdx] = append(g.condensedGraph[fromLinkIdx], condensedEdge) + condensedEdgeIdxSeq++ } } -func (g *Graph) subsetDfs(path *Path, v int, fullFrom, from *[]*Edge, visited []int, scopeId int) { - visited[v] = vertexIsVisitedAndPrecessing - for _, to := range g.graph[v] { - *fullFrom = append(*fullFrom, to) - *from = append(*from, to) - currentScopeId := scopeId - if visited[to.Idx] == vertexIsNotVisited { - if len(g.tables[to.Idx].SubsetConds) > 0 { - for _, e := range *from { - currentScopeId = path.AddEdge(e, currentScopeId) - } - *from = (*from)[:0] - } - g.subsetDfs(path, to.Idx, fullFrom, from, visited, currentScopeId) - } else if visited[to.Idx] == vertexIsVisitedAndPrecessing { - // if the vertex is visited and processing, it means that we found a cycle, and we need to mark the edge - // as cycled and collect the cycle. This data will be used later for cycle resolution - log.Debug().Msg("cycle detected") - g.cycledVertexes[to.Id] = slices.Clone(*fullFrom) +func (g *Graph) generateAndSetQueryForTable(path *Path) { + // We start DFS from the root scope + rootVertex := g.scc[path.rootVertex] + table := rootVertex.getOneTable() + query := g.generateQueriesDfs(path, nil) + table.Query = query +} + +func (g *Graph) generateAndSetQueryForScc(path *Path) { + // We start DFS from the root scope + rootVertex := g.scc[path.rootVertex] + cq := newCteQuery(rootVertex) + g.generateQueriesSccDfs(cq, path, nil) + for _, t := range rootVertex.tables { + query := cq.generateQuery(t) + t.Query = query + } +} + +func (g *Graph) generateQueriesSccDfs(cq *cteQuery, path *Path, scopeEdge *ScopeEdge) { + scopeId := rootScopeId + if scopeEdge != nil { + scopeId = scopeEdge.scopeId + } + if len(path.scopeEdges[scopeId]) == 0 && scopeEdge != nil { + return + } + + g.generateQueryForScc(cq, scopeId, path, scopeEdge) + for _, nextScopeEdge := range path.scopeGraph[scopeId] { + g.generateQueriesSccDfs(cq, path, nextScopeEdge) + } +} + +func (g *Graph) generateQueryForScc(cq *cteQuery, scopeId int, path *Path, prevScopeEdge *ScopeEdge) { + edges := path.scopeEdges[scopeId] + nextScopeEdges := path.scopeGraph[scopeId] + rootVertex := g.scc[path.rootVertex] + if prevScopeEdge != nil { + // If prevScopeEdge != nil then we have subquery + edges = edges[1:] + rootVertex = prevScopeEdge.originalCondensedEdge.to.component + } + if len(rootVertex.cycles) > 1 { + panic("IMPLEMENT ME: more than one cycle found in SCC") + } + + cycle := orderCycle(rootVertex.cycles[0], edges, path.scopeGraph[scopeId]) + g.generateRecursiveQueriesForCycle(cq, scopeId, cycle, edges, nextScopeEdges) + g.generateQueriesForVertexesInCycle(cq, scopeId, cycle) +} + +func (g *Graph) generateQueriesForVertexesInCycle(cq *cteQuery, scopeId int, cycle []*Edge) { + for _, t := range getTablesFromCycle(cycle) { + queryName := fmt.Sprintf("%s__%s__ids", t.Schema, t.Name) + query := generateAllTablesValidPkSelection(cycle, scopeId, t) + cq.addItem(queryName, query) + } +} + +func (g *Graph) generateRecursiveQueriesForCycle( + cq *cteQuery, scopeId int, cycle []*Edge, rest []*CondensedEdge, nextScopeEdges []*ScopeEdge, +) { + var ( + cycleId = getCycleIdent(cycle) + overriddenTableNames = make(map[toolkit.Oid]string) + ) + + rest = slices.Clone(rest) + for _, se := range nextScopeEdges { + t := se.originalCondensedEdge.originalEdge.to.table + overriddenTableNames[t.Oid] = fmt.Sprintf("%s__%s__ids", t.Schema, t.Name) + rest = append(rest, se.originalCondensedEdge) + } + + //var unionQueries []string + shiftedCycle := slices.Clone(cycle) + for idx := 1; idx <= len(cycle); idx++ { + var ( + mainTable = shiftedCycle[0].from.table + // queryName - name of a query in the recursive CTE + // where: + // * s - scope id + // * c - cycle id + // * pt1 - part 1 of the recursive query + queryName = fmt.Sprintf("__s%d__c%s__%s__%s", scopeId, cycleId, mainTable.Schema, mainTable.Name) + filteredQueryName = fmt.Sprintf("%s__filtered", queryName) + ) + + query := generateQuery(queryName, shiftedCycle, rest, overriddenTableNames) + cq.addItem(queryName, query) + filteredQuery := generateIntegrityCheckJoinConds(shiftedCycle, mainTable, queryName) + cq.addItem(filteredQueryName, filteredQuery) + shiftedCycle = shiftCycle(shiftedCycle) + } +} + +func (g *Graph) generateQueriesDfs(path *Path, scopeEdge *ScopeEdge) string { + // TODO: + // 1. Add scopeEdges support and LEFT JOIN + // 2. Consider how to implement LEFT JOIN for WHERE IN clause (maybe use cond ISNULL OR IN) + scopeId := rootScopeId + if scopeEdge != nil { + scopeId = scopeEdge.scopeId + } + if len(path.scopeEdges[scopeId]) == 0 && scopeEdge != nil { + return "" + } + + currentScopeQuery := g.generateQueryForTables(path, scopeEdge) + var subQueries []string + for _, nextScope := range path.scopeGraph[scopeId] { + subQuery := g.generateQueriesDfs(path, nextScope) + if subQuery != "" { + subQueries = append(subQueries, subQuery) } - *fullFrom = (*fullFrom)[:len(*fullFrom)-1] - if len(*from) > 0 { - *from = (*from)[:len(*from)-1] + } + + if len(subQueries) == 0 { + return currentScopeQuery + } + + totalQuery := fmt.Sprintf( + "%s AND %s", currentScopeQuery, + strings.Join(subQueries, " AND "), + ) + return totalQuery +} + +func (g *Graph) generateQueryForTables(path *Path, scopeEdge *ScopeEdge) string { + scopeId := rootScopeId + if scopeEdge != nil { + scopeId = scopeEdge.scopeId + } + var edges []*Edge + for _, se := range path.scopeEdges[scopeId] { + edges = append(edges, se.originalEdge) + } + + // Use root table as a root table from path + rootVertex := g.scc[path.rootVertex] + rootTable := rootVertex.getOneTable() + if scopeEdge != nil { + // If it is not a root scope use the right table from the first edge as a root table + // And left table from the first edge as a left table for the subquery. It will be used for where in clause + rootTable = scopeEdge.originalCondensedEdge.originalEdge.to.table + edges = edges[1:] + } + + whereConds := slices.Clone(rootTable.SubsetConds) + selectClause := fmt.Sprintf(`SELECT "%s"."%s".*`, rootTable.Schema, rootTable.Name) + if scopeEdge != nil { + selectClause = generateSelectByPrimaryKey(rootTable, rootTable.PrimaryKey) + } + fromClause := fmt.Sprintf(`FROM "%s"."%s" `, rootTable.Schema, rootTable.Name) + + var joinClauses []string + + nullabilityMap := make(map[int]bool) + for _, e := range edges { + isNullable := e.isNullable + if !isNullable { + isNullable = nullabilityMap[e.from.idx] + } + nullabilityMap[e.to.idx] = isNullable + joinType := joinTypeInner + if isNullable { + joinType = joinTypeLeft } + joinClause := generateJoinClauseV2(e, joinType, make(map[toolkit.Oid]string)) + joinClauses = append(joinClauses, joinClause) + } + integrityChecks := generateIntegrityChecksForNullableEdges(nullabilityMap, edges, make(map[toolkit.Oid]string)) + whereConds = append(whereConds, integrityChecks...) + + query := fmt.Sprintf( + `%s %s %s %s`, + selectClause, + fromClause, + strings.Join(joinClauses, " "), + generateWhereClause(whereConds), + ) + + if scopeEdge != nil { + var leftTableConds []string + originalEdge := scopeEdge.originalCondensedEdge.originalEdge + for _, k := range originalEdge.from.keys { + leftTable := originalEdge.from.table + leftTableConds = append(leftTableConds, fmt.Sprintf(`"%s"."%s"."%s"`, leftTable.Schema, leftTable.Name, k)) + } + query = fmt.Sprintf("((%s) IN (%s))", strings.Join(leftTableConds, ", "), query) + + if scopeEdge.isNullable { + var nullableChecks []string + for _, k := range originalEdge.from.keys { + nullableCheck := fmt.Sprintf(`"%s"."%s"."%s" IS NULL`, originalEdge.from.table.Schema, originalEdge.from.table.Name, k) + nullableChecks = append(nullableChecks, nullableCheck) + } + query = fmt.Sprintf( + "((%s) OR %s)", + strings.Join(nullableChecks, " AND "), + query, + ) + } + } - visited[v] = vertexIsNotVisited + + return query } func getReferences(ctx context.Context, tx pgx.Tx, tableOid toolkit.Oid) ([]*toolkit.Reference, error) { @@ -209,10 +483,311 @@ func getReferences(ctx context.Context, tx pgx.Tx, tableOid toolkit.Oid) ([]*too defer rows.Close() for rows.Next() { ref := &toolkit.Reference{} - if err = rows.Scan(&ref.Schema, &ref.Name, &ref.ReferencedKeys); err != nil { + if err = rows.Scan(&ref.Schema, &ref.Name, &ref.ReferencedKeys, &ref.IsNullable); err != nil { return nil, fmt.Errorf("error scanning ForeignKeyColumnsQuery: %w", err) } refs = append(refs, ref) } return refs, nil } + +func overrideKeys(table *entries.Table, keys []string) []string { + var res []string + for _, k := range keys { + res = append(res, fmt.Sprintf(`"%s.%s.%s"`, table.Schema, table.Name, k)) + } + return res +} + +func isPathForScc(path *Path, graph *Graph) bool { + return graph.scc[path.rootVertex].hasCycle() +} + +func orderCycle(cycle []*Edge, subsetJoins []*CondensedEdge, scopeEdges []*ScopeEdge) []*Edge { + var ( + vertexes []int + valuableEdgesIdx int + ) + + for _, e := range cycle { + vertexes = append(vertexes, e.from.idx) + } + + for _, sj := range subsetJoins { + if slices.Contains(vertexes, sj.originalEdge.from.idx) { + valuableEdgesIdx = slices.IndexFunc(cycle, func(e *Edge) bool { + return sj.originalEdge.from.idx == e.from.idx + }) + if !sj.originalEdge.isNullable { + break + } + } + } + + for _, se := range scopeEdges { + if slices.Contains(vertexes, se.originalCondensedEdge.from.idx) { + valuableEdgesIdx = slices.IndexFunc(cycle, func(e *Edge) bool { + return se.originalCondensedEdge.originalEdge.from.idx == e.from.idx + }) + if !se.originalCondensedEdge.originalEdge.isNullable { + break + } + } + } + + if valuableEdgesIdx == -1 { + panic("is not found") + } + + resCycle := slices.Clone(cycle[valuableEdgesIdx:]) + resCycle = append(resCycle, cycle[:valuableEdgesIdx]...) + return resCycle +} + +func generateQuery(queryName string, cycle []*Edge, rest []*CondensedEdge, overriddenTables map[toolkit.Oid]string) string { + var ( + selectKeys []string + initialJoins, recursiveJoins []string + initialWhereConds, recursiveWhereConds []string + integrityCheck string + cycleSubsetConds []string + edges = slices.Clone(cycle[:len(cycle)-1]) + droppedEdge = cycle[len(cycle)-1] + ) + for _, ce := range rest { + edges = append(edges, ce.originalEdge) + } + + for _, t := range getTablesFromCycle(cycle) { + var keysWithAliases []string + for _, k := range t.PrimaryKey { + keysWithAliases = append(keysWithAliases, fmt.Sprintf(`"%s"."%s"."%s" as "%s__%s__%s"`, t.Schema, t.Name, k, t.Schema, t.Name, k)) + } + selectKeys = append(selectKeys, keysWithAliases...) + if len(t.SubsetConds) > 0 { + cycleSubsetConds = append(cycleSubsetConds, t.SubsetConds...) + } + } + + var droppedKeysWithAliases []string + for _, k := range droppedEdge.from.keys { + t := droppedEdge.from.table + droppedKeysWithAliases = append(droppedKeysWithAliases, fmt.Sprintf(`"%s"."%s"."%s" as "%s__%s__%s"`, t.Schema, t.Name, k, t.Schema, t.Name, k)) + } + selectKeys = append(selectKeys, droppedKeysWithAliases...) + + var initialPathSelectionKeys []string + for _, k := range edges[0].from.table.PrimaryKey { + t := edges[0].from.table + pathName := fmt.Sprintf( + `ARRAY["%s"."%s"."%s"] AS %s__%s__%s__path`, + t.Schema, t.Name, k, + t.Schema, t.Name, k, + ) + initialPathSelectionKeys = append(initialPathSelectionKeys, pathName) + } + + initialKeys := slices.Clone(selectKeys) + initialKeys = append(initialKeys, initialPathSelectionKeys...) + initFromClause := fmt.Sprintf(`FROM "%s"."%s" `, edges[0].from.table.Schema, edges[0].from.table.Name) + integrityCheck = "TRUE AS valid" + initialKeys = append(initialKeys, integrityCheck) + initialWhereConds = append(initialWhereConds, cycleSubsetConds...) + + initialSelect := fmt.Sprintf("SELECT %s", strings.Join(initialKeys, ", ")) + nullabilityMap := make(map[int]bool) + for _, e := range edges { + isNullable := e.isNullable + if !isNullable { + isNullable = nullabilityMap[e.from.idx] + } + nullabilityMap[e.to.idx] = isNullable + joinType := joinTypeInner + if isNullable { + joinType = joinTypeLeft + } + initialJoins = append(initialJoins, generateJoinClauseV2(e, joinType, overriddenTables)) + } + + integrityChecks := generateIntegrityChecksForNullableEdges(nullabilityMap, edges, overriddenTables) + initialWhereConds = append(initialWhereConds, integrityChecks...) + initialWhereClause := generateWhereClause(initialWhereConds) + initialQuery := fmt.Sprintf(`%s %s %s %s`, + initialSelect, initFromClause, strings.Join(initialJoins, " "), initialWhereClause, + ) + + recursiveIntegrityChecks := slices.Clone(cycleSubsetConds) + recursiveIntegrityChecks = append(recursiveIntegrityChecks, integrityChecks...) + recursiveIntegrityCheck := fmt.Sprintf("(%s) AS valid", strings.Join(recursiveIntegrityChecks, " AND ")) + recursiveKeys := slices.Clone(selectKeys) + for _, k := range edges[0].from.table.PrimaryKey { + t := edges[0].from.table + //recursivePathSelectionKeys = append(recursivePathSelectionKeys, fmt.Sprintf(`coalesce("%s"."%s"."%s"::TEXT, 'NULL')`, t.Schema, t.Name, k)) + + pathName := fmt.Sprintf( + `"%s__%s__%s__path" || ARRAY["%s"."%s"."%s"]`, + t.Schema, t.Name, k, + t.Schema, t.Name, k, + ) + recursiveKeys = append(recursiveKeys, pathName) + } + recursiveKeys = append(recursiveKeys, recursiveIntegrityCheck) + + recursiveSelect := fmt.Sprintf("SELECT %s", strings.Join(recursiveKeys, ", ")) + recursiveFromClause := fmt.Sprintf(`FROM "%s" `, queryName) + recursiveJoins = append(recursiveJoins, generateJoinClauseForDroppedEdge(droppedEdge, queryName)) + nullabilityMap = make(map[int]bool) + for _, e := range edges { + isNullable := e.isNullable + if !isNullable { + isNullable = nullabilityMap[e.from.idx] + } + nullabilityMap[e.to.idx] = isNullable + joinType := joinTypeInner + if isNullable { + joinType = joinTypeLeft + } + recursiveJoins = append(recursiveJoins, generateJoinClauseV2(e, joinType, overriddenTables)) + } + + recursiveValidCond := fmt.Sprintf(`"%s"."%s"`, queryName, "valid") + recursiveWhereConds = append(recursiveWhereConds, recursiveValidCond) + for _, k := range edges[0].from.table.PrimaryKey { + t := edges[0].from.table + + recursivePathCheck := fmt.Sprintf( + `NOT "%s"."%s"."%s" = ANY("%s"."%s__%s__%s__%s")`, + t.Schema, t.Name, k, + queryName, t.Schema, t.Name, k, "path", + ) + + recursiveWhereConds = append(recursiveWhereConds, recursivePathCheck) + } + recursiveWhereClause := generateWhereClause(recursiveWhereConds) + + recursiveQuery := fmt.Sprintf(`%s %s %s %s`, + recursiveSelect, recursiveFromClause, strings.Join(recursiveJoins, " "), recursiveWhereClause, + ) + + query := fmt.Sprintf("( %s ) UNION ( %s )", initialQuery, recursiveQuery) + return query +} + +func getTablesFromCycle(cycle []*Edge) (res []*entries.Table) { + for _, e := range cycle { + res = append(res, e.to.table) + } + slices.SortFunc(res, func(a, b *entries.Table) int { + return cmp.Compare(a.Oid, b.Oid) + }) + return res +} + +func generateIntegrityChecksForNullableEdges(nullabilityMap map[int]bool, edges []*Edge, overriddenTables map[toolkit.Oid]string) (res []string) { + // generate conditional checks for foreign tables that has left joins + + for _, e := range edges { + if isNullable := nullabilityMap[e.to.idx]; !isNullable { + continue + } + var keys []string + for idx := range e.from.keys { + leftTableKey := e.from.keys + rightTableKey := e.to.keys + k := fmt.Sprintf( + `("%s"."%s"."%s" IS NULL OR "%s"."%s"."%s" IS NOT NULL)`, + e.from.table.Schema, + e.from.table.Name, + leftTableKey[idx], + e.to.table.Schema, + e.to.table.Name, + rightTableKey[idx], + ) + if _, ok := overriddenTables[e.to.table.Oid]; ok { + k = fmt.Sprintf( + `("%s"."%s"."%s" IS NULL OR "%s"."%s" IS NOT NULL)`, + e.from.table.Schema, + e.from.table.Name, + leftTableKey[idx], + overriddenTables[e.to.table.Oid], + rightTableKey[idx], + ) + } + keys = append(keys, k) + } + res = append(res, fmt.Sprintf("(%s)", strings.Join(keys, " AND "))) + } + return +} + +func generateIntegrityCheckJoinConds(cycle []*Edge, table *entries.Table, tableName string) string { + + var ( + allPks []string + mainTablePks []string + unnestSelections []string + ) + + for _, t := range getTablesFromCycle(cycle) { + for _, k := range t.PrimaryKey { + key := fmt.Sprintf(`"%s"."%s__%s__%s"`, tableName, t.Schema, t.Name, k) + allPks = append(allPks, key) + if t.Oid == table.Oid { + pathName := fmt.Sprintf(`"%s"."%s__%s__%s__path"`, tableName, t.Schema, t.Name, k) + mainTablePks = append(mainTablePks, key) + unnestSelection := fmt.Sprintf(`unnest(%s) AS "%s"`, pathName, k) + unnestSelections = append(unnestSelections, unnestSelection) + } + } + } + + unnestQuery := fmt.Sprintf( + `SELECT %s FROM "%s" WHERE NOT "%s"."valid"`, + strings.Join(unnestSelections, ", "), + tableName, + tableName, + ) + + filteredQuery := fmt.Sprintf( + `SELECT DISTINCT %s FROM "%s" WHERE (%s) NOT IN (%s)`, + strings.Join(allPks, ", "), + tableName, + strings.Join(mainTablePks, ", "), + unnestQuery, + ) + + return filteredQuery +} + +func generateAllTablesValidPkSelection(cycle []*Edge, scopeId int, forTable *entries.Table) string { + + var unionParts []string + + for _, t := range getTablesFromCycle(cycle) { + var ( + selectionKeys []string + cycleId = getCycleIdent(cycle) + filteredQueryName = fmt.Sprintf("__s%d__c%s__%s__%s__filtered", scopeId, cycleId, t.Schema, t.Name) + ) + + for _, k := range forTable.PrimaryKey { + key := fmt.Sprintf(`"%s"."%s__%s__%s" AS "%s"`, filteredQueryName, forTable.Schema, forTable.Name, k, k) + selectionKeys = append(selectionKeys, key) + } + + query := fmt.Sprintf( + `SELECT DISTINCT %s FROM "%s"`, + strings.Join(selectionKeys, ", "), + filteredQueryName, + ) + unionParts = append(unionParts, query) + } + res := strings.Join(unionParts, " UNION ") + return res +} + +func shiftCycle(cycle []*Edge) (res []*Edge) { + res = append(res, cycle[len(cycle)-1]) + res = append(res, cycle[:len(cycle)-1]...) + return +} diff --git a/internal/db/postgres/subset/graph_test.go b/internal/db/postgres/subset/graph_test.go deleted file mode 100644 index b87e3129..00000000 --- a/internal/db/postgres/subset/graph_test.go +++ /dev/null @@ -1 +0,0 @@ -package subset diff --git a/internal/db/postgres/subset/path.go b/internal/db/postgres/subset/path.go index 85f7b856..58736233 100644 --- a/internal/db/postgres/subset/path.go +++ b/internal/db/postgres/subset/path.go @@ -1,89 +1,123 @@ package subset -import "slices" +import ( + "slices" +) const rootScopeId = 0 type Path struct { - RootVertex int - // Vertexes contains all the vertexes that are in the subset of the RootVertex vertex - Vertexes []int - // ScopeEdges - edges that are in the same scope with proper order - ScopeEdges map[int][]*Edge - // ScopeGraph - graph scope to scope connections - ScopeGraph map[int][]int - Edges []*Edge - CycledEdges map[int][]int + rootVertex int + // vertexes contains all the vertexes that are in the subset of the rootVertex vertex + vertexes []int + // scopeEdges - edges that are in the same scope with proper order + scopeEdges map[int][]*CondensedEdge + scopeEdgesNullable map[int]map[int]bool + // scopeGraph - graph scope to scope connections + scopeGraph map[int][]*ScopeEdge + edges []*CondensedEdge + graph map[int][]*CondensedEdge scopeIdSequence int } func NewPath(rootVertex int) *Path { return &Path{ - RootVertex: rootVertex, - CycledEdges: make(map[int][]int), - ScopeGraph: make(map[int][]int), - ScopeEdges: make(map[int][]*Edge), - scopeIdSequence: rootScopeId, + rootVertex: rootVertex, + scopeGraph: make(map[int][]*ScopeEdge), + scopeEdges: make(map[int][]*CondensedEdge), + scopeEdgesNullable: make(map[int]map[int]bool), + scopeIdSequence: rootScopeId, + graph: make(map[int][]*CondensedEdge), } } func (p *Path) AddVertex(v int) { - p.Vertexes = append(p.Vertexes, v) + p.vertexes = append(p.vertexes, v) } // AddEdge adds the edge to the path and return it scope -func (p *Path) AddEdge(e *Edge, scopeId int) int { - if len(p.Vertexes) == 0 { +func (p *Path) AddEdge(e *CondensedEdge, scopeId int) int { + p.addEdgeToGraph(e) + if len(p.vertexes) == 0 { // if there are no vertexes in the path, add the first (root) vertex - p.AddVertex(e.A.Idx) + p.AddVertex(e.from.idx) } return p.addEdge(e, scopeId) } -func (p *Path) MarkEdgeCycled(id int) { - p.CycledEdges[id] = []int{} +func (p *Path) addEdgeToGraph(e *CondensedEdge) { + p.graph[e.from.idx] = append(p.graph[e.from.idx], e) + if _, ok := p.graph[e.to.idx]; !ok { + p.graph[e.to.idx] = nil + } } func (p *Path) Len() int { - return len(p.Vertexes) + return len(p.vertexes) } -func (p *Path) addEdge(e *Edge, scopeId int) int { +func (p *Path) addEdge(e *CondensedEdge, scopeId int) int { if scopeId > p.scopeIdSequence { panic("scopeId is greater than the sequence") } p.createScopeIfNotExist(scopeId) - // If the vertex is already in the scope then fork the scope and put the edge in the new scope - if slices.ContainsFunc(p.ScopeEdges[scopeId], func(edge *Edge) bool { - return edge.A.Idx == e.B.Idx || edge.B.Idx == e.B.Idx - }) { - p.scopeIdSequence++ - parestScopeId := scopeId - scopeId = p.scopeIdSequence - p.createScopeWithParent(parestScopeId, scopeId) + // If the vertex is already in the scope (or has cycle) then fork the scope and put the edge in the new scope + if e.to.component.hasCycle() || vertexIsInScope(p.scopeEdges[scopeId], e) { + scopeId = p.createScopeWithParent(scopeId, e) + } else { + isNullable := e.originalEdge.isNullable + if !isNullable { + isNullable = slices.ContainsFunc(p.scopeEdges[scopeId], func(edge *CondensedEdge) bool { + return edge.originalEdge.to.idx == e.from.idx && edge.originalEdge.isNullable + }) + } + p.scopeEdgesNullable[scopeId][e.to.idx] = isNullable + p.scopeEdges[scopeId] = append(p.scopeEdges[scopeId], e) } - - p.ScopeEdges[scopeId] = append(p.ScopeEdges[scopeId], e) - p.Edges = append(p.Edges, e) - p.Vertexes = append(p.Vertexes, e.B.Idx) + p.edges = append(p.edges, e) + p.vertexes = append(p.vertexes, e.to.idx) return scopeId } func (p *Path) createScopeIfNotExist(scopeId int) { - if _, ok := p.ScopeEdges[scopeId]; !ok { - p.ScopeEdges[scopeId] = nil - p.ScopeGraph[scopeId] = nil + if _, ok := p.scopeEdges[scopeId]; !ok { + p.scopeEdges[scopeId] = nil + p.scopeGraph[scopeId] = nil + p.scopeEdgesNullable[scopeId] = make(map[int]bool) } } -func (p *Path) createScopeWithParent(parentScopeId, scopeId int) { - if _, ok := p.ScopeEdges[scopeId]; ok { +func (p *Path) createScopeWithParent(parentScopeId int, e *CondensedEdge) int { + p.scopeIdSequence++ + scopeId := p.scopeIdSequence + + if _, ok := p.scopeEdges[scopeId]; ok { panic("scope already exists") } // Create empty new scope - p.ScopeEdges[scopeId] = nil - p.ScopeGraph[scopeId] = nil + p.scopeEdges[scopeId] = nil + p.scopeGraph[scopeId] = nil + p.scopeEdgesNullable[scopeId] = make(map[int]bool) + p.scopeEdges[scopeId] = append(p.scopeEdges[scopeId], e) + // Add the new scope to the parent scope - p.ScopeGraph[parentScopeId] = append(p.ScopeGraph[parentScopeId], scopeId) + isNullable := e.originalEdge.isNullable + if !isNullable { + isNullable = p.scopeEdgesNullable[parentScopeId][e.from.idx] + } + + scopeEdge := &ScopeEdge{ + scopeId: scopeId, + originalCondensedEdge: e, + isNullable: isNullable, + } + p.scopeGraph[parentScopeId] = append(p.scopeGraph[parentScopeId], scopeEdge) + return scopeId +} + +func vertexIsInScope(scopeEdges []*CondensedEdge, e *CondensedEdge) bool { + return slices.ContainsFunc(scopeEdges, func(edge *CondensedEdge) bool { + return edge.from.idx == e.to.idx || edge.to.idx == e.to.idx + }) } diff --git a/internal/db/postgres/subset/query.go b/internal/db/postgres/subset/query.go index a7d67f45..334a951d 100644 --- a/internal/db/postgres/subset/query.go +++ b/internal/db/postgres/subset/query.go @@ -2,129 +2,101 @@ package subset import ( "fmt" - "slices" "strings" "github.com/greenmaskio/greenmask/internal/db/postgres/entries" - "github.com/rs/zerolog/log" + "github.com/greenmaskio/greenmask/pkg/toolkit" ) -func generateAndSetQuery(path *Path, tables []*entries.Table) { - // We start DFS from the root scope - table := tables[path.RootVertex] - if table.Name == "businessentity" { - log.Debug() - } - query := generateQueriesDfs(path, tables, rootScopeId, false) - fmt.Printf("%s.%s\n", table.Schema, table.Name) - fmt.Println(query) - table.Query = query -} +const ( + joinTypeInner = "INNER" + joinTypeLeft = "LEFT" +) -func generateQueriesDfs(path *Path, tables []*entries.Table, scopeId int, isSubQuery bool) string { +func generateJoinClauseForDroppedEdge(edge *Edge, initTableName string) string { + var conds []string - if len(path.ScopeEdges[scopeId]) == 0 && isSubQuery { - return "" - } - currentScopeQuery := generateQuery(tables, path.RootVertex, path.ScopeEdges[scopeId], isSubQuery) - var subQueries []string - for _, nextScopeId := range path.ScopeGraph[scopeId] { - subQuery := generateQueriesDfs(path, tables, nextScopeId, true) - if subQuery != "" { - subQueries = append(subQueries, subQuery) - } + var leftTableKeys []string + keys := edge.from.keys + table := edge.from.table + for _, key := range keys { + leftTableKeys = append(leftTableKeys, fmt.Sprintf(`%s__%s__%s`, table.Schema, table.Name, key)) } - if len(subQueries) == 0 { - return currentScopeQuery - } + rightTable := edge.to + for idx := 0; idx < len(edge.to.keys); idx++ { - totalQuery := fmt.Sprintf( - "%s AND %s", currentScopeQuery, - strings.Join(subQueries, " AND "), - ) - return totalQuery -} + leftPart := fmt.Sprintf( + `"%s"."%s"`, + initTableName, + leftTableKeys[idx], + ) -// TODO: Start always WHERE TRUE AND ... -func generateQuery(tables []*entries.Table, rootTableIdx int, edges []*Edge, isSubQuery bool) string { - - // Use root table as a root table from path - rootTable := tables[rootTableIdx] - var leftTableEdge *Edge - if isSubQuery { - // If it is not a root scope use the right table from the first edge as a root table - // And left table from the first edge as a left table for the subquery. It will be used for where in clause - leftTableEdge = edges[0] - rootTable = tables[edges[0].B.Idx] - edges = edges[1:] - } + rightPart := fmt.Sprintf( + `"%s"."%s"."%s"`, + rightTable.table.Schema, + rightTable.table.Name, + edge.to.keys[idx], + ) - subsetConds := slices.Clone(rootTable.SubsetConds) - selectClause := fmt.Sprintf(`SELECT "%s"."%s".*`, rootTable.Schema, rootTable.Name) - if isSubQuery { - selectClause = generateSelectByPrimaryKey(rootTable) + conds = append(conds, fmt.Sprintf(`%s = %s`, leftPart, rightPart)) } - fromClause := fmt.Sprintf(`FROM "%s"."%s" `, rootTable.Schema, rootTable.Name) - var joinClauses []string - for _, e := range edges { - rightTable := e.B - if len(rightTable.Table.SubsetConds) > 0 { - subsetConds = append(subsetConds, rightTable.Table.SubsetConds...) - } - joinClause := generateJoinClause(e) - joinClauses = append(joinClauses, joinClause) - } + rightTableName := fmt.Sprintf(`"%s"."%s"`, edge.to.table.Schema, edge.to.table.Name) - query := fmt.Sprintf( - `%s %s %s %s`, - selectClause, - fromClause, - strings.Join(joinClauses, " "), - generateWhereClause(subsetConds), + joinClause := fmt.Sprintf( + `JOIN %s ON %s`, + rightTableName, + strings.Join(conds, " AND "), ) + return joinClause +} - if isSubQuery { - if leftTableEdge == nil { - panic("leftTableEdge is nil") - } - var leftTableConds []string - for _, k := range leftTableEdge.A.Keys { - leftTableConds = append(leftTableConds, fmt.Sprintf(`"%s"."%s"."%s"`, leftTableEdge.A.Table.Schema, leftTableEdge.A.Table.Name, k)) - } - query = fmt.Sprintf("((%s) IN (%s))", strings.Join(leftTableConds, ", "), query) +func generateJoinClauseV2(edge *Edge, joinType string, overriddenTables map[toolkit.Oid]string) string { + if joinType != joinTypeInner && joinType != joinTypeLeft { + panic(fmt.Sprintf("invalid join type: %s", joinType)) } - return query -} - -func generateJoinClause(edge *Edge) string { var conds []string - leftTable, rightTable := edge.A, edge.B - for idx := 0; idx < len(leftTable.Keys); idx++ { + + leftTable, rightTable := edge.from.table, edge.to.table + for idx := 0; idx < len(edge.from.keys); idx++ { leftPart := fmt.Sprintf( `"%s"."%s"."%s"`, leftTable.Table.Schema, leftTable.Table.Name, - leftTable.Keys[idx], + edge.from.keys[idx], ) rightPart := fmt.Sprintf( `"%s"."%s"."%s"`, rightTable.Table.Schema, rightTable.Table.Name, - rightTable.Keys[idx], + edge.to.keys[idx], ) + if override, ok := overriddenTables[rightTable.Table.Oid]; ok { + rightPart = fmt.Sprintf( + `"%s"."%s"`, + override, + edge.to.keys[idx], + ) + } conds = append(conds, fmt.Sprintf(`%s = %s`, leftPart, rightPart)) + if len(edge.to.table.SubsetConds) > 0 { + conds = append(conds, edge.to.table.SubsetConds...) + } } rightTableName := fmt.Sprintf(`"%s"."%s"`, rightTable.Table.Schema, rightTable.Table.Name) + if override, ok := overriddenTables[rightTable.Table.Oid]; ok { + rightTableName = fmt.Sprintf(`"%s"`, override) + } joinClause := fmt.Sprintf( - `JOIN %s ON %s`, + `%s JOIN %s ON %s`, + joinType, rightTableName, strings.Join(conds, " AND "), ) @@ -133,7 +105,7 @@ func generateJoinClause(edge *Edge) string { func generateWhereClause(subsetConds []string) string { if len(subsetConds) == 0 { - return "" + return "WHERE TRUE" } escapedConds := make([]string, 0, len(subsetConds)) for _, cond := range subsetConds { @@ -142,9 +114,9 @@ func generateWhereClause(subsetConds []string) string { return "WHERE " + strings.Join(escapedConds, " AND ") } -func generateSelectByPrimaryKey(table *entries.Table) string { +func generateSelectByPrimaryKey(table *entries.Table, pk []string) string { var keys []string - for _, key := range table.PrimaryKey { + for _, key := range pk { keys = append(keys, fmt.Sprintf(`"%s"."%s"."%s"`, table.Schema, table.Name, key)) } return fmt.Sprintf( @@ -152,24 +124,3 @@ func generateSelectByPrimaryKey(table *entries.Table) string { strings.Join(keys, ", "), ) } - -func generateSelectDistinctByPrimaryKey(table *entries.Table) string { - var keys []string - for _, key := range table.PrimaryKey { - keys = append(keys, fmt.Sprintf(`"%s"."%s"."%s"`, table.Schema, table.Name, key)) - } - return fmt.Sprintf( - `SELECT DISTINCT ON (%s) "%s"."%s".*`, - strings.Join(keys, ", "), - table.Schema, - table.Name, - ) -} - -func generateSelectDistinctWithCast(table *entries.Table) string { - var columns []string - for _, c := range table.Columns { - columns = append(columns, fmt.Sprintf(`CAST("%s"."%s"."%s" AS text)`, table.Schema, table.Name, c.Name)) - } - return fmt.Sprintf(`SELECT DISTINCT %s`, strings.Join(columns, ", ")) -} diff --git a/internal/db/postgres/subset/scope_edge.go b/internal/db/postgres/subset/scope_edge.go new file mode 100644 index 00000000..7162cbf2 --- /dev/null +++ b/internal/db/postgres/subset/scope_edge.go @@ -0,0 +1,7 @@ +package subset + +type ScopeEdge struct { + scopeId int + originalCondensedEdge *CondensedEdge + isNullable bool +} diff --git a/internal/db/postgres/subset/set_queries.go b/internal/db/postgres/subset/set_queries.go index 9e7f81aa..4812da6d 100644 --- a/internal/db/postgres/subset/set_queries.go +++ b/internal/db/postgres/subset/set_queries.go @@ -4,8 +4,9 @@ import ( "context" "fmt" - "github.com/greenmaskio/greenmask/internal/db/postgres/entries" "github.com/jackc/pgx/v5" + + "github.com/greenmaskio/greenmask/internal/db/postgres/entries" ) func SetSubsetQueries(ctx context.Context, tx pgx.Tx, tables []*entries.Table) error { @@ -13,9 +14,14 @@ func SetSubsetQueries(ctx context.Context, tx pgx.Tx, tables []*entries.Table) e if err != nil { return fmt.Errorf("error creating graph: %w", err) } + graph.buildCondensedGraph() graph.findSubsetVertexes() - for _, p := range graph.Paths { - generateAndSetQuery(p, tables) + for _, p := range graph.paths { + if isPathForScc(p, graph) { + graph.generateAndSetQueryForScc(p) + } else { + graph.generateAndSetQueryForTable(p) + } } return nil } diff --git a/internal/db/postgres/subset/table_link.go b/internal/db/postgres/subset/table_link.go index a1ee8e9c..b96dc276 100644 --- a/internal/db/postgres/subset/table_link.go +++ b/internal/db/postgres/subset/table_link.go @@ -3,15 +3,15 @@ package subset import "github.com/greenmaskio/greenmask/internal/db/postgres/entries" type TableLink struct { - Idx int - Table *entries.Table - Keys []string + idx int + table *entries.Table + keys []string } func NewTableLink(idx int, t *entries.Table, keys []string) *TableLink { return &TableLink{ - Idx: idx, - Table: t, - Keys: keys, + idx: idx, + table: t, + keys: keys, } } diff --git a/internal/db/postgres/subset/topologocal_sort.go b/internal/db/postgres/subset/topologocal_sort.go deleted file mode 100644 index 00ecd25c..00000000 --- a/internal/db/postgres/subset/topologocal_sort.go +++ /dev/null @@ -1,30 +0,0 @@ -package subset - -import "slices" - -// TopologicalSort returns the topological sort of the graph provided -func TopologicalSort(graph [][]*Edge, path []int) []int { - visited := make([]int, len(graph)) - order := make([]int, 0, len(graph)) - var component int - for _, v := range path { - // - if visited[v] == 0 { - component++ - order = topologicalSortDfs(graph, v, visited, order, component, path) - } - } - //slices.Reverse(order) - return order -} - -// topologicalSortDfs - the basic DFS algorithm adapted to find the topological sort -func topologicalSortDfs(graph [][]*Edge, v int, visited []int, order []int, component int, path []int) []int { - visited[v] = component - for _, to := range graph[v] { - if visited[to.Idx] == 0 && slices.Contains(path, to.Idx) { - order = topologicalSortDfs(graph, to.Idx, visited, order, component, path) - } - } - return append(order, v) -} diff --git a/pkg/toolkit/table.go b/pkg/toolkit/table.go index 964845a4..b10663ec 100644 --- a/pkg/toolkit/table.go +++ b/pkg/toolkit/table.go @@ -22,6 +22,7 @@ type Reference struct { Name string // ReferencedKeys - list of foreign keys of current table ReferencedKeys []string + IsNullable bool } type Table struct {