diff --git a/flow/cmd/validate_mirror.go b/flow/cmd/validate_mirror.go index 91b33677ac..5ae070a32e 100644 --- a/flow/cmd/validate_mirror.go +++ b/flow/cmd/validate_mirror.go @@ -4,97 +4,11 @@ import ( "context" "fmt" "log/slog" - "strconv" - "strings" + connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/generated/protos" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" ) -func (h *FlowRequestHandler) CheckReplicationPermissions(ctx context.Context, pool *pgxpool.Pool, username string) error { - if pool == nil { - return fmt.Errorf("check replication permissions: pool is nil") - } - - var replicationRes bool - err := pool.QueryRow(ctx, "SELECT rolreplication FROM pg_roles WHERE rolname = $1;", username).Scan(&replicationRes) - if err != nil { - return err - } - - if !replicationRes { - return fmt.Errorf("postgres user does not have replication role") - } - - // check wal_level - var walLevel string - err = pool.QueryRow(ctx, "SHOW wal_level;").Scan(&walLevel) - if err != nil { - return err - } - - if walLevel != "logical" { - return fmt.Errorf("wal_level is not logical") - } - - // max_wal_senders must be at least 2 - var maxWalSendersRes string - err = pool.QueryRow(ctx, "SHOW max_wal_senders;").Scan(&maxWalSendersRes) - if err != nil { - return err - } - - maxWalSenders, err := strconv.Atoi(maxWalSendersRes) - if err != nil { - return err - } - - if maxWalSenders < 2 { - return fmt.Errorf("max_wal_senders must be at least 1") - } - - return nil -} - -func (h *FlowRequestHandler) CheckSourceTables(ctx context.Context, pool *pgxpool.Pool, tableNames []string, pubName string) error { - if pool == nil { - return fmt.Errorf("check tables: pool is nil") - } - - // Check that we can select from all tables - for _, tableName := range tableNames { - var row pgx.Row - err := pool.QueryRow(ctx, fmt.Sprintf("SELECT * FROM %s LIMIT 0;", tableName)).Scan(&row) - if err != nil && err != pgx.ErrNoRows { - return err - } - } - - // Check if tables belong to publication - tableArr := make([]string, 0, len(tableNames)) - for _, tableName := range tableNames { - tableArr = append(tableArr, fmt.Sprintf("'%s'", tableName)) - } - - tableStr := strings.Join(tableArr, ",") - - if pubName != "" { - var pubTableCount int - err := pool.QueryRow(ctx, fmt.Sprintf("select COUNT(DISTINCT(schemaname||'.'||tablename)) from pg_publication_tables "+ - "where schemaname||'.'||tablename in (%s) and pubname=$1;", tableStr), pubName).Scan(&pubTableCount) - if err != nil { - return err - } - - if pubTableCount != len(tableNames) { - return fmt.Errorf("not all tables belong to publication") - } - } - - return nil -} - func (h *FlowRequestHandler) ValidateCDCMirror( ctx context.Context, req *protos.CreateCDCFlowRequest, ) (*protos.ValidateCDCMirrorResponse, error) { @@ -112,7 +26,7 @@ func (h *FlowRequestHandler) ValidateCDCMirror( } // 2. Check permissions of postgres peer - err = h.CheckReplicationPermissions(ctx, sourcePool, sourcePeerConfig.User) + err = connpostgres.CheckReplicationPermissions(ctx, sourcePool, sourcePeerConfig.User) if err != nil { return &protos.ValidateCDCMirrorResponse{ Ok: false, @@ -125,7 +39,7 @@ func (h *FlowRequestHandler) ValidateCDCMirror( sourceTables = append(sourceTables, tableMapping.SourceTableIdentifier) } - err = h.CheckSourceTables(ctx, sourcePool, sourceTables, req.ConnectionConfigs.PublicationName) + err = connpostgres.CheckSourceTables(ctx, sourcePool, sourceTables, req.ConnectionConfigs.PublicationName) if err != nil { return &protos.ValidateCDCMirrorResponse{ Ok: false, diff --git a/flow/cmd/validate_peer.go b/flow/cmd/validate_peer.go index 52048fd5c5..573a5dfe1c 100644 --- a/flow/cmd/validate_peer.go +++ b/flow/cmd/validate_peer.go @@ -4,33 +4,14 @@ import ( "context" "fmt" "log/slog" - "strconv" "github.com/PeerDB-io/peer-flow/connectors" + connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/jackc/pgx/v5/pgxpool" ) -func (h *FlowRequestHandler) GetPostgresVersion(ctx context.Context, pool *pgxpool.Pool) (int, error) { - if pool == nil { - return -1, fmt.Errorf("version check: pool is nil") - } - - var versionRes string - err := pool.QueryRow(ctx, "SHOW server_version_num;").Scan(&versionRes) - if err != nil { - return -1, err - } - - version, err := strconv.Atoi(versionRes) - if err != nil { - return -1, err - } - - return version / 10000, nil -} - func (h *FlowRequestHandler) ValidatePeer( ctx context.Context, req *protos.ValidatePeerRequest, @@ -73,7 +54,7 @@ func (h *FlowRequestHandler) ValidatePeer( slog.Error("/peer/validate: failed to obtain peer connection", slog.Any("error", err)) return nil, err } - version, err := h.GetPostgresVersion(ctx, sourcePool) + version, err := connpostgres.GetPostgresVersion(ctx, sourcePool) if err != nil { slog.Error("/peer/validate: pg version check", slog.Any("error", err)) return nil, err diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index f18d38be52..5b411b0530 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -1,10 +1,12 @@ package connpostgres import ( + "context" "errors" "fmt" "log" "regexp" + "strconv" "strings" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -13,6 +15,7 @@ import ( "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" "github.com/lib/pq/oid" ) @@ -572,3 +575,105 @@ func (c *PostgresConnector) getCurrentLSN() (pglogrepl.LSN, error) { func (c *PostgresConnector) getDefaultPublicationName(jobName string) string { return fmt.Sprintf("peerflow_pub_%s", jobName) } + +func CheckSourceTables(ctx context.Context, pool *pgxpool.Pool, tableNames []string, pubName string) error { + if pool == nil { + return fmt.Errorf("check tables: pool is nil") + } + + // Check that we can select from all tables + for _, tableName := range tableNames { + var row pgx.Row + err := pool.QueryRow(ctx, fmt.Sprintf("SELECT * FROM %s LIMIT 0;", tableName)).Scan(&row) + if err != nil && err != pgx.ErrNoRows { + return err + } + } + + // Check if tables belong to publication + tableArr := make([]string, 0, len(tableNames)) + for _, tableName := range tableNames { + tableArr = append(tableArr, fmt.Sprintf("'%s'", tableName)) + } + + tableStr := strings.Join(tableArr, ",") + + if pubName != "" { + var pubTableCount int + err := pool.QueryRow(ctx, fmt.Sprintf("select COUNT(DISTINCT(schemaname||'.'||tablename)) from pg_publication_tables "+ + "where schemaname||'.'||tablename in (%s) and pubname=$1;", tableStr), pubName).Scan(&pubTableCount) + if err != nil { + return err + } + + if pubTableCount != len(tableNames) { + return fmt.Errorf("not all tables belong to publication") + } + } + + return nil +} + +func CheckReplicationPermissions(ctx context.Context, pool *pgxpool.Pool, username string) error { + if pool == nil { + return fmt.Errorf("check replication permissions: pool is nil") + } + + var replicationRes bool + err := pool.QueryRow(ctx, "SELECT rolreplication FROM pg_roles WHERE rolname = $1;", username).Scan(&replicationRes) + if err != nil { + return err + } + + if !replicationRes { + return fmt.Errorf("postgres user does not have replication role") + } + + // check wal_level + var walLevel string + err = pool.QueryRow(ctx, "SHOW wal_level;").Scan(&walLevel) + if err != nil { + return err + } + + if walLevel != "logical" { + return fmt.Errorf("wal_level is not logical") + } + + // max_wal_senders must be at least 2 + var maxWalSendersRes string + err = pool.QueryRow(ctx, "SHOW max_wal_senders;").Scan(&maxWalSendersRes) + if err != nil { + return err + } + + maxWalSenders, err := strconv.Atoi(maxWalSendersRes) + if err != nil { + return err + } + + if maxWalSenders < 2 { + return fmt.Errorf("max_wal_senders must be at least 2") + } + + return nil +} + +func GetPostgresVersion(ctx context.Context, pool *pgxpool.Pool) (int, error) { + if pool == nil { + return -1, fmt.Errorf("version check: pool is nil") + } + + var versionRes string + err := pool.QueryRow(ctx, "SHOW server_version_num;").Scan(&versionRes) + if err != nil { + return -1, err + } + + version, err := strconv.Atoi(versionRes) + if err != nil { + return -1, err + } + + return version / 10000, nil +}