diff --git a/docs/database_subset.md b/docs/database_subset.md index 4db7d1ec..ec7b8be7 100644 --- a/docs/database_subset.md +++ b/docs/database_subset.md @@ -5,6 +5,12 @@ when you need to dump only a part of the database, such as a specific table or a ensures data consistency by including all related data from other tables that are required to maintain the integrity of the subset. +!!! info + + Greenmask genrates queries for subset conditions based on the introspected schema using joins and recursive queries. + It cannot be responsible for query optimization. The subset quries might be slow due to the complexity of + the queries and/or lack of indexes. Circular dependencies resolution requires recursive queries execution. + ## Detail The subset is a list of SQL conditions that are applied to table. The conditions are combined with `AND` operator. **You @@ -29,6 +35,16 @@ Greenmask will automatically generate the appropriate queries for the table subs system ensures data consistency by validating all records found through the recursive queries. If a record does not meet the subset condition, it will be excluded along with its parent records, preventing constraint violations. +!!! warning + + Currently, can resolve multi-cylces in one strogly connected component, but only for one group of vertexes. For + instance if you have SSC that contains 2 groups of vertexes, Greenmask will not be able to resolve it. For instance + we have 2 cycles with tables `A, B, C` (first group) and `D, E, F` (second group). Greenmask will not be able to + resolve it. But if you have only one group of vertexes one and more cycles in the same group of tables (for instance + `A, B, C`), Greenmask will be able to resolve it. This might be fixed in the future. See second example below. + +You can read the Wikipedia article about Circular reference [here](https://en.wikipedia.org/wiki/Circular_reference). + ## Example: Dump a subset of the database !!! info @@ -61,3 +77,103 @@ transformation: - > person.password.passwordsalt = '329eacbe-c883-4f48-b8b6-17aa4627efff' ``` + +## Example: Dump a subset with circular reference + +```postgresql title="Create tables with multi cyles" +-- Step 1: Create tables without foreign keys +DROP TABLE IF EXISTS employees CASCADE; +CREATE TABLE employees +( + employee_id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + department_id INT -- Will reference departments(department_id) +); + +DROP TABLE IF EXISTS departments CASCADE; +CREATE TABLE departments +( + department_id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + project_id INT -- Will reference projects(project_id) +); + +DROP TABLE IF EXISTS projects CASCADE; +CREATE TABLE projects +( + project_id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + lead_employee_id INT, -- Will reference employees(employee_id) + head_employee_id INT -- Will reference employees(employee_id) +); + +-- Step 2: Alter tables to add foreign key constraints +ALTER TABLE employees + ADD CONSTRAINT fk_department + FOREIGN KEY (department_id) REFERENCES departments (department_id); + +ALTER TABLE departments + ADD CONSTRAINT fk_project + FOREIGN KEY (project_id) REFERENCES projects (project_id); + +ALTER TABLE projects + ADD CONSTRAINT fk_lead_employee + FOREIGN KEY (lead_employee_id) REFERENCES employees (employee_id); + +ALTER TABLE projects + ADD CONSTRAINT fk_lead_employee2 + FOREIGN KEY (head_employee_id) REFERENCES employees (employee_id); + +-- Insert projects +INSERT INTO projects (name, lead_employee_id) +SELECT 'Project ' || i, NULL +FROM generate_series(1, 10) AS s(i); + +-- Insert departments +INSERT INTO departments (name, project_id) +SELECT 'Department ' || i, i +FROM generate_series(1, 10) AS s(i); + +-- Insert employees and assign 10 of them as project leads +INSERT INTO employees (name, department_id) +SELECT 'Employee ' || i, (i / 10) + 1 +FROM generate_series(1, 99) AS s(i); + +-- Assign 10 employees as project leads +UPDATE projects +SET lead_employee_id = (SELECT employee_id + FROM employees + WHERE employees.department_id = projects.project_id + LIMIT 1), + head_employee_id = 3 +WHERE project_id <= 10; +``` + +This schema has two cycles: + +* `employees (department_id) -> departments (project_id) -> projects (lead_employee_id) -> employees (employee_id)` +* `employees (department_id) -> departments (project_id) -> projects (head_employee_id) -> employees (employee_id)` + +Greenmask can simply resolve it by generating a recursive query with integrity checks for subset and join conditions. + +The example below will fetch the data for both 3 employees and related departments and projects. + +```yaml title="Subset configuration example " +transformation: + - schema: "public" + name: "employees" + subset_conds: + - "public.employees.employee_id in (1, 2, 3)" +``` + +But this will return empty result, because the subset condition is not met for all related tables because project with +`project_id=1` has reference to employee with `employee_id=3` that is invalid for subset condition. + +```yaml title="Subset configuration example" +transformation: + - schema: "public" + name: "employees" + subset_conds: + - "public.employees.employee_id in (1, 2)" +``` + diff --git a/internal/db/postgres/subset/component.go b/internal/db/postgres/subset/component.go index fa6b38e1..c6817b7f 100644 --- a/internal/db/postgres/subset/component.go +++ b/internal/db/postgres/subset/component.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/greenmaskio/greenmask/internal/db/postgres/entries" + "github.com/greenmaskio/greenmask/pkg/toolkit" ) type Component struct { @@ -20,6 +21,11 @@ type Component struct { cycles [][]*Edge cyclesIdents map[string]struct{} keys []string + // groupedCycles - cycles grouped by the vertexes + groupedCycles map[string][]int + // groupedCyclesGraph - contains the mapping of the vertexes in the component to the edges in the original graph + // for grouped cycles. This required to join the separated cycles together + groupedCyclesGraph map[string][]*CycleEdge } func NewComponent(id int, componentGraph map[int][]*Edge, tables map[int]*entries.Table) *Component { @@ -35,6 +41,8 @@ func NewComponent(id int, componentGraph map[int][]*Edge, tables map[int]*entrie } else { c.keys = c.getOneTable().PrimaryKey } + c.groupCycles() + c.buildCyclesGraph() return c } @@ -58,6 +66,19 @@ func (c *Component) getOneTable() *entries.Table { panic("cannot call get one table method for cycled scc") } +func (c *Component) getOneCycleGroup() [][]*Edge { + if len(c.groupedCycles) == 1 { + for _, g := range c.groupedCycles { + var res [][]*Edge + for _, idx := range g { + res = append(res, c.cycles[idx]) + } + return res + } + } + panic("get one group cycle group is not allowed for multy cycles") +} + func (c *Component) hasCycle() bool { return len(c.cycles) > 0 } @@ -108,7 +129,7 @@ func (c *Component) findAllCyclesDfs(v int, visited map[int]bool, recStack map[i break } } - cycleId := getCycleIdent(cycle) + cycleId := getCycleId(cycle) if _, ok := c.cyclesIdents[cycleId]; !ok { res := slices.Clone(cycle) slices.Reverse(res) @@ -122,7 +143,18 @@ func (c *Component) findAllCyclesDfs(v int, visited map[int]bool, recStack map[i recStack[v] = false } -func getCycleIdent(cycle []*Edge) string { +// getCycleGroupId - returns the group id for the cycle based on the vertexes ID +func getCycleGroupId(cycle []*Edge) string { + ids := make([]string, 0, len(cycle)) + for _, edge := range cycle { + ids = append(ids, fmt.Sprintf("%d", edge.to.idx)) + } + slices.Sort(ids) + return strings.Join(ids, "_") +} + +// getCycleId - returns the unique identifier for the cycle based on the edges ID +func getCycleId(cycle []*Edge) string { ids := make([]string, 0, len(cycle)) for _, edge := range cycle { ids = append(ids, fmt.Sprintf("%d", edge.id)) @@ -131,21 +163,91 @@ func getCycleIdent(cycle []*Edge) string { return strings.Join(ids, "_") } -func (c *Component) getComponentKeys() []string { - if len(c.cycles) > 1 { - panic("IMPLEMENT ME: multiple cycles in the component") +func (c *Component) groupCycles() { + c.groupedCycles = make(map[string][]int) + for cycleIdx, cycle := range c.cycles { + cycleId := getCycleGroupId(cycle) + c.groupedCycles[cycleId] = append(c.groupedCycles[cycleId], cycleIdx) + } +} + +func (c *Component) buildCyclesGraph() { + // TODO: Need to loop through c.groupedCycles instead of c.cycles + var idSeq int + c.groupedCyclesGraph = make(map[string][]*CycleEdge) + for groupIdI, cyclesI := range c.groupedCycles { + for groupIdJ, cyclesJ := range c.groupedCycles { + if groupIdI == groupIdJ { + continue + } + commonVertexes := c.findCommonVertexes(cyclesI[0], cyclesJ[0]) + if len(commonVertexes) == 0 { + continue + } + if c.areCyclesLinked(cyclesI[0], cyclesJ[0]) { + continue + } + e := NewCycleEdge(idSeq, groupIdI, groupIdJ, commonVertexes) + c.groupedCyclesGraph[groupIdI] = append(c.groupedCyclesGraph[groupIdJ], e) + idSeq++ + } + } +} + +func (c *Component) findCommonVertexes(i, j int) (res []*entries.Table) { + common := make(map[toolkit.Oid]*entries.Table) + for _, edgeI := range c.cycles[i] { + for _, edgeJ := range c.cycles[j] { + if edgeI.to.idx == edgeJ.to.idx { + common[edgeI.to.table.Oid] = edgeI.to.table + } + } + } + for _, table := range common { + res = append(res, table) } + slices.SortFunc(res, func(i, j *entries.Table) int { + switch { + case i.Oid < j.Oid: + return -1 + case i.Oid > j.Oid: + return 1 + } + return 0 + }) + return +} + +func (c *Component) areCyclesLinked(i, j int) bool { + iId := getCycleGroupId(c.cycles[i]) + jId := getCycleGroupId(c.cycles[j]) + for _, to := range c.groupedCyclesGraph[iId] { + if to.to == jId { + return true + } + } + for _, to := range c.groupedCyclesGraph[jId] { + if to.to == iId { + return true + } + } + return false +} + +func (c *Component) getComponentKeys() []string { if !c.hasCycle() { return c.getOneTable().PrimaryKey } - var vertexes []int - for _, edge := range c.cycles[0] { - vertexes = append(vertexes, edge.to.idx) + vertexes := make(map[int]struct{}) + for _, cycle := range c.cycles { + for _, edge := range cycle { + vertexes[edge.to.idx] = struct{}{} + } } var keys []string - for _, v := range vertexes { + for v := range vertexes { table := c.tables[v] for _, key := range table.PrimaryKey { keys = append(keys, fmt.Sprintf(`%s__%s__%s`, table.Schema, table.Name, key)) diff --git a/internal/db/postgres/subset/cte.go b/internal/db/postgres/subset/cte.go index fa13f660..fae1e148 100644 --- a/internal/db/postgres/subset/cte.go +++ b/internal/db/postgres/subset/cte.go @@ -29,8 +29,8 @@ func (c *cteQuery) addItem(name, query string) { func (c *cteQuery) generateQuery(targetTable *entries.Table) string { var queries []string var excludedCteQueries []string - if len(c.c.cycles) > 1 { - panic("IMPLEMENT ME") + if len(c.c.groupedCycles) > 1 { + panic("FIXME: found more than one grouped cycle") } for _, edge := range c.c.cycles[0] { if edge.from.table.Oid == targetTable.Oid { diff --git a/internal/db/postgres/subset/cycle_edge.go b/internal/db/postgres/subset/cycle_edge.go new file mode 100644 index 00000000..95c83fbb --- /dev/null +++ b/internal/db/postgres/subset/cycle_edge.go @@ -0,0 +1,22 @@ +package subset + +import "github.com/greenmaskio/greenmask/internal/db/postgres/entries" + +type CycleEdge struct { + id int + from string + to string + tables []*entries.Table +} + +func NewCycleEdge(id int, from, to string, tables []*entries.Table) *CycleEdge { + if len(tables) == 0 { + panic("empty tables provided for cycle edge") + } + return &CycleEdge{ + id: id, + from: from, + to: to, + tables: tables, + } +} diff --git a/internal/db/postgres/subset/graph.go b/internal/db/postgres/subset/graph.go index 8b9faecf..f03e7ade 100644 --- a/internal/db/postgres/subset/graph.go +++ b/internal/db/postgres/subset/graph.go @@ -310,6 +310,8 @@ func (g *Graph) generateAndSetQueryForScc(path *Path) { g.generateQueriesSccDfs(cq, path, nil) for _, t := range rootVertex.tables { query := cq.generateQuery(t) + fmt.Printf("********************%s.%s*****************\n", t.Schema, t.Name) + println(query) t.Query = query } } @@ -338,32 +340,51 @@ func (g *Graph) generateQueryForScc(cq *cteQuery, scopeId int, path *Path, prevS edges = edges[1:] rootVertex = prevScopeEdge.originalCondensedEdge.to.component } - if len(rootVertex.cycles) > 1 { - panic("IMPLEMENT ME: more than one cycle found in SCC") + //cycle := orderCycle(rootVertex.cycles[0], edges, path.scopeGraph[scopeId]) + if len(rootVertex.groupedCycles) > 1 { + panic("IMPLEMENT ME: more than one cycle group found in SCC") + } + cycleGroup := rootVertex.getOneCycleGroup() + overlapMap := g.getOverlapMap(cycleGroup) + for _, cycle := range cycleGroup { + g.generateRecursiveQueriesForCycle(cq, scopeId, cycle, edges, nextScopeEdges, overlapMap) } - cycle := orderCycle(rootVertex.cycles[0], edges, path.scopeGraph[scopeId]) - g.generateRecursiveQueriesForCycle(cq, scopeId, cycle, edges, nextScopeEdges) - g.generateQueriesForVertexesInCycle(cq, scopeId, cycle) + g.generateFilteredQueries(cq, cycleGroup, scopeId) + g.generateQueriesForVertexesInCycle(cq, scopeId, cycleGroup) } -func (g *Graph) generateQueriesForVertexesInCycle(cq *cteQuery, scopeId int, cycle []*Edge) { - for _, t := range getTablesFromCycle(cycle) { +func (g *Graph) getOverlapMap(cycles [][]*Edge) map[string][][]*Edge { + cyclesOverlap := make(map[string][][]*Edge, len(cycles)) + for i, currCycle := range cycles { + cycleId := getCycleId(currCycle) + var overlapCycles [][]*Edge + for j, overlapCycle := range cycles { + if i == j { + continue + } + overlapCycles = append(overlapCycles, overlapCycle) + } + cyclesOverlap[cycleId] = overlapCycles + } + return cyclesOverlap +} + +func (g *Graph) generateQueriesForVertexesInCycle(cq *cteQuery, scopeId int, cycles [][]*Edge) { + for _, t := range getTablesFromCycle(cycles[0]) { queryName := fmt.Sprintf("%s__%s__ids", t.Schema, t.Name) - query := generateAllTablesValidPkSelection(cycle, scopeId, t) + query := generateAllTablesValidPkSelection(cycles, scopeId, t) cq.addItem(queryName, query) } } func (g *Graph) generateRecursiveQueriesForCycle( cq *cteQuery, scopeId int, cycle []*Edge, rest []*CondensedEdge, nextScopeEdges []*ScopeEdge, + overlapMap map[string][][]*Edge, ) { - var ( - cycleId = getCycleIdent(cycle) - overriddenTableNames = make(map[toolkit.Oid]string) - ) - + overriddenTableNames := make(map[toolkit.Oid]string) rest = slices.Clone(rest) + for _, se := range nextScopeEdges { t := se.originalCondensedEdge.originalEdge.to.table overriddenTableNames[t.Oid] = fmt.Sprintf("%s__%s__ids", t.Schema, t.Name) @@ -373,25 +394,42 @@ func (g *Graph) generateRecursiveQueriesForCycle( //var unionQueries []string shiftedCycle := slices.Clone(cycle) for idx := 1; idx <= len(cycle); idx++ { - var ( - mainTable = shiftedCycle[0].from.table - // queryName - name of a query in the recursive CTE - // where: - // * s - scope id - // * c - cycle id - // * pt1 - part 1 of the recursive query - queryName = fmt.Sprintf("__s%d__c%s__%s__%s", scopeId, cycleId, mainTable.Schema, mainTable.Name) - filteredQueryName = fmt.Sprintf("%s__filtered", queryName) - ) - + queryName := getCycleQueryName(scopeId, shiftedCycle, "") query := generateQuery(queryName, shiftedCycle, rest, overriddenTableNames) cq.addItem(queryName, query) - filteredQuery := generateIntegrityCheckJoinConds(shiftedCycle, mainTable, queryName) - cq.addItem(filteredQueryName, filteredQuery) + cycleId := getCycleId(shiftedCycle) + if len(overlapMap[cycleId]) > 0 { + overlapQueryName := getCycleQueryName(scopeId, shiftedCycle, "overlap") + overlapQuery := generateOverlapQuery(scopeId, overlapQueryName, shiftedCycle, rest, overriddenTableNames, overlapMap[cycleId]) + cq.addItem(overlapQueryName, overlapQuery) + } shiftedCycle = shiftCycle(shiftedCycle) } } +func (g *Graph) generateFilteredQueries(cq *cteQuery, groupedCycles [][]*Edge, scopeId int) { + + // Clone cycles group + cycles := make([][]*Edge, 0, len(groupedCycles)) + for _, cycle := range groupedCycles { + cycles = append(cycles, slices.Clone(cycle)) + } + for idx := 1; idx <= len(cycles[0]); idx++ { + groupQueryNamePrefix := getCyclesGroupQueryName(scopeId, cycles[0]) + filteredQueryName := fmt.Sprintf("%s__filtered", groupQueryNamePrefix) + if len(cycles) > 1 { + unitedQuery := generateUnitedCyclesQuery(scopeId, cycles) + groupQueryName := getCyclesGroupQueryName(scopeId, cycles[0]) + unitedQueryName := fmt.Sprintf("%s__united", groupQueryName) + cq.addItem(unitedQueryName, unitedQuery) + } + filteredQuery := generateIntegrityCheckJoinConds(scopeId, cycles) + cq.addItem(filteredQueryName, filteredQuery) + shiftCycleGroup(cycles) + } + +} + func (g *Graph) generateQueriesDfs(path *Path, scopeEdge *ScopeEdge) string { // TODO: // 1. Add scopeEdges support and LEFT JOIN @@ -566,53 +604,146 @@ func isPathForScc(path *Path, graph *Graph) bool { return graph.scc[path.rootVertex].hasCycle() } -func orderCycle(cycle []*Edge, subsetJoins []*CondensedEdge, scopeEdges []*ScopeEdge) []*Edge { +func generateQuery( + queryName string, cycle []*Edge, rest []*CondensedEdge, overriddenTables map[toolkit.Oid]string, +) string { var ( - vertexes []int - valuableEdgesIdx int + selectKeys []string + initialJoins, recursiveJoins []string + initialWhereConds, recursiveWhereConds []string + integrityCheck string + cycleSubsetConds []string + edges = slices.Clone(cycle[:len(cycle)-1]) + droppedEdge = cycle[len(cycle)-1] ) + for _, ce := range rest { + edges = append(edges, ce.originalEdge) + } - for _, e := range cycle { - vertexes = append(vertexes, e.from.idx) + for _, t := range getTablesFromCycle(cycle) { + var keysWithAliases []string + for _, k := range t.PrimaryKey { + keysWithAliases = append(keysWithAliases, fmt.Sprintf(`"%s"."%s"."%s" as "%s__%s__%s"`, t.Schema, t.Name, k, t.Schema, t.Name, k)) + } + selectKeys = append(selectKeys, keysWithAliases...) + if len(t.SubsetConds) > 0 { + cycleSubsetConds = append(cycleSubsetConds, t.SubsetConds...) + } } - for _, sj := range subsetJoins { - if slices.Contains(vertexes, sj.originalEdge.from.idx) { - valuableEdgesIdx = slices.IndexFunc(cycle, func(e *Edge) bool { - return sj.originalEdge.from.idx == e.from.idx - }) - if !sj.originalEdge.isNullable { - break - } + var droppedKeysWithAliases []string + for _, k := range droppedEdge.from.keys { + t := droppedEdge.from.table + droppedKeysWithAliases = append(droppedKeysWithAliases, fmt.Sprintf(`"%s"."%s"."%s" as "%s__%s__%s"`, t.Schema, t.Name, k, t.Schema, t.Name, k)) + } + selectKeys = append(selectKeys, droppedKeysWithAliases...) + + var initialPathSelectionKeys []string + for _, k := range cycle[0].from.table.PrimaryKey { + t := cycle[0].from.table + pathName := fmt.Sprintf( + `ARRAY["%s"."%s"."%s"] AS %s__%s__%s__path`, + t.Schema, t.Name, k, + t.Schema, t.Name, k, + ) + initialPathSelectionKeys = append(initialPathSelectionKeys, pathName) + } + + initialKeys := slices.Clone(selectKeys) + initialKeys = append(initialKeys, initialPathSelectionKeys...) + initFromClause := fmt.Sprintf(`FROM "%s"."%s" `, cycle[0].from.table.Schema, cycle[0].from.table.Name) + integrityCheck = "TRUE AS valid" + initialKeys = append(initialKeys, integrityCheck) + initialWhereConds = append(initialWhereConds, cycleSubsetConds...) + + initialSelect := fmt.Sprintf("SELECT %s", strings.Join(initialKeys, ", ")) + nullabilityMap := make(map[int]bool) + for _, e := range edges { + isNullable := e.isNullable + if !isNullable { + isNullable = nullabilityMap[e.from.idx] + } + nullabilityMap[e.to.idx] = isNullable + joinType := joinTypeInner + if isNullable { + joinType = joinTypeLeft } + initialJoins = append(initialJoins, generateJoinClauseV2(e, joinType, overriddenTables)) } - for _, se := range scopeEdges { - if slices.Contains(vertexes, se.originalCondensedEdge.from.idx) { - valuableEdgesIdx = slices.IndexFunc(cycle, func(e *Edge) bool { - return se.originalCondensedEdge.originalEdge.from.idx == e.from.idx - }) - if !se.originalCondensedEdge.originalEdge.isNullable { - break - } + integrityChecks := generateIntegrityChecksForNullableEdges(nullabilityMap, edges, overriddenTables) + initialWhereConds = append(initialWhereConds, integrityChecks...) + initialWhereClause := generateWhereClause(initialWhereConds) + initialQuery := fmt.Sprintf(`%s %s %s %s`, + initialSelect, initFromClause, strings.Join(initialJoins, " "), initialWhereClause, + ) + + recursiveIntegrityChecks := slices.Clone(cycleSubsetConds) + recursiveIntegrityChecks = append(recursiveIntegrityChecks, integrityChecks...) + recursiveIntegrityCheck := fmt.Sprintf("(%s) AS valid", strings.Join(recursiveIntegrityChecks, " AND ")) + recursiveKeys := slices.Clone(selectKeys) + for _, k := range cycle[0].from.table.PrimaryKey { + t := cycle[0].from.table + //recursivePathSelectionKeys = append(recursivePathSelectionKeys, fmt.Sprintf(`coalesce("%s"."%s"."%s"::TEXT, 'NULL')`, t.Schema, t.Name, k)) + + pathName := fmt.Sprintf( + `"%s__%s__%s__path" || ARRAY["%s"."%s"."%s"]`, + t.Schema, t.Name, k, + t.Schema, t.Name, k, + ) + recursiveKeys = append(recursiveKeys, pathName) + } + recursiveKeys = append(recursiveKeys, recursiveIntegrityCheck) + + recursiveSelect := fmt.Sprintf("SELECT %s", strings.Join(recursiveKeys, ", ")) + recursiveFromClause := fmt.Sprintf(`FROM "%s" `, queryName) + recursiveJoins = append(recursiveJoins, generateJoinClauseForDroppedEdge(droppedEdge, queryName)) + nullabilityMap = make(map[int]bool) + for _, e := range edges { + isNullable := e.isNullable + if !isNullable { + isNullable = nullabilityMap[e.from.idx] + } + nullabilityMap[e.to.idx] = isNullable + joinType := joinTypeInner + if isNullable { + joinType = joinTypeLeft } + recursiveJoins = append(recursiveJoins, generateJoinClauseV2(e, joinType, overriddenTables)) } - if valuableEdgesIdx == -1 { - panic("is not found") + recursiveValidCond := fmt.Sprintf(`"%s"."%s"`, queryName, "valid") + recursiveWhereConds = append(recursiveWhereConds, recursiveValidCond) + for _, k := range cycle[0].from.table.PrimaryKey { + t := cycle[0].from.table + + recursivePathCheck := fmt.Sprintf( + `NOT "%s"."%s"."%s" = ANY("%s"."%s__%s__%s__%s")`, + t.Schema, t.Name, k, + queryName, t.Schema, t.Name, k, "path", + ) + + recursiveWhereConds = append(recursiveWhereConds, recursivePathCheck) } + recursiveWhereClause := generateWhereClause(recursiveWhereConds) + + recursiveQuery := fmt.Sprintf(`%s %s %s %s`, + recursiveSelect, recursiveFromClause, strings.Join(recursiveJoins, " "), recursiveWhereClause, + ) - resCycle := slices.Clone(cycle[valuableEdgesIdx:]) - resCycle = append(resCycle, cycle[:valuableEdgesIdx]...) - return resCycle + query := fmt.Sprintf("( %s ) UNION ( %s )", initialQuery, recursiveQuery) + return query } -func generateQuery(queryName string, cycle []*Edge, rest []*CondensedEdge, overriddenTables map[toolkit.Oid]string) string { +func generateOverlapQuery( + scopeId int, + queryName string, cycle []*Edge, rest []*CondensedEdge, overriddenTables map[toolkit.Oid]string, + overlap [][]*Edge, +) string { var ( selectKeys []string initialJoins, recursiveJoins []string initialWhereConds, recursiveWhereConds []string - integrityCheck string cycleSubsetConds []string edges = slices.Clone(cycle[:len(cycle)-1]) droppedEdge = cycle[len(cycle)-1] @@ -653,11 +784,8 @@ func generateQuery(queryName string, cycle []*Edge, rest []*CondensedEdge, overr initialKeys := slices.Clone(selectKeys) initialKeys = append(initialKeys, initialPathSelectionKeys...) initFromClause := fmt.Sprintf(`FROM "%s"."%s" `, edges[0].from.table.Schema, edges[0].from.table.Name) - integrityCheck = "TRUE AS valid" - initialKeys = append(initialKeys, integrityCheck) - initialWhereConds = append(initialWhereConds, cycleSubsetConds...) + initialWhereConds = append(initialWhereConds, generateInClauseForOverlap(scopeId, edges, overlap)) - initialSelect := fmt.Sprintf("SELECT %s", strings.Join(initialKeys, ", ")) nullabilityMap := make(map[int]bool) for _, e := range edges { isNullable := e.isNullable @@ -673,15 +801,18 @@ func generateQuery(queryName string, cycle []*Edge, rest []*CondensedEdge, overr } integrityChecks := generateIntegrityChecksForNullableEdges(nullabilityMap, edges, overriddenTables) - initialWhereConds = append(initialWhereConds, integrityChecks...) + integrityChecks = append(integrityChecks, cycleSubsetConds...) + initialIntegrityCheck := fmt.Sprintf("(%s) AS valid", strings.Join(integrityChecks, " AND ")) + initialKeys = append(initialKeys, initialIntegrityCheck) + initialSelect := fmt.Sprintf("SELECT %s", strings.Join(initialKeys, ", ")) + + //initialWhereConds = append(initialWhereConds, integrityChecks...) initialWhereClause := generateWhereClause(initialWhereConds) initialQuery := fmt.Sprintf(`%s %s %s %s`, initialSelect, initFromClause, strings.Join(initialJoins, " "), initialWhereClause, ) - recursiveIntegrityChecks := slices.Clone(cycleSubsetConds) - recursiveIntegrityChecks = append(recursiveIntegrityChecks, integrityChecks...) - recursiveIntegrityCheck := fmt.Sprintf("(%s) AS valid", strings.Join(recursiveIntegrityChecks, " AND ")) + recursiveIntegrityCheck := fmt.Sprintf("(%s) AS valid", strings.Join(integrityChecks, " AND ")) recursiveKeys := slices.Clone(selectKeys) for _, k := range edges[0].from.table.PrimaryKey { t := edges[0].from.table @@ -736,6 +867,36 @@ func generateQuery(queryName string, cycle []*Edge, rest []*CondensedEdge, overr return query } +func generateInClauseForOverlap(scopeId int, edges []*Edge, overlap [][]*Edge) string { + var ( + overlapTables []string + unionQueryParts []string + rightTableKeys, leftTableKeys []string + ) + + var shiftedOverlaps [][]*Edge + for _, oc := range overlap { + shiftedOverlaps = append(shiftedOverlaps, shiftUntilVertexWillBeFirst(edges[0], oc)) + } + + for _, c := range shiftedOverlaps { + overlapTables = append(overlapTables, getCycleQueryName(scopeId, c, "")) + } + for _, k := range edges[0].from.table.PrimaryKey { + rightTableKey := fmt.Sprintf(`"%s__%s__%s"`, edges[0].from.table.Schema, edges[0].from.table.Name, k) + rightTableKeys = append(rightTableKeys, rightTableKey) + leftTableKey := fmt.Sprintf(`"%s"."%s"."%s"`, edges[0].from.table.Schema, edges[0].from.table.Name, k) + leftTableKeys = append(leftTableKeys, leftTableKey) + } + for _, t := range overlapTables { + unionQueryParts = append(unionQueryParts, fmt.Sprintf(`SELECT %s FROM "%s"`, strings.Join(rightTableKeys, ", "), t)) + } + unionQuery := strings.Join(unionQueryParts, " UNION ") + + res := fmt.Sprintf(`(%s) IN (%s)`, strings.Join(leftTableKeys, ", "), unionQuery) + return res +} + func getTablesFromCycle(cycle []*Edge) (res []*entries.Table) { for _, e := range cycle { res = append(res, e.to.table) @@ -783,15 +944,34 @@ func generateIntegrityChecksForNullableEdges(nullabilityMap map[int]bool, edges return } -func generateIntegrityCheckJoinConds(cycle []*Edge, table *entries.Table, tableName string) string { +func generateUnitedCyclesQuery(scopeId int, cycles [][]*Edge) string { + var tablesSelection []string + for _, c := range cycles { + q1 := fmt.Sprintf(`SELECT * FROM "%s"`, getCycleQueryName(scopeId, c, "")) + tablesSelection = append(tablesSelection, q1) + q2 := fmt.Sprintf(`SELECT * FROM "%s"`, getCycleQueryName(scopeId, c, "overlap")) + tablesSelection = append(tablesSelection, q2) + } + res := strings.Join(tablesSelection, " UNION ") + return res +} + +func generateIntegrityCheckJoinConds(scopeId int, cycles [][]*Edge) string { var ( + table = cycles[0][0].from.table allPks []string mainTablePks []string unnestSelections []string + tableName = getCycleQueryName(scopeId, cycles[0], "") ) - for _, t := range getTablesFromCycle(cycle) { + if len(cycles) > 1 { + prefix := getCyclesGroupQueryName(scopeId, cycles[0]) + tableName = fmt.Sprintf("%s__united", prefix) + } + + for _, t := range getTablesFromCycle(cycles[0]) { for _, k := range t.PrimaryKey { key := fmt.Sprintf(`"%s"."%s__%s__%s"`, tableName, t.Schema, t.Name, k) allPks = append(allPks, key) @@ -822,15 +1002,15 @@ func generateIntegrityCheckJoinConds(cycle []*Edge, table *entries.Table, tableN return filteredQuery } -func generateAllTablesValidPkSelection(cycle []*Edge, scopeId int, forTable *entries.Table) string { +func generateAllTablesValidPkSelection(cycles [][]*Edge, scopeId int, forTable *entries.Table) string { var unionParts []string - for _, t := range getTablesFromCycle(cycle) { + for _, t := range getTablesFromCycle(cycles[0]) { var ( selectionKeys []string - cycleId = getCycleIdent(cycle) - filteredQueryName = fmt.Sprintf("__s%d__c%s__%s__%s__filtered", scopeId, cycleId, t.Schema, t.Name) + groupId = getCycleGroupId(cycles[0]) + filteredQueryName = fmt.Sprintf("__s%d__g%s__%s__%s__filtered", scopeId, groupId, t.Schema, t.Name) ) for _, k := range forTable.PrimaryKey { @@ -854,3 +1034,54 @@ func shiftCycle(cycle []*Edge) (res []*Edge) { res = append(res, cycle[:len(cycle)-1]...) return } + +func getCycleQueryName(scopeId int, cycle []*Edge, postfix string) string { + // queryName - name of a query in the recursive CTE + // where: + // * s - scope id + // * g - group id + // * c - cycle id + // * postfix with table name + mainTable := cycle[0].from.table + groupId := getCycleGroupId(cycle) + cycleId := getCycleId(cycle) + res := fmt.Sprintf("__s%d__g%s__c%s__%s__%s", scopeId, groupId, cycleId, mainTable.Schema, mainTable.Name) + if postfix != "" { + res = fmt.Sprintf("%s__%s", res, postfix) + } + return res +} + +func getCyclesGroupQueryName(scopeId int, cycle []*Edge) string { + // queryName - name of a query in the recursive CTE + // where: + // * s - scope id + // * g - group id + // * postfix with table name + mainTable := cycle[0].from.table + groupId := getCycleGroupId(cycle) + return getCyclesGroupQueryNameByMainTable(scopeId, groupId, mainTable) +} + +func getCyclesGroupQueryNameByMainTable(scopeId int, groupId string, mainTable *entries.Table) string { + return fmt.Sprintf("__s%d__g%s__%s__%s", scopeId, groupId, mainTable.Schema, mainTable.Name) +} + +func shiftCycleGroup(g [][]*Edge) [][]*Edge { + for idx := range g { + g[idx] = shiftCycle(g[idx]) + } + return g +} + +func shiftUntilVertexWillBeFirst(v *Edge, c []*Edge) []*Edge { + //generateInClauseForOverlap + res := slices.Clone(c) + for { + if res[0].from.idx == v.from.idx { + break + } + res = shiftCycle(res) + } + return res +}