diff --git a/internal/flypg/admin/admin.go b/internal/flypg/admin/admin.go index 1d3b4a86..2f1925b1 100644 --- a/internal/flypg/admin/admin.go +++ b/internal/flypg/admin/admin.go @@ -2,10 +2,7 @@ package admin import ( "context" - "database/sql" "fmt" - "log" - "regexp" "strconv" "strings" @@ -431,109 +428,3 @@ func ValidatePGSettings(ctx context.Context, conn *pgx.Conn, requested map[strin return nil } - -func fixCollationMismatch(ctx context.Context, db *sql.DB) error { - query := ` - SELECT pg_describe_object(refclassid, refobjid, refobjsubid) AS "Collation", - pg_describe_object(classid, objid, objsubid) AS "Object" - FROM pg_depend d JOIN pg_collation c - ON refclassid = 'pg_collation'::regclass AND refobjid = c.oid - WHERE c.collversion <> pg_collation_actual_version(c.oid) - ORDER BY 1, 2;` - - rows, err := db.Query(query) - if err != nil { - return fmt.Errorf("failed to query collation mismatches: %v", err) - } - defer rows.Close() - - var collation, object string - for rows.Next() { - if err := rows.Scan(&collation, &object); err != nil { - return fmt.Errorf("failed to scan row: %v", err) - } - - fixObject(db, object) - } - - if err := rows.Err(); err != nil { - return fmt.Errorf("failed to iterate over rows: %v", err) - } - - return nil -} - -func fixObject(db *sql.DB, object string) { - fmt.Printf("Fixing object: %s\n", object) - - switch { - case regexp.MustCompile(`index`).MatchString(object): - // reindex(db, object) - case regexp.MustCompile(`column`).MatchString(object): - // alterColumn(db, object) - case regexp.MustCompile(`constraint`).MatchString(object): - // dropAndRecreateConstraint(db, object) - case regexp.MustCompile(`materialized view`).MatchString(object): - // refreshMaterializedView(db, object) - case regexp.MustCompile(`function`).MatchString(object): - // recreateFunction(db, object) - case regexp.MustCompile(`view`).MatchString(object): - // recreateView(db, object) - case regexp.MustCompile(`trigger`).MatchString(object): - // recreateTrigger(db, object) - default: - log.Printf("Unknown object type: %s", object) - } -} - -const refreshCollationSQL = ` -DO $$ -DECLARE - r RECORD; -BEGIN - FOR r IN (SELECT datname FROM pg_database WHERE datallowconn = true) - LOOP - BEGIN - EXECUTE 'ALTER DATABASE ' || quote_ident(r.datname) || ' REFRESH COLLATION VERSION;'; - EXCEPTION - WHEN OTHERS THEN - RAISE NOTICE 'Failed to refresh collation for database: % - %', r.datname, SQLERRM; - END; - END LOOP; -END $$;` - -// RefreshCollationVersion will refresh the collation version for all databases. -func RefreshCollationVersion(ctx context.Context, conn *pgx.Conn) error { - _, err := conn.Exec(ctx, refreshCollationSQL) - return err -} - -const identifyCollationObjectsSQL = ` -SELECT pg_describe_object(refclassid, refobjid, refobjsubid) AS "Collation", - pg_describe_object(classid, objid, objsubid) AS "Object" - FROM pg_depend d JOIN pg_collation c - ON refclassid = 'pg_collation'::regclass AND refobjid = c.oid - WHERE c.collversion <> pg_collation_actual_version(c.oid) - ORDER BY 1, 2;` - -const reIndexSQL = ` -DO $$ -DECLARE - r RECORD; -BEGIN - FOR r IN (SELECT n.nspname, i.relname - FROM pg_index x - JOIN pg_class c ON c.oid = x.indrelid - JOIN pg_namespace n ON n.oid = c.relnamespace - JOIN pg_class i ON i.oid = x.indexrelid - JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(x.indkey) - JOIN pg_collation col ON col.oid = a.attcollation - WHERE col.collname = 'en_US.utf8') LOOP - EXECUTE 'REINDEX INDEX ' || quote_ident(r.nspname) || '.' || quote_ident(r.relname); - END LOOP; -END $$;` - -func ReIndex(ctx context.Context, conn *pgx.Conn) error { - _, err := conn.Exec(ctx, reIndexSQL) - return err -} diff --git a/internal/flypg/admin/test b/internal/flypg/admin/test new file mode 100644 index 00000000..33789dda --- /dev/null +++ b/internal/flypg/admin/test @@ -0,0 +1,29 @@ +-- CREATE DATABASE test_db; + +-- -- Create a table with a collated column +-- CREATE TABLE test_table ( +-- id SERIAL PRIMARY KEY, +-- name TEXT COLLATE "en_US.utf8" +-- ); + +-- -- Insert data into the table +-- INSERT INTO test_table (name) +-- VALUES +-- ('apple'), +-- ('banana'), +-- ('cherry'); + +-- -- Create an index on the collated column +-- CREATE INDEX idx_test_table_name ON test_table (name); + +-- -- Create a view +-- CREATE VIEW test_view AS +-- SELECT id, name +-- FROM test_table +-- WHERE name LIKE 'a%'; + +-- -- Create a materialized view +-- CREATE MATERIALIZED VIEW test_materialized_view AS +-- SELECT id, name +-- FROM test_table +-- WHERE name LIKE 'b%'; diff --git a/internal/flypg/collation.go b/internal/flypg/collation.go new file mode 100644 index 00000000..8a6472ac --- /dev/null +++ b/internal/flypg/collation.go @@ -0,0 +1,287 @@ +package flypg + +import ( + "context" + "crypto/md5" + "database/sql" + "fmt" + "log" + "os" + + "github.com/fly-apps/postgres-flex/internal/utils" + "github.com/jackc/pgx/v5" + "golang.org/x/exp/slices" +) + +const collationVersionFile = "/data/.collationVersion" + +func collationChanged() (bool, error) { + // Short-circuit if there's no collation file. + if !collationVersionFileExists() { + return true, nil + } + + // Calculate the md5sum of the ldd version. + sum, err := calculateLDDVersionSum() + if err != nil { + return false, fmt.Errorf("failed to calculate collation sum: %w", err) + } + + // Read the collation lock file. + body, err := os.ReadFile(collationVersionFile) + if err != nil { + return false, fmt.Errorf("failed to read collation lock file: %w", err) + } + + // Compare the md5sum of the ldd version with version in the collation version file. + return !slices.Equal(sum[:], body), nil +} + +func collationVersionFileExists() bool { + _, err := os.Stat(collationVersionFile) + return !os.IsNotExist(err) +} + +func calculateLDDVersionSum() ([16]byte, error) { + // md5sum the ldd version so we can verify if the system has changed. + output, err := utils.RunCommand("ldd --version", "postgres") + if err != nil { + return [16]byte{}, fmt.Errorf("failed to capture ldd version: %w", err) + } + + // Calculate the md5sum of the ldd version output + return md5.Sum(output), nil +} + +func writeCollationLock(sum [16]byte) error { + // Write the collation lock file. + if err := os.WriteFile(collationVersionFile, sum[:], 0600); err != nil { + return fmt.Errorf("failed to write collation lock file: %w", err) + } + + return nil +} + +const refreshCollationVersionPerDatabaseSQL = ` +DO $$ +DECLARE + r RECORD; +BEGIN + FOR r IN (SELECT datname FROM pg_database WHERE datallowconn = true) + LOOP + BEGIN + EXECUTE 'ALTER DATABASE ' || quote_ident(r.datname) || ' REFRESH COLLATION VERSION;'; + EXCEPTION + WHEN OTHERS THEN + RAISE NOTICE 'Failed to refresh collation for database: % - %', r.datname, SQLERRM; + END; + END LOOP; +END $$; +` + +func identifyCollationMismatches(ctx context.Context, conn *pgx.Conn) (bool, error) { + // Fetch the collation version for each database. + preVersionMap, err := fetchDatabaseCollationVersions(ctx, conn) + if err != nil { + return false, fmt.Errorf("failed to fetch collation versions: %w", err) + } + + // Refresh the collation version for each collation. + if err := refreshCollations(ctx, conn); err != nil { + return false, fmt.Errorf("failed to refresh collations: %w", err) + } + + // Fetch the collation version for each database after the refresh. + postVersionMap, err := fetchDatabaseCollationVersions(ctx, conn) + if err != nil { + return false, fmt.Errorf("failed to fetch collation versions: %w", err) + } + + // Return whether any collation versions have changed. + return collationVersionChanged(preVersionMap, postVersionMap), nil +} + +const fetchCollationsSQL = `SELECT DISTINCT datcollate FROM pg_database WHERE datcollate != 'C';` + +// FetchCollations will fetch the distinct collations within the cluster. +func fetchCollations(ctx context.Context, conn *pgx.Conn) ([]string, error) { + rows, err := conn.Query(ctx, fetchCollationsSQL) + if err != nil { + return nil, fmt.Errorf("failed to fetch collations: %w", err) + } + defer rows.Close() + + var collations []string + for rows.Next() { + var collation sql.NullString + if err := rows.Scan(&collation); err != nil { + return nil, fmt.Errorf("failed to scan collation row: %w", err) + } + if collation.Valid { + collations = append(collations, collation.String) + } + } + + if rows.Err() != nil { + return nil, rows.Err() + } + return collations, nil +} + +const fetchCollationVersionSQL = ` +SELECT datname, datcollversion, collversion FROM pg_database + d JOIN pg_collation c ON d.datcollate = c.collname +WHERE datallowconn = true; +` + +// FetchDatabaseCollationVersions will fetch the collation version for each database within the cluster. +func fetchDatabaseCollationVersions(ctx context.Context, conn *pgx.Conn) (map[string][]string, error) { + rows, err := conn.Query(ctx, fetchCollationVersionSQL) + if err != nil { + return nil, fmt.Errorf("failed to fetch database collation versions: %w", err) + } + defer rows.Close() + + versions := make(map[string][]string) + for rows.Next() { + var datname, datcollversion, collversion string + if err := rows.Scan(&datname, &datcollversion, &collversion); err != nil { + return nil, fmt.Errorf("failed to scan db collation row: %w", err) + } + versions[datname] = []string{datcollversion, collversion} + } + + if rows.Err() != nil { + return nil, rows.Err() + } + + return versions, nil +} + +// RefreshCollationVersion will refresh the collation version for the specified collation. +func refreshCollationVersion(ctx context.Context, conn *pgx.Conn, collation string) error { + query := fmt.Sprintf("ALTER COLLATION pg_catalog.\"%s\" REFRESH VERSION;", collation) + _, err := conn.Exec(ctx, query) + return err +} + +func collationVersionChanged(before, after map[string][]string) bool { + for db := range before { + if db == "template0" || db == "template1" { + continue + } + + if !slices.Equal(before[db], after[db]) { + return true + } + } + return false +} + +const detectCollationMismatchSQL = ` +SELECT COUNT(*) FROM pg_depend d + JOIN pg_collation c ON refclassid = 'pg_collation'::regclass AND refobjid = c.oid +WHERE c.collversion <> pg_collation_actual_version(c.oid); +` + +// CountCollationMismatchs will detect if there are any collation mismatches within the database. +func countCollationMismatchs(ctx context.Context, conn *pgx.Conn) (int, error) { + var count int + err := conn.QueryRow(ctx, detectCollationMismatchSQL).Scan(&count) + if err != nil { + return 0, err + } + + return count, nil +} + +const identifyImpactedCollationObjectsSQL = ` +SELECT pg_describe_object(refclassid, refobjid, refobjsubid) AS "Collation", + pg_describe_object(classid, objid, objsubid) AS "Object" +FROM pg_depend d JOIN pg_collation c + ON refclassid = 'pg_collation'::regclass AND refobjid = c.oid +WHERE c.collversion <> pg_collation_actual_version(c.oid) +ORDER BY 1, 2;` + +// IdentifyImpactedCollationObjects will identify any objects within the database that are +// impacted by collation mismatches. +func identifyImpactedCollationObjects(ctx context.Context, db *sql.DB) error { + rows, err := db.QueryContext(ctx, identifyImpactedCollationObjectsSQL) + if err != nil { + return fmt.Errorf("failed to query collation mismatches: %v", err) + } + defer rows.Close() + + var collation, object string + for rows.Next() { + if err := rows.Scan(&collation, &object); err != nil { + return fmt.Errorf("failed to scan row: %v", err) + } + + log.Printf("[WARN] Collation mismatch detected: %s - %s", collation, object) + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("failed to iterate over rows: %v", err) + } + + return nil +} + +const impactedCollationIndexesSQL = ` +SELECT n.nspname AS schema_name, c.relname AS index_name + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + JOIN pg_index i ON i.indexrelid = c.oid + JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(i.indkey) + JOIN pg_collation col ON col.oid = a.attcollation +WHERE col.collversion <> pg_collation_actual_version(col.oid); +` + +// ReIndexMismatchedIndexes will reindex any indexes that are impacted by collation mismatches. +func reIndexMismatchedIndexes(ctx context.Context, conn *pgx.Conn) error { + rows, err := conn.Query(ctx, impactedCollationIndexesSQL) + if err != nil { + return fmt.Errorf("failed to fetch indexes impacted by collation mismatch: %v", err) + } + defer rows.Close() + + for rows.Next() { + var schemaName, indexName string + err := rows.Scan(&schemaName, &indexName) + if err != nil { + return fmt.Errorf("failed to scan row: %v", err) + } + + log.Printf("[WARN] Reindexing index %s.%s concurrently to address collation mismatch\n", schemaName, indexName) + _, err = conn.Exec(ctx, fmt.Sprintf("REINDEX INDEX CONCURRENTLY %s.%s;", schemaName, indexName)) + if err != nil { + log.Printf("failed to reindex %s.%s: %v\n", schemaName, indexName, err) + } + } + + return nil +} + +func refreshCollations(ctx context.Context, conn *pgx.Conn) error { + // Fetch the distinct collations within the cluster. + collations, err := fetchCollations(ctx, conn) + if err != nil { + return fmt.Errorf("failed to fetch collations: %w", err) + } + + // Refresh the collation version for each collation. + for _, collation := range collations { + if err := refreshCollationVersion(ctx, conn, collation); err != nil { + return fmt.Errorf("failed to refresh collation version: %w", err) + } + } + + // Refresh the collation version for each database. + _, err = conn.Exec(ctx, refreshCollationVersionPerDatabaseSQL) + if err != nil { + return fmt.Errorf("failed to refresh collation version per database: %w", err) + } + + return nil +} diff --git a/internal/flypg/node.go b/internal/flypg/node.go index 9e70ef65..1ae244ff 100644 --- a/internal/flypg/node.go +++ b/internal/flypg/node.go @@ -306,9 +306,9 @@ func (n *Node) PostInit(ctx context.Context) error { } } - // Refresh collation for all databases. - if err := refreshCollation(ctx, conn); err != nil { - log.Printf("failed to refresh collation: %s", err) + // Confirm collation integrity. + if err := n.confirmCollationIntegrity(ctx, conn); err != nil { + log.Printf("[WARN] problem occurred while evaluating collation integrity: %s", err) } case StandbyRoleName: @@ -478,27 +478,63 @@ func setDirOwnership() error { return nil } -func (n *Node) fixCollationMismatch(ctx context.Context, conn *pgx.Conn) error { +func (n *Node) confirmCollationIntegrity(ctx context.Context, conn *pgx.Conn) error { + // Check to see if the version has changed + changed, err := collationChanged() + if err != nil { + return fmt.Errorf("failed to check collation version file: %s", err) + } + + if !changed { + log.Printf("[INFO] Collation version has not changed.\n") + return nil + } + + fmt.Printf("[INFO] Collation version has changed or has not been evaluated. Evaluating collation integrity.\n") + + // Evaluate collation integrity. + mismatch, err := identifyCollationMismatches(ctx, conn) + if err != nil { + log.Printf("[WARN] Failed to evaluate collation integrity: %s\n", err) + } + + if !mismatch { + return nil + } + + log.Printf("[WARN] Collation mismatches detected. Refreshing collation versions.\n") + + // Detect any indexes that are impacted by collation mismatches. + // Unfortunately, this needs to be checked per database. dbs, err := admin.ListDatabases(ctx, conn) if err != nil { return fmt.Errorf("failed to list databases: %s", err) } - // Add the template1 database to the list of databases to refresh. dbs = append(dbs, admin.DbInfo{Name: "template1"}) for _, db := range dbs { + // Establish a connection to the specified database. dbConn, err := n.NewLocalConnection(ctx, db.Name, n.SUCredentials) if err != nil { - return fmt.Errorf("failed to establish connection to local node: %s", err) + return fmt.Errorf("failed to establish connection to database %s: %s", db.Name, err) } defer func() { _ = dbConn.Close(ctx) }() - if err := admin.RefreshCollationVersion(ctx, dbConn); err != nil { - return fmt.Errorf("failed to refresh collation: %s", err) + // Count collation mismatches + count, err := countCollationMismatchs(ctx, dbConn) + if err != nil { + log.Printf("[WARN] Failed to count collation mismatches: %s\n", err) + } + + // Skip if no mismatches are found. + if err == nil && count == 0 { + continue } - if err := admin.ReIndex(ctx, dbConn); err != nil { + log.Printf("[WARN] %d collation mismatches detected %s\n", count, db.Name) + + if err := reIndexMismatchedIndexes(ctx, dbConn); err != nil { return fmt.Errorf("failed to reindex database: %s", err) } }