diff --git a/internal/db/postgres/subset/component.go b/internal/db/postgres/subset/component.go index fa6b38e1..c6817b7f 100644 --- a/internal/db/postgres/subset/component.go +++ b/internal/db/postgres/subset/component.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/greenmaskio/greenmask/internal/db/postgres/entries" + "github.com/greenmaskio/greenmask/pkg/toolkit" ) type Component struct { @@ -20,6 +21,11 @@ type Component struct { cycles [][]*Edge cyclesIdents map[string]struct{} keys []string + // groupedCycles - cycles grouped by the vertexes + groupedCycles map[string][]int + // groupedCyclesGraph - contains the mapping of the vertexes in the component to the edges in the original graph + // for grouped cycles. This required to join the separated cycles together + groupedCyclesGraph map[string][]*CycleEdge } func NewComponent(id int, componentGraph map[int][]*Edge, tables map[int]*entries.Table) *Component { @@ -35,6 +41,8 @@ func NewComponent(id int, componentGraph map[int][]*Edge, tables map[int]*entrie } else { c.keys = c.getOneTable().PrimaryKey } + c.groupCycles() + c.buildCyclesGraph() return c } @@ -58,6 +66,19 @@ func (c *Component) getOneTable() *entries.Table { panic("cannot call get one table method for cycled scc") } +func (c *Component) getOneCycleGroup() [][]*Edge { + if len(c.groupedCycles) == 1 { + for _, g := range c.groupedCycles { + var res [][]*Edge + for _, idx := range g { + res = append(res, c.cycles[idx]) + } + return res + } + } + panic("get one group cycle group is not allowed for multy cycles") +} + func (c *Component) hasCycle() bool { return len(c.cycles) > 0 } @@ -108,7 +129,7 @@ func (c *Component) findAllCyclesDfs(v int, visited map[int]bool, recStack map[i break } } - cycleId := getCycleIdent(cycle) + cycleId := getCycleId(cycle) if _, ok := c.cyclesIdents[cycleId]; !ok { res := slices.Clone(cycle) slices.Reverse(res) @@ -122,7 +143,18 @@ func (c *Component) findAllCyclesDfs(v int, visited map[int]bool, recStack map[i recStack[v] = false } -func getCycleIdent(cycle []*Edge) string { +// getCycleGroupId - returns the group id for the cycle based on the vertexes ID +func getCycleGroupId(cycle []*Edge) string { + ids := make([]string, 0, len(cycle)) + for _, edge := range cycle { + ids = append(ids, fmt.Sprintf("%d", edge.to.idx)) + } + slices.Sort(ids) + return strings.Join(ids, "_") +} + +// getCycleId - returns the unique identifier for the cycle based on the edges ID +func getCycleId(cycle []*Edge) string { ids := make([]string, 0, len(cycle)) for _, edge := range cycle { ids = append(ids, fmt.Sprintf("%d", edge.id)) @@ -131,21 +163,91 @@ func getCycleIdent(cycle []*Edge) string { return strings.Join(ids, "_") } -func (c *Component) getComponentKeys() []string { - if len(c.cycles) > 1 { - panic("IMPLEMENT ME: multiple cycles in the component") +func (c *Component) groupCycles() { + c.groupedCycles = make(map[string][]int) + for cycleIdx, cycle := range c.cycles { + cycleId := getCycleGroupId(cycle) + c.groupedCycles[cycleId] = append(c.groupedCycles[cycleId], cycleIdx) + } +} + +func (c *Component) buildCyclesGraph() { + // TODO: Need to loop through c.groupedCycles instead of c.cycles + var idSeq int + c.groupedCyclesGraph = make(map[string][]*CycleEdge) + for groupIdI, cyclesI := range c.groupedCycles { + for groupIdJ, cyclesJ := range c.groupedCycles { + if groupIdI == groupIdJ { + continue + } + commonVertexes := c.findCommonVertexes(cyclesI[0], cyclesJ[0]) + if len(commonVertexes) == 0 { + continue + } + if c.areCyclesLinked(cyclesI[0], cyclesJ[0]) { + continue + } + e := NewCycleEdge(idSeq, groupIdI, groupIdJ, commonVertexes) + c.groupedCyclesGraph[groupIdI] = append(c.groupedCyclesGraph[groupIdJ], e) + idSeq++ + } + } +} + +func (c *Component) findCommonVertexes(i, j int) (res []*entries.Table) { + common := make(map[toolkit.Oid]*entries.Table) + for _, edgeI := range c.cycles[i] { + for _, edgeJ := range c.cycles[j] { + if edgeI.to.idx == edgeJ.to.idx { + common[edgeI.to.table.Oid] = edgeI.to.table + } + } + } + for _, table := range common { + res = append(res, table) } + slices.SortFunc(res, func(i, j *entries.Table) int { + switch { + case i.Oid < j.Oid: + return -1 + case i.Oid > j.Oid: + return 1 + } + return 0 + }) + return +} + +func (c *Component) areCyclesLinked(i, j int) bool { + iId := getCycleGroupId(c.cycles[i]) + jId := getCycleGroupId(c.cycles[j]) + for _, to := range c.groupedCyclesGraph[iId] { + if to.to == jId { + return true + } + } + for _, to := range c.groupedCyclesGraph[jId] { + if to.to == iId { + return true + } + } + return false +} + +func (c *Component) getComponentKeys() []string { if !c.hasCycle() { return c.getOneTable().PrimaryKey } - var vertexes []int - for _, edge := range c.cycles[0] { - vertexes = append(vertexes, edge.to.idx) + vertexes := make(map[int]struct{}) + for _, cycle := range c.cycles { + for _, edge := range cycle { + vertexes[edge.to.idx] = struct{}{} + } } var keys []string - for _, v := range vertexes { + for v := range vertexes { table := c.tables[v] for _, key := range table.PrimaryKey { keys = append(keys, fmt.Sprintf(`%s__%s__%s`, table.Schema, table.Name, key)) diff --git a/internal/db/postgres/subset/cte.go b/internal/db/postgres/subset/cte.go index fa13f660..fae1e148 100644 --- a/internal/db/postgres/subset/cte.go +++ b/internal/db/postgres/subset/cte.go @@ -29,8 +29,8 @@ func (c *cteQuery) addItem(name, query string) { func (c *cteQuery) generateQuery(targetTable *entries.Table) string { var queries []string var excludedCteQueries []string - if len(c.c.cycles) > 1 { - panic("IMPLEMENT ME") + if len(c.c.groupedCycles) > 1 { + panic("FIXME: found more than one grouped cycle") } for _, edge := range c.c.cycles[0] { if edge.from.table.Oid == targetTable.Oid { diff --git a/internal/db/postgres/subset/cycle_edge.go b/internal/db/postgres/subset/cycle_edge.go new file mode 100644 index 00000000..95c83fbb --- /dev/null +++ b/internal/db/postgres/subset/cycle_edge.go @@ -0,0 +1,22 @@ +package subset + +import "github.com/greenmaskio/greenmask/internal/db/postgres/entries" + +type CycleEdge struct { + id int + from string + to string + tables []*entries.Table +} + +func NewCycleEdge(id int, from, to string, tables []*entries.Table) *CycleEdge { + if len(tables) == 0 { + panic("empty tables provided for cycle edge") + } + return &CycleEdge{ + id: id, + from: from, + to: to, + tables: tables, + } +} diff --git a/internal/db/postgres/subset/graph.go b/internal/db/postgres/subset/graph.go index 8b9faecf..468e511f 100644 --- a/internal/db/postgres/subset/graph.go +++ b/internal/db/postgres/subset/graph.go @@ -310,6 +310,8 @@ func (g *Graph) generateAndSetQueryForScc(path *Path) { g.generateQueriesSccDfs(cq, path, nil) for _, t := range rootVertex.tables { query := cq.generateQuery(t) + fmt.Printf("********************%s.%s*****************\n", t.Schema, t.Name) + println(query) t.Query = query } } @@ -338,32 +340,51 @@ func (g *Graph) generateQueryForScc(cq *cteQuery, scopeId int, path *Path, prevS edges = edges[1:] rootVertex = prevScopeEdge.originalCondensedEdge.to.component } - if len(rootVertex.cycles) > 1 { - panic("IMPLEMENT ME: more than one cycle found in SCC") + //cycle := orderCycle(rootVertex.cycles[0], edges, path.scopeGraph[scopeId]) + if len(rootVertex.groupedCycles) > 1 { + panic("IMPLEMENT ME: more than one cycle group found in SCC") + } + cycleGroup := rootVertex.getOneCycleGroup() + overlapMap := g.getOverlapMap(cycleGroup) + for _, cycle := range cycleGroup { + g.generateRecursiveQueriesForCycle(cq, scopeId, cycle, edges, nextScopeEdges, overlapMap) } - cycle := orderCycle(rootVertex.cycles[0], edges, path.scopeGraph[scopeId]) - g.generateRecursiveQueriesForCycle(cq, scopeId, cycle, edges, nextScopeEdges) - g.generateQueriesForVertexesInCycle(cq, scopeId, cycle) + g.generateFilteredQueries(cq, cycleGroup, scopeId) + g.generateQueriesForVertexesInCycle(cq, scopeId, cycleGroup) } -func (g *Graph) generateQueriesForVertexesInCycle(cq *cteQuery, scopeId int, cycle []*Edge) { - for _, t := range getTablesFromCycle(cycle) { +func (g *Graph) getOverlapMap(cycles [][]*Edge) map[string][][]*Edge { + cyclesOverlap := make(map[string][][]*Edge, len(cycles)) + for i, currCycle := range cycles { + cycleId := getCycleId(currCycle) + var overlapCycles [][]*Edge + for j, overlapCycle := range cycles { + if i == j { + continue + } + overlapCycles = append(overlapCycles, overlapCycle) + } + cyclesOverlap[cycleId] = overlapCycles + } + return cyclesOverlap +} + +func (g *Graph) generateQueriesForVertexesInCycle(cq *cteQuery, scopeId int, cycles [][]*Edge) { + for _, t := range getTablesFromCycle(cycles[0]) { queryName := fmt.Sprintf("%s__%s__ids", t.Schema, t.Name) - query := generateAllTablesValidPkSelection(cycle, scopeId, t) + query := generateAllTablesValidPkSelection(cycles, scopeId, t) cq.addItem(queryName, query) } } func (g *Graph) generateRecursiveQueriesForCycle( cq *cteQuery, scopeId int, cycle []*Edge, rest []*CondensedEdge, nextScopeEdges []*ScopeEdge, + overlapMap map[string][][]*Edge, ) { - var ( - cycleId = getCycleIdent(cycle) - overriddenTableNames = make(map[toolkit.Oid]string) - ) - + overriddenTableNames := make(map[toolkit.Oid]string) rest = slices.Clone(rest) + for _, se := range nextScopeEdges { t := se.originalCondensedEdge.originalEdge.to.table overriddenTableNames[t.Oid] = fmt.Sprintf("%s__%s__ids", t.Schema, t.Name) @@ -373,25 +394,42 @@ func (g *Graph) generateRecursiveQueriesForCycle( //var unionQueries []string shiftedCycle := slices.Clone(cycle) for idx := 1; idx <= len(cycle); idx++ { - var ( - mainTable = shiftedCycle[0].from.table - // queryName - name of a query in the recursive CTE - // where: - // * s - scope id - // * c - cycle id - // * pt1 - part 1 of the recursive query - queryName = fmt.Sprintf("__s%d__c%s__%s__%s", scopeId, cycleId, mainTable.Schema, mainTable.Name) - filteredQueryName = fmt.Sprintf("%s__filtered", queryName) - ) - + queryName := getCycleQueryName(scopeId, shiftedCycle, "") query := generateQuery(queryName, shiftedCycle, rest, overriddenTableNames) cq.addItem(queryName, query) - filteredQuery := generateIntegrityCheckJoinConds(shiftedCycle, mainTable, queryName) - cq.addItem(filteredQueryName, filteredQuery) + cycleId := getCycleId(shiftedCycle) + if len(overlapMap[cycleId]) > 0 { + overlapQueryName := getCycleQueryName(scopeId, shiftedCycle, "overlap") + overlapQuery := generateOverlapQuery(scopeId, overlapQueryName, shiftedCycle, rest, overriddenTableNames, overlapMap[cycleId]) + cq.addItem(overlapQueryName, overlapQuery) + } shiftedCycle = shiftCycle(shiftedCycle) } } +func (g *Graph) generateFilteredQueries(cq *cteQuery, groupedCycles [][]*Edge, scopeId int) { + + // Clone cycles group + cycles := make([][]*Edge, 0, len(groupedCycles)) + for _, cycle := range groupedCycles { + cycles = append(cycles, slices.Clone(cycle)) + } + for idx := 1; idx <= len(cycles[0]); idx++ { + groupQueryNamePrefix := getCyclesGroupQueryName(scopeId, cycles[0]) + filteredQueryName := fmt.Sprintf("%s__filtered", groupQueryNamePrefix) + if len(cycles) > 1 { + unitedQuery := generateUnitedCyclesQuery(scopeId, cycles) + groupQueryName := getCyclesGroupQueryName(scopeId, cycles[0]) + unitedQueryName := fmt.Sprintf("%s__united", groupQueryName) + cq.addItem(unitedQueryName, unitedQuery) + } + filteredQuery := generateIntegrityCheckJoinConds(scopeId, cycles) + cq.addItem(filteredQueryName, filteredQuery) + shiftCycleGroup(cycles) + } + +} + func (g *Graph) generateQueriesDfs(path *Path, scopeEdge *ScopeEdge) string { // TODO: // 1. Add scopeEdges support and LEFT JOIN @@ -566,53 +604,146 @@ func isPathForScc(path *Path, graph *Graph) bool { return graph.scc[path.rootVertex].hasCycle() } -func orderCycle(cycle []*Edge, subsetJoins []*CondensedEdge, scopeEdges []*ScopeEdge) []*Edge { +func generateQuery( + queryName string, cycle []*Edge, rest []*CondensedEdge, overriddenTables map[toolkit.Oid]string, +) string { var ( - vertexes []int - valuableEdgesIdx int + selectKeys []string + initialJoins, recursiveJoins []string + initialWhereConds, recursiveWhereConds []string + integrityCheck string + cycleSubsetConds []string + edges = slices.Clone(cycle[:len(cycle)-1]) + droppedEdge = cycle[len(cycle)-1] ) + for _, ce := range rest { + edges = append(edges, ce.originalEdge) + } - for _, e := range cycle { - vertexes = append(vertexes, e.from.idx) + for _, t := range getTablesFromCycle(cycle) { + var keysWithAliases []string + for _, k := range t.PrimaryKey { + keysWithAliases = append(keysWithAliases, fmt.Sprintf(`"%s"."%s"."%s" as "%s__%s__%s"`, t.Schema, t.Name, k, t.Schema, t.Name, k)) + } + selectKeys = append(selectKeys, keysWithAliases...) + if len(t.SubsetConds) > 0 { + cycleSubsetConds = append(cycleSubsetConds, t.SubsetConds...) + } } - for _, sj := range subsetJoins { - if slices.Contains(vertexes, sj.originalEdge.from.idx) { - valuableEdgesIdx = slices.IndexFunc(cycle, func(e *Edge) bool { - return sj.originalEdge.from.idx == e.from.idx - }) - if !sj.originalEdge.isNullable { - break - } + var droppedKeysWithAliases []string + for _, k := range droppedEdge.from.keys { + t := droppedEdge.from.table + droppedKeysWithAliases = append(droppedKeysWithAliases, fmt.Sprintf(`"%s"."%s"."%s" as "%s__%s__%s"`, t.Schema, t.Name, k, t.Schema, t.Name, k)) + } + selectKeys = append(selectKeys, droppedKeysWithAliases...) + + var initialPathSelectionKeys []string + for _, k := range edges[0].from.table.PrimaryKey { + t := edges[0].from.table + pathName := fmt.Sprintf( + `ARRAY["%s"."%s"."%s"] AS %s__%s__%s__path`, + t.Schema, t.Name, k, + t.Schema, t.Name, k, + ) + initialPathSelectionKeys = append(initialPathSelectionKeys, pathName) + } + + initialKeys := slices.Clone(selectKeys) + initialKeys = append(initialKeys, initialPathSelectionKeys...) + initFromClause := fmt.Sprintf(`FROM "%s"."%s" `, edges[0].from.table.Schema, edges[0].from.table.Name) + integrityCheck = "TRUE AS valid" + initialKeys = append(initialKeys, integrityCheck) + initialWhereConds = append(initialWhereConds, cycleSubsetConds...) + + initialSelect := fmt.Sprintf("SELECT %s", strings.Join(initialKeys, ", ")) + nullabilityMap := make(map[int]bool) + for _, e := range edges { + isNullable := e.isNullable + if !isNullable { + isNullable = nullabilityMap[e.from.idx] } + nullabilityMap[e.to.idx] = isNullable + joinType := joinTypeInner + if isNullable { + joinType = joinTypeLeft + } + initialJoins = append(initialJoins, generateJoinClauseV2(e, joinType, overriddenTables)) } - for _, se := range scopeEdges { - if slices.Contains(vertexes, se.originalCondensedEdge.from.idx) { - valuableEdgesIdx = slices.IndexFunc(cycle, func(e *Edge) bool { - return se.originalCondensedEdge.originalEdge.from.idx == e.from.idx - }) - if !se.originalCondensedEdge.originalEdge.isNullable { - break - } + integrityChecks := generateIntegrityChecksForNullableEdges(nullabilityMap, edges, overriddenTables) + initialWhereConds = append(initialWhereConds, integrityChecks...) + initialWhereClause := generateWhereClause(initialWhereConds) + initialQuery := fmt.Sprintf(`%s %s %s %s`, + initialSelect, initFromClause, strings.Join(initialJoins, " "), initialWhereClause, + ) + + recursiveIntegrityChecks := slices.Clone(cycleSubsetConds) + recursiveIntegrityChecks = append(recursiveIntegrityChecks, integrityChecks...) + recursiveIntegrityCheck := fmt.Sprintf("(%s) AS valid", strings.Join(recursiveIntegrityChecks, " AND ")) + recursiveKeys := slices.Clone(selectKeys) + for _, k := range edges[0].from.table.PrimaryKey { + t := edges[0].from.table + //recursivePathSelectionKeys = append(recursivePathSelectionKeys, fmt.Sprintf(`coalesce("%s"."%s"."%s"::TEXT, 'NULL')`, t.Schema, t.Name, k)) + + pathName := fmt.Sprintf( + `"%s__%s__%s__path" || ARRAY["%s"."%s"."%s"]`, + t.Schema, t.Name, k, + t.Schema, t.Name, k, + ) + recursiveKeys = append(recursiveKeys, pathName) + } + recursiveKeys = append(recursiveKeys, recursiveIntegrityCheck) + + recursiveSelect := fmt.Sprintf("SELECT %s", strings.Join(recursiveKeys, ", ")) + recursiveFromClause := fmt.Sprintf(`FROM "%s" `, queryName) + recursiveJoins = append(recursiveJoins, generateJoinClauseForDroppedEdge(droppedEdge, queryName)) + nullabilityMap = make(map[int]bool) + for _, e := range edges { + isNullable := e.isNullable + if !isNullable { + isNullable = nullabilityMap[e.from.idx] + } + nullabilityMap[e.to.idx] = isNullable + joinType := joinTypeInner + if isNullable { + joinType = joinTypeLeft } + recursiveJoins = append(recursiveJoins, generateJoinClauseV2(e, joinType, overriddenTables)) } - if valuableEdgesIdx == -1 { - panic("is not found") + recursiveValidCond := fmt.Sprintf(`"%s"."%s"`, queryName, "valid") + recursiveWhereConds = append(recursiveWhereConds, recursiveValidCond) + for _, k := range edges[0].from.table.PrimaryKey { + t := edges[0].from.table + + recursivePathCheck := fmt.Sprintf( + `NOT "%s"."%s"."%s" = ANY("%s"."%s__%s__%s__%s")`, + t.Schema, t.Name, k, + queryName, t.Schema, t.Name, k, "path", + ) + + recursiveWhereConds = append(recursiveWhereConds, recursivePathCheck) } + recursiveWhereClause := generateWhereClause(recursiveWhereConds) + + recursiveQuery := fmt.Sprintf(`%s %s %s %s`, + recursiveSelect, recursiveFromClause, strings.Join(recursiveJoins, " "), recursiveWhereClause, + ) - resCycle := slices.Clone(cycle[valuableEdgesIdx:]) - resCycle = append(resCycle, cycle[:valuableEdgesIdx]...) - return resCycle + query := fmt.Sprintf("( %s ) UNION ( %s )", initialQuery, recursiveQuery) + return query } -func generateQuery(queryName string, cycle []*Edge, rest []*CondensedEdge, overriddenTables map[toolkit.Oid]string) string { +func generateOverlapQuery( + scopeId int, + queryName string, cycle []*Edge, rest []*CondensedEdge, overriddenTables map[toolkit.Oid]string, + overlap [][]*Edge, +) string { var ( selectKeys []string initialJoins, recursiveJoins []string initialWhereConds, recursiveWhereConds []string - integrityCheck string cycleSubsetConds []string edges = slices.Clone(cycle[:len(cycle)-1]) droppedEdge = cycle[len(cycle)-1] @@ -653,11 +784,8 @@ func generateQuery(queryName string, cycle []*Edge, rest []*CondensedEdge, overr initialKeys := slices.Clone(selectKeys) initialKeys = append(initialKeys, initialPathSelectionKeys...) initFromClause := fmt.Sprintf(`FROM "%s"."%s" `, edges[0].from.table.Schema, edges[0].from.table.Name) - integrityCheck = "TRUE AS valid" - initialKeys = append(initialKeys, integrityCheck) - initialWhereConds = append(initialWhereConds, cycleSubsetConds...) + initialWhereConds = append(initialWhereConds, generateInClauseForOverlap(scopeId, edges, overlap)) - initialSelect := fmt.Sprintf("SELECT %s", strings.Join(initialKeys, ", ")) nullabilityMap := make(map[int]bool) for _, e := range edges { isNullable := e.isNullable @@ -673,15 +801,18 @@ func generateQuery(queryName string, cycle []*Edge, rest []*CondensedEdge, overr } integrityChecks := generateIntegrityChecksForNullableEdges(nullabilityMap, edges, overriddenTables) - initialWhereConds = append(initialWhereConds, integrityChecks...) + integrityChecks = append(integrityChecks, cycleSubsetConds...) + initialIntegrityCheck := fmt.Sprintf("(%s) AS valid", strings.Join(integrityChecks, " AND ")) + initialKeys = append(initialKeys, initialIntegrityCheck) + initialSelect := fmt.Sprintf("SELECT %s", strings.Join(initialKeys, ", ")) + + //initialWhereConds = append(initialWhereConds, integrityChecks...) initialWhereClause := generateWhereClause(initialWhereConds) initialQuery := fmt.Sprintf(`%s %s %s %s`, initialSelect, initFromClause, strings.Join(initialJoins, " "), initialWhereClause, ) - recursiveIntegrityChecks := slices.Clone(cycleSubsetConds) - recursiveIntegrityChecks = append(recursiveIntegrityChecks, integrityChecks...) - recursiveIntegrityCheck := fmt.Sprintf("(%s) AS valid", strings.Join(recursiveIntegrityChecks, " AND ")) + recursiveIntegrityCheck := fmt.Sprintf("(%s) AS valid", strings.Join(integrityChecks, " AND ")) recursiveKeys := slices.Clone(selectKeys) for _, k := range edges[0].from.table.PrimaryKey { t := edges[0].from.table @@ -736,6 +867,31 @@ func generateQuery(queryName string, cycle []*Edge, rest []*CondensedEdge, overr return query } +func generateInClauseForOverlap(scopeId int, edges []*Edge, overlap [][]*Edge) string { + var ( + overlapTables []string + unionQueryParts []string + rightTableKeys, leftTableKeys []string + ) + + for _, c := range overlap { + overlapTables = append(overlapTables, getCycleQueryName(scopeId, c, "")) + } + for _, k := range edges[0].from.table.PrimaryKey { + rightTableKey := fmt.Sprintf(`"%s__%s__%s"`, edges[0].from.table.Schema, edges[0].from.table.Name, k) + rightTableKeys = append(rightTableKeys, rightTableKey) + leftTableKey := fmt.Sprintf(`"%s"."%s"."%s"`, edges[0].from.table.Schema, edges[0].from.table.Name, k) + leftTableKeys = append(leftTableKeys, leftTableKey) + } + for _, t := range overlapTables { + unionQueryParts = append(unionQueryParts, fmt.Sprintf(`SELECT %s FROM "%s"`, strings.Join(rightTableKeys, ", "), t)) + } + unionQuery := strings.Join(unionQueryParts, " UNION ") + + res := fmt.Sprintf(`(%s) IN (%s)`, strings.Join(leftTableKeys, ", "), unionQuery) + return res +} + func getTablesFromCycle(cycle []*Edge) (res []*entries.Table) { for _, e := range cycle { res = append(res, e.to.table) @@ -783,15 +939,34 @@ func generateIntegrityChecksForNullableEdges(nullabilityMap map[int]bool, edges return } -func generateIntegrityCheckJoinConds(cycle []*Edge, table *entries.Table, tableName string) string { +func generateUnitedCyclesQuery(scopeId int, cycles [][]*Edge) string { + var tablesSelection []string + for _, c := range cycles { + q1 := fmt.Sprintf(`SELECT * FROM "%s"`, getCycleQueryName(scopeId, c, "")) + tablesSelection = append(tablesSelection, q1) + q2 := fmt.Sprintf(`SELECT * FROM "%s"`, getCycleQueryName(scopeId, c, "overlap")) + tablesSelection = append(tablesSelection, q2) + } + res := strings.Join(tablesSelection, " UNION ") + return res +} + +func generateIntegrityCheckJoinConds(scopeId int, cycles [][]*Edge) string { var ( + table = cycles[0][0].from.table allPks []string mainTablePks []string unnestSelections []string + tableName = getCycleQueryName(scopeId, cycles[0], "") ) - for _, t := range getTablesFromCycle(cycle) { + if len(cycles) > 1 { + prefix := getCyclesGroupQueryName(scopeId, cycles[0]) + tableName = fmt.Sprintf("%s__united", prefix) + } + + for _, t := range getTablesFromCycle(cycles[0]) { for _, k := range t.PrimaryKey { key := fmt.Sprintf(`"%s"."%s__%s__%s"`, tableName, t.Schema, t.Name, k) allPks = append(allPks, key) @@ -822,15 +997,15 @@ func generateIntegrityCheckJoinConds(cycle []*Edge, table *entries.Table, tableN return filteredQuery } -func generateAllTablesValidPkSelection(cycle []*Edge, scopeId int, forTable *entries.Table) string { +func generateAllTablesValidPkSelection(cycles [][]*Edge, scopeId int, forTable *entries.Table) string { var unionParts []string - for _, t := range getTablesFromCycle(cycle) { + for _, t := range getTablesFromCycle(cycles[0]) { var ( selectionKeys []string - cycleId = getCycleIdent(cycle) - filteredQueryName = fmt.Sprintf("__s%d__c%s__%s__%s__filtered", scopeId, cycleId, t.Schema, t.Name) + groupId = getCycleGroupId(cycles[0]) + filteredQueryName = fmt.Sprintf("__s%d__g%s__%s__%s__filtered", scopeId, groupId, t.Schema, t.Name) ) for _, k := range forTable.PrimaryKey { @@ -854,3 +1029,42 @@ func shiftCycle(cycle []*Edge) (res []*Edge) { res = append(res, cycle[:len(cycle)-1]...) return } + +func getCycleQueryName(scopeId int, cycle []*Edge, postfix string) string { + // queryName - name of a query in the recursive CTE + // where: + // * s - scope id + // * g - group id + // * c - cycle id + // * postfix with table name + mainTable := cycle[0].from.table + groupId := getCycleGroupId(cycle) + cycleId := getCycleId(cycle) + res := fmt.Sprintf("__s%d__g%s__c%s__%s__%s", scopeId, groupId, cycleId, mainTable.Schema, mainTable.Name) + if postfix != "" { + res = fmt.Sprintf("%s__%s", res, postfix) + } + return res +} + +func getCyclesGroupQueryName(scopeId int, cycle []*Edge) string { + // queryName - name of a query in the recursive CTE + // where: + // * s - scope id + // * g - group id + // * postfix with table name + mainTable := cycle[0].from.table + groupId := getCycleGroupId(cycle) + return getCyclesGroupQueryNameByMainTable(scopeId, groupId, mainTable) +} + +func getCyclesGroupQueryNameByMainTable(scopeId int, groupId string, mainTable *entries.Table) string { + return fmt.Sprintf("__s%d__g%s__%s__%s", scopeId, groupId, mainTable.Schema, mainTable.Name) +} + +func shiftCycleGroup(g [][]*Edge) [][]*Edge { + for idx := range g { + g[idx] = shiftCycle(g[idx]) + } + return g +}