Skip to content

Commit

Permalink
move to client.go
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj committed Jan 19, 2024
1 parent 4ee8165 commit 876eeb8
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 110 deletions.
92 changes: 3 additions & 89 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,97 +4,11 @@ import (
"context"
"fmt"
"log/slog"
"strconv"
"strings"

connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)

func (h *FlowRequestHandler) CheckReplicationPermissions(ctx context.Context, pool *pgxpool.Pool, username string) error {
if pool == nil {
return fmt.Errorf("check replication permissions: pool is nil")
}

var replicationRes bool
err := pool.QueryRow(ctx, "SELECT rolreplication FROM pg_roles WHERE rolname = $1;", username).Scan(&replicationRes)
if err != nil {
return err
}

if !replicationRes {
return fmt.Errorf("postgres user does not have replication role")
}

// check wal_level
var walLevel string
err = pool.QueryRow(ctx, "SHOW wal_level;").Scan(&walLevel)
if err != nil {
return err
}

if walLevel != "logical" {
return fmt.Errorf("wal_level is not logical")
}

// max_wal_senders must be at least 2
var maxWalSendersRes string
err = pool.QueryRow(ctx, "SHOW max_wal_senders;").Scan(&maxWalSendersRes)
if err != nil {
return err
}

maxWalSenders, err := strconv.Atoi(maxWalSendersRes)
if err != nil {
return err
}

if maxWalSenders < 2 {
return fmt.Errorf("max_wal_senders must be at least 1")
}

return nil
}

func (h *FlowRequestHandler) CheckSourceTables(ctx context.Context, pool *pgxpool.Pool, tableNames []string, pubName string) error {
if pool == nil {
return fmt.Errorf("check tables: pool is nil")
}

// Check that we can select from all tables
for _, tableName := range tableNames {
var row pgx.Row
err := pool.QueryRow(ctx, fmt.Sprintf("SELECT * FROM %s LIMIT 0;", tableName)).Scan(&row)
if err != nil && err != pgx.ErrNoRows {
return err
}
}

// Check if tables belong to publication
tableArr := make([]string, 0, len(tableNames))
for _, tableName := range tableNames {
tableArr = append(tableArr, fmt.Sprintf("'%s'", tableName))
}

tableStr := strings.Join(tableArr, ",")

if pubName != "" {
var pubTableCount int
err := pool.QueryRow(ctx, fmt.Sprintf("select COUNT(DISTINCT(schemaname||'.'||tablename)) from pg_publication_tables "+
"where schemaname||'.'||tablename in (%s) and pubname=$1;", tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}

if pubTableCount != len(tableNames) {
return fmt.Errorf("not all tables belong to publication")
}
}

return nil
}

func (h *FlowRequestHandler) ValidateCDCMirror(
ctx context.Context, req *protos.CreateCDCFlowRequest,
) (*protos.ValidateCDCMirrorResponse, error) {
Expand All @@ -112,7 +26,7 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}

// 2. Check permissions of postgres peer
err = h.CheckReplicationPermissions(ctx, sourcePool, sourcePeerConfig.User)
err = connpostgres.CheckReplicationPermissions(ctx, sourcePool, sourcePeerConfig.User)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
Expand All @@ -125,7 +39,7 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
sourceTables = append(sourceTables, tableMapping.SourceTableIdentifier)
}

err = h.CheckSourceTables(ctx, sourcePool, sourceTables, req.ConnectionConfigs.PublicationName)
err = connpostgres.CheckSourceTables(ctx, sourcePool, sourceTables, req.ConnectionConfigs.PublicationName)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
Expand Down
23 changes: 2 additions & 21 deletions flow/cmd/validate_peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,14 @@ import (
"context"
"fmt"
"log/slog"
"strconv"

"github.com/PeerDB-io/peer-flow/connectors"
connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/jackc/pgx/v5/pgxpool"
)

func (h *FlowRequestHandler) GetPostgresVersion(ctx context.Context, pool *pgxpool.Pool) (int, error) {
if pool == nil {
return -1, fmt.Errorf("version check: pool is nil")
}

var versionRes string
err := pool.QueryRow(ctx, "SHOW server_version_num;").Scan(&versionRes)
if err != nil {
return -1, err
}

version, err := strconv.Atoi(versionRes)
if err != nil {
return -1, err
}

return version / 10000, nil
}

func (h *FlowRequestHandler) ValidatePeer(
ctx context.Context,
req *protos.ValidatePeerRequest,
Expand Down Expand Up @@ -73,7 +54,7 @@ func (h *FlowRequestHandler) ValidatePeer(
slog.Error("/peer/validate: failed to obtain peer connection", slog.Any("error", err))
return nil, err
}
version, err := h.GetPostgresVersion(ctx, sourcePool)
version, err := connpostgres.GetPostgresVersion(ctx, sourcePool)
if err != nil {
slog.Error("/peer/validate: pg version check", slog.Any("error", err))
return nil, err
Expand Down
105 changes: 105 additions & 0 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package connpostgres

import (
"context"
"errors"
"fmt"
"log"
"regexp"
"strconv"
"strings"

"github.com/PeerDB-io/peer-flow/connectors/utils"
Expand All @@ -13,6 +15,7 @@ import (
"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/lib/pq/oid"
)

Expand Down Expand Up @@ -572,3 +575,105 @@ func (c *PostgresConnector) getCurrentLSN() (pglogrepl.LSN, error) {
func (c *PostgresConnector) getDefaultPublicationName(jobName string) string {
return fmt.Sprintf("peerflow_pub_%s", jobName)
}

func CheckSourceTables(ctx context.Context, pool *pgxpool.Pool, tableNames []string, pubName string) error {
if pool == nil {
return fmt.Errorf("check tables: pool is nil")
}

// Check that we can select from all tables
for _, tableName := range tableNames {
var row pgx.Row
err := pool.QueryRow(ctx, fmt.Sprintf("SELECT * FROM %s LIMIT 0;", tableName)).Scan(&row)
if err != nil && err != pgx.ErrNoRows {
return err
}
}

// Check if tables belong to publication
tableArr := make([]string, 0, len(tableNames))
for _, tableName := range tableNames {
tableArr = append(tableArr, fmt.Sprintf("'%s'", tableName))
}

tableStr := strings.Join(tableArr, ",")

if pubName != "" {
var pubTableCount int
err := pool.QueryRow(ctx, fmt.Sprintf("select COUNT(DISTINCT(schemaname||'.'||tablename)) from pg_publication_tables "+
"where schemaname||'.'||tablename in (%s) and pubname=$1;", tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}

if pubTableCount != len(tableNames) {
return fmt.Errorf("not all tables belong to publication")
}
}

return nil
}

func CheckReplicationPermissions(ctx context.Context, pool *pgxpool.Pool, username string) error {
if pool == nil {
return fmt.Errorf("check replication permissions: pool is nil")
}

var replicationRes bool
err := pool.QueryRow(ctx, "SELECT rolreplication FROM pg_roles WHERE rolname = $1;", username).Scan(&replicationRes)
if err != nil {
return err
}

if !replicationRes {
return fmt.Errorf("postgres user does not have replication role")
}

// check wal_level
var walLevel string
err = pool.QueryRow(ctx, "SHOW wal_level;").Scan(&walLevel)
if err != nil {
return err
}

if walLevel != "logical" {
return fmt.Errorf("wal_level is not logical")
}

// max_wal_senders must be at least 2
var maxWalSendersRes string
err = pool.QueryRow(ctx, "SHOW max_wal_senders;").Scan(&maxWalSendersRes)
if err != nil {
return err
}

maxWalSenders, err := strconv.Atoi(maxWalSendersRes)
if err != nil {
return err
}

if maxWalSenders < 2 {
return fmt.Errorf("max_wal_senders must be at least 2")
}

return nil
}

func GetPostgresVersion(ctx context.Context, pool *pgxpool.Pool) (int, error) {
if pool == nil {
return -1, fmt.Errorf("version check: pool is nil")
}

var versionRes string
err := pool.QueryRow(ctx, "SHOW server_version_num;").Scan(&versionRes)
if err != nil {
return -1, err
}

version, err := strconv.Atoi(versionRes)
if err != nil {
return -1, err
}

return version / 10000, nil
}

0 comments on commit 876eeb8

Please sign in to comment.