Skip to content

Commit

Permalink
Merge branch 'main' into stable
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj committed Mar 21, 2024
2 parents 4c5f527 + 2b8a575 commit 75fea24
Show file tree
Hide file tree
Showing 32 changed files with 472 additions and 206 deletions.
4 changes: 4 additions & 0 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,9 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
partition *protos.QRepPartition,
runUUID string,
) error {
msg := fmt.Sprintf("replicating partition - %s: %d of %d total.", partition.PartitionId, idx, total)
activity.RecordHeartbeat(ctx, msg)

ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName)
logger := log.With(activity.GetLogger(ctx), slog.String(string(shared.FlowNameKey), config.FlowJobName))

Expand Down Expand Up @@ -643,6 +646,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
}
if done {
logger.Info("no records to push for partition " + partition.PartitionId)
activity.RecordHeartbeat(ctx, "no records to push for partition "+partition.PartitionId)
return nil
}

Expand Down
12 changes: 12 additions & 0 deletions flow/cmd/validate_peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ func (h *FlowRequestHandler) ValidatePeer(
}
}

validationConn, ok := conn.(connectors.ValidationConnector)
if ok {
validErr := validationConn.ValidateCheck(ctx)
if validErr != nil {
return &protos.ValidatePeerResponse{
Status: protos.ValidatePeerStatus_INVALID,
Message: fmt.Sprintf("failed to validate %s peer %s: %v",
req.Peer.Type, req.Peer.Name, validErr),
}, nil
}
}

connErr := conn.ConnectionActive(ctx)
if connErr != nil {
return &protos.ValidatePeerResponse{
Expand Down
18 changes: 6 additions & 12 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,14 @@ func (bqsa *BigQueryServiceAccount) CreateStorageClient(ctx context.Context) (*s
return client, nil
}

// TableCheck:
// ValidateCheck:
// 1. Creates a table
// 2. Inserts one row into the table
// 3. Deletes the table
func TableCheck(ctx context.Context, client *bigquery.Client, dataset string, project string) error {
func (c *BigQueryConnector) ValidateCheck(ctx context.Context) error {
dummyTable := "peerdb_validate_dummy_" + shared.RandomString(4)

newTable := client.DatasetInProject(project, dataset).Table(dummyTable)
newTable := c.client.DatasetInProject(c.projectID, c.datasetID).Table(dummyTable)

createErr := newTable.Create(ctx, &bigquery.TableMetadata{
Schema: []*bigquery.FieldSchema{
Expand All @@ -155,9 +155,9 @@ func TableCheck(ctx context.Context, client *bigquery.Client, dataset string, pr
}

var errs []error
insertQuery := client.Query(fmt.Sprintf("INSERT INTO %s VALUES(true)", dummyTable))
insertQuery.DefaultDatasetID = dataset
insertQuery.DefaultProjectID = project
insertQuery := c.client.Query(fmt.Sprintf("INSERT INTO %s VALUES(true)", dummyTable))
insertQuery.DefaultDatasetID = c.datasetID
insertQuery.DefaultProjectID = c.projectID
_, insertErr := insertQuery.Run(ctx)
if insertErr != nil {
errs = append(errs, fmt.Errorf("unable to validate insertion into table: %w. ", insertErr))
Expand Down Expand Up @@ -207,12 +207,6 @@ func NewBigQueryConnector(ctx context.Context, config *protos.BigqueryConfig) (*
return nil, fmt.Errorf("failed to get dataset metadata: %v", datasetErr)
}

permissionErr := TableCheck(ctx, client, datasetID, projectID)
if permissionErr != nil {
logger.Error("failed to get run mock table check", "error", permissionErr)
return nil, permissionErr
}

storageClient, err := bqsa.CreateStorageClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create Storage client: %v", err)
Expand Down
23 changes: 9 additions & 14 deletions flow/connectors/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,32 @@ func ValidateS3(ctx context.Context, creds *utils.ClickhouseS3Credentials) error
}

// Creates and drops a dummy table to validate the peer
func ValidateClickhouse(ctx context.Context, conn *sql.DB) error {
func (c *ClickhouseConnector) ValidateCheck(ctx context.Context) error {
validateDummyTableName := "peerdb_validation_" + shared.RandomString(4)
// create a table
_, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id UInt64) ENGINE = Memory",
_, err := c.database.ExecContext(ctx, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id UInt64) ENGINE = Memory",
validateDummyTableName))
if err != nil {
return fmt.Errorf("failed to create validation table %s: %w", validateDummyTableName, err)
}

// insert a row
_, err = conn.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s VALUES (1)", validateDummyTableName))
_, err = c.database.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s VALUES (1)", validateDummyTableName))
if err != nil {
return fmt.Errorf("failed to insert into validation table %s: %w", validateDummyTableName, err)
}

// drop the table
_, err = conn.ExecContext(ctx, "DROP TABLE IF EXISTS "+validateDummyTableName)
_, err = c.database.ExecContext(ctx, "DROP TABLE IF EXISTS "+validateDummyTableName)
if err != nil {
return fmt.Errorf("failed to drop validation table %s: %w", validateDummyTableName, err)
}

validateErr := ValidateS3(ctx, c.creds)
if validateErr != nil {
return fmt.Errorf("failed to validate S3 bucket: %w", validateErr)
}

return nil
}

Expand All @@ -90,11 +95,6 @@ func NewClickhouseConnector(
return nil, fmt.Errorf("failed to open connection to Clickhouse peer: %w", err)
}

err = ValidateClickhouse(ctx, database)
if err != nil {
return nil, fmt.Errorf("invalidated Clickhouse peer: %w", err)
}

pgMetadata, err := metadataStore.NewPostgresMetadataStore(ctx)
if err != nil {
logger.Error("failed to create postgres metadata store", "error", err)
Expand Down Expand Up @@ -122,11 +122,6 @@ func NewClickhouseConnector(
clickhouseS3Creds = utils.GetClickhouseAWSSecrets(bucketPathSuffix)
}

validateErr := ValidateS3(ctx, clickhouseS3Creds)
if validateErr != nil {
return nil, fmt.Errorf("failed to validate S3 bucket: %w", validateErr)
}

return &ClickhouseConnector{
database: database,
pgMetadata: pgMetadata,
Expand Down
13 changes: 13 additions & 0 deletions flow/connectors/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ type Connector interface {
ConnectionActive(context.Context) error
}

type ValidationConnector interface {
Connector

// ValidationCheck performs validation for the connectors,
// usually includes permissions to create and use objects (tables, schema etc).
ValidateCheck(context.Context) error
}

type GetTableSchemaConnector interface {
Connector

Expand Down Expand Up @@ -279,4 +287,9 @@ var (

_ QRepConsolidateConnector = &connsnowflake.SnowflakeConnector{}
_ QRepConsolidateConnector = &connclickhouse.ClickhouseConnector{}

_ ValidationConnector = &connsnowflake.SnowflakeConnector{}
_ ValidationConnector = &connclickhouse.ClickhouseConnector{}
_ ValidationConnector = &connbigquery.BigQueryConnector{}
_ ValidationConnector = &conns3.S3Connector{}
)
11 changes: 7 additions & 4 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -739,10 +739,13 @@ func (p *PostgresCDCSource) processRelationMessage(
for _, column := range currRel.Columns {
// not present in previous relation message, but in current one, so added.
if _, ok := prevRelMap[column.Name]; !ok {
schemaDelta.AddedColumns = append(schemaDelta.AddedColumns, &protos.DeltaAddedColumn{
ColumnName: column.Name,
ColumnType: string(currRelMap[column.Name]),
})
// only add to delta if not excluded
if _, ok := p.tableNameMapping[p.srcTableIDNameMapping[currRel.RelationID]].Exclude[column.Name]; !ok {
schemaDelta.AddedColumns = append(schemaDelta.AddedColumns, &protos.DeltaAddedColumn{
ColumnName: column.Name,
ColumnType: string(currRelMap[column.Name]),
})
}
// present in previous and current relation messages, but data types have changed.
// so we add it to AddedColumns and DroppedColumns, knowing that we process DroppedColumns first.
} else if prevRelMap[column.Name] != currRelMap[column.Name] {
Expand Down
24 changes: 24 additions & 0 deletions flow/connectors/postgres/qvalue_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/lib/pq/oid"
"github.com/shopspring/decimal"

peerdb_interval "github.com/PeerDB-io/peer-flow/interval"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/PeerDB-io/peer-flow/shared"
)
Expand Down Expand Up @@ -80,6 +81,8 @@ func (c *PostgresConnector) postgresOIDToQValueKind(recvOID uint32) qvalue.QValu
return qvalue.QValueKindArrayTimestampTZ
case pgtype.TextArrayOID, pgtype.VarcharArrayOID, pgtype.BPCharArrayOID:
return qvalue.QValueKindArrayString
case pgtype.IntervalOID:
return qvalue.QValueKindInterval
default:
typeName, ok := pgtype.NewMap().TypeForOID(recvOID)
if !ok {
Expand Down Expand Up @@ -225,6 +228,27 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
case qvalue.QValueKindTimestampTZ:
timestamp := value.(time.Time)
val = qvalue.QValue{Kind: qvalue.QValueKindTimestampTZ, Value: timestamp}
case qvalue.QValueKindInterval:
intervalObject := value.(pgtype.Interval)
var interval peerdb_interval.PeerDBInterval
interval.Hours = int(intervalObject.Microseconds / 3600000000)
interval.Minutes = int((intervalObject.Microseconds % 3600000000) / 60000000)
interval.Seconds = float64(intervalObject.Microseconds%60000000) / 1000000.0
interval.Days = int(intervalObject.Days)
interval.Years = int(intervalObject.Months / 12)
interval.Months = int(intervalObject.Months % 12)
interval.Valid = intervalObject.Valid

intervalJSON, err := json.Marshal(interval)
if err != nil {
return qvalue.QValue{}, fmt.Errorf("failed to parse interval: %w", err)
}

if !interval.Valid {
return qvalue.QValue{}, fmt.Errorf("invalid interval: %v", value)
}

return qvalue.QValue{Kind: qvalue.QValueKindString, Value: string(intervalJSON)}, nil
case qvalue.QValueKindDate:
date := value.(time.Time)
val = qvalue.QValue{Kind: qvalue.QValueKindDate, Value: date}
Expand Down
18 changes: 6 additions & 12 deletions flow/connectors/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,18 @@ func (c *S3Connector) Close() error {
return nil
}

func ValidCheck(ctx context.Context, s3Client *s3.Client, bucketURL string, metadataDB *metadataStore.PostgresMetadataStore) error {
func (c *S3Connector) ValidateCheck(ctx context.Context) error {
reader := strings.NewReader(time.Now().Format(time.RFC3339))

bucketPrefix, parseErr := utils.NewS3BucketAndPrefix(bucketURL)
bucketPrefix, parseErr := utils.NewS3BucketAndPrefix(c.url)
if parseErr != nil {
return fmt.Errorf("failed to parse bucket url: %w", parseErr)
}

// Write an empty file and then delete it
// to check if we have write permissions
bucketName := aws.String(bucketPrefix.Bucket)
_, putErr := s3Client.PutObject(ctx, &s3.PutObjectInput{
_, putErr := c.client.PutObject(ctx, &s3.PutObjectInput{
Bucket: bucketName,
Key: aws.String(_peerDBCheck),
Body: reader,
Expand All @@ -109,7 +109,7 @@ func ValidCheck(ctx context.Context, s3Client *s3.Client, bucketURL string, meta
return fmt.Errorf("failed to write to bucket: %w", putErr)
}

_, delErr := s3Client.DeleteObject(ctx, &s3.DeleteObjectInput{
_, delErr := c.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: bucketName,
Key: aws.String(_peerDBCheck),
})
Expand All @@ -118,8 +118,8 @@ func ValidCheck(ctx context.Context, s3Client *s3.Client, bucketURL string, meta
}

// check if we can ping external metadata
if metadataDB != nil {
err := metadataDB.Ping(ctx)
if c.pgMetadata != nil {
err := c.pgMetadata.Ping(ctx)
if err != nil {
return fmt.Errorf("failed to ping external metadata: %w", err)
}
Expand All @@ -129,12 +129,6 @@ func ValidCheck(ctx context.Context, s3Client *s3.Client, bucketURL string, meta
}

func (c *S3Connector) ConnectionActive(ctx context.Context) error {
validErr := ValidCheck(ctx, &c.client, c.url, c.pgMetadata)
if validErr != nil {
c.logger.Error("failed to validate s3 connector:", "error", validErr)
return validErr
}

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/snowflake/merge_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("TO_GEOMETRY(CAST(%s:\"%s\" AS STRING),true) AS %s",
toVariantColumnName, column.Name, targetColumnName))
case qvalue.QValueKindJSON, qvalue.QValueKindHStore:
case qvalue.QValueKindJSON, qvalue.QValueKindHStore, qvalue.QValueKindInterval:
flattenedCastsSQLArray = append(flattenedCastsSQLArray,
fmt.Sprintf("PARSE_JSON(CAST(%s:\"%s\" AS STRING)) AS %s",
toVariantColumnName, column.Name, targetColumnName))
Expand Down
5 changes: 4 additions & 1 deletion flow/connectors/snowflake/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ func (c *SnowflakeConnector) createExternalStage(stageName string, config *proto

cleanURL := fmt.Sprintf("s3://%s/%s/%s", s3o.Bucket, s3o.Prefix, config.FlowJobName)

s3Int := config.DestinationPeer.GetSnowflakeConfig().S3Integration
var s3Int string
if config.DestinationPeer != nil {
s3Int = config.DestinationPeer.GetSnowflakeConfig().S3Integration
}
if s3Int == "" {
credsStr := fmt.Sprintf("CREDENTIALS=(AWS_KEY_ID='%s' AWS_SECRET_KEY='%s')",
awsCreds.AccessKeyID, awsCreds.SecretAccessKey)
Expand Down
18 changes: 7 additions & 11 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@ type UnchangedToastColumnResult struct {
UnchangedToastColumns ArrayString
}

func ValidationCheck(ctx context.Context, database *sql.DB, schemaName string) error {
func (c *SnowflakeConnector) ValidateCheck(ctx context.Context) error {
schemaName := c.rawSchema
// check if schema exists
var schemaExists sql.NullBool
err := database.QueryRowContext(ctx, checkIfSchemaExistsSQL, schemaName).Scan(&schemaExists)
err := c.database.QueryRowContext(ctx, checkIfSchemaExistsSQL, schemaName).Scan(&schemaExists)
if err != nil {
return fmt.Errorf("error while checking if schema exists: %w", err)
}
Expand All @@ -116,9 +117,9 @@ func ValidationCheck(ctx context.Context, database *sql.DB, schemaName string) e

// In a transaction, create a table, insert a row into the table and then drop the table
// If any of these steps fail, the transaction will be rolled back
tx, err := database.BeginTx(ctx, nil)
tx, err := c.database.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
return fmt.Errorf("failed to begin transaction for table check: %w", err)
}
// in case we return after error, ensure transaction is rolled back
defer func() {
Expand Down Expand Up @@ -158,7 +159,7 @@ func ValidationCheck(ctx context.Context, database *sql.DB, schemaName string) e
// commit transaction
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
return fmt.Errorf("failed to commit transaction for table check: %w", err)
}

return nil
Expand Down Expand Up @@ -212,11 +213,6 @@ func NewSnowflakeConnector(
rawSchema = *snowflakeProtoConfig.MetadataSchema
}

err = ValidationCheck(ctx, database, rawSchema)
if err != nil {
return nil, fmt.Errorf("could not validate snowflake peer: %w", err)
}

pgMetadata, err := metadataStore.NewPostgresMetadataStore(ctx)
if err != nil {
return nil, fmt.Errorf("could not connect to metadata store: %w", err)
Expand Down Expand Up @@ -459,7 +455,7 @@ func (c *SnowflakeConnector) syncRecordsViaAvro(
}

qrepConfig := &protos.QRepConfig{
StagingPath: "",
StagingPath: req.StagingPath,
FlowJobName: req.FlowJobName,
DestinationTableIdentifier: strings.ToLower(fmt.Sprintf("%s.%s", c.rawSchema,
rawTableIdentifier)),
Expand Down
12 changes: 12 additions & 0 deletions flow/connectors/sql/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,18 @@ func (g *GenericSQLQueryExecutor) CountNonNullRows(
return count.Int64, err
}

func (g *GenericSQLQueryExecutor) CountSRIDs(
ctx context.Context,
schemaName string,
tableName string,
columnName string,
) (int64, error) {
var count pgtype.Int8
err := g.db.QueryRowxContext(ctx, "SELECT COUNT(CASE WHEN ST_SRID("+columnName+
") <> 0 THEN 1 END) AS not_zero FROM "+schemaName+"."+tableName).Scan(&count)
return count.Int64, err
}

func (g *GenericSQLQueryExecutor) columnTypeToQField(ct *sql.ColumnType) (model.QField, error) {
qvKind, ok := g.dbtypeToQValueKind[ct.DatabaseTypeName()]
if !ok {
Expand Down
Loading

0 comments on commit 75fea24

Please sign in to comment.