From af39551eb034d6a21548bbb3ec016e9a7e9f155a Mon Sep 17 00:00:00 2001 From: Kevin Biju <52661649+heavycrystal@users.noreply.github.com> Date: Thu, 21 Mar 2024 02:09:30 +0530 Subject: [PATCH 1/9] column exclusion - properly handle schema changes (#1512) Column exclusions works by removing columns from the schema we fetch of the source table. This exclusion was not being done in the code path for schema changes [where we fetch the schema again], causing a disconnect and normalize to fail. Fixed by moving the exclusion code to a separate function and making both code paths use it. Also CDC handles excluded columns earlier to prevent spurious logs. --- flow/connectors/postgres/cdc.go | 11 +++-- flow/model/cdc_record_stream.go | 22 ++------- flow/shared/additional_tables.go | 26 ----------- flow/shared/schema_helpers.go | 76 ++++++++++++++++++++++++++++++++ flow/workflows/setup_flow.go | 31 +------------ flow/workflows/sync_flow.go | 12 +++-- 6 files changed, 94 insertions(+), 84 deletions(-) delete mode 100644 flow/shared/additional_tables.go create mode 100644 flow/shared/schema_helpers.go diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 4683adb072..acf3ba7a40 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -739,10 +739,13 @@ func (p *PostgresCDCSource) processRelationMessage( for _, column := range currRel.Columns { // not present in previous relation message, but in current one, so added. if _, ok := prevRelMap[column.Name]; !ok { - schemaDelta.AddedColumns = append(schemaDelta.AddedColumns, &protos.DeltaAddedColumn{ - ColumnName: column.Name, - ColumnType: string(currRelMap[column.Name]), - }) + // only add to delta if not excluded + if _, ok := p.tableNameMapping[p.srcTableIDNameMapping[currRel.RelationID]].Exclude[column.Name]; !ok { + schemaDelta.AddedColumns = append(schemaDelta.AddedColumns, &protos.DeltaAddedColumn{ + ColumnName: column.Name, + ColumnType: string(currRelMap[column.Name]), + }) + } // present in previous and current relation messages, but data types have changed. // so we add it to AddedColumns and DroppedColumns, knowing that we process DroppedColumns first. } else if prevRelMap[column.Name] != currRelMap[column.Name] { diff --git a/flow/model/cdc_record_stream.go b/flow/model/cdc_record_stream.go index dcdadfbb67..0e2e633d4c 100644 --- a/flow/model/cdc_record_stream.go +++ b/flow/model/cdc_record_stream.go @@ -76,22 +76,8 @@ func (r *CDCRecordStream) GetRecords() <-chan Record { return r.records } -func (r *CDCRecordStream) AddSchemaDelta(tableNameMapping map[string]NameAndExclude, delta *protos.TableSchemaDelta) { - if tm, ok := tableNameMapping[delta.SrcTableName]; ok && len(tm.Exclude) != 0 { - added := make([]*protos.DeltaAddedColumn, 0, len(delta.AddedColumns)) - for _, column := range delta.AddedColumns { - if _, has := tm.Exclude[column.ColumnName]; !has { - added = append(added, column) - } - } - if len(added) != 0 { - r.SchemaDeltas = append(r.SchemaDeltas, &protos.TableSchemaDelta{ - SrcTableName: delta.SrcTableName, - DstTableName: delta.DstTableName, - AddedColumns: added, - }) - } - } else { - r.SchemaDeltas = append(r.SchemaDeltas, delta) - } +func (r *CDCRecordStream) AddSchemaDelta(tableNameMapping map[string]NameAndExclude, + delta *protos.TableSchemaDelta, +) { + r.SchemaDeltas = append(r.SchemaDeltas, delta) } diff --git a/flow/shared/additional_tables.go b/flow/shared/additional_tables.go deleted file mode 100644 index 0eb0b79f35..0000000000 --- a/flow/shared/additional_tables.go +++ /dev/null @@ -1,26 +0,0 @@ -package shared - -import ( - "github.com/PeerDB-io/peer-flow/generated/protos" -) - -func AdditionalTablesHasOverlap(currentTableMappings []*protos.TableMapping, - additionalTableMappings []*protos.TableMapping, -) bool { - currentSrcTables := make([]string, 0, len(currentTableMappings)) - currentDstTables := make([]string, 0, len(currentTableMappings)) - additionalSrcTables := make([]string, 0, len(additionalTableMappings)) - additionalDstTables := make([]string, 0, len(additionalTableMappings)) - - for _, currentTableMapping := range currentTableMappings { - currentSrcTables = append(currentSrcTables, currentTableMapping.SourceTableIdentifier) - currentDstTables = append(currentDstTables, currentTableMapping.DestinationTableIdentifier) - } - for _, additionalTableMapping := range additionalTableMappings { - additionalSrcTables = append(additionalSrcTables, additionalTableMapping.SourceTableIdentifier) - additionalDstTables = append(additionalDstTables, additionalTableMapping.DestinationTableIdentifier) - } - - return ArraysHaveOverlap(currentSrcTables, additionalSrcTables) || - ArraysHaveOverlap(currentDstTables, additionalDstTables) -} diff --git a/flow/shared/schema_helpers.go b/flow/shared/schema_helpers.go new file mode 100644 index 0000000000..2c92195e6f --- /dev/null +++ b/flow/shared/schema_helpers.go @@ -0,0 +1,76 @@ +package shared + +import ( + "log/slog" + "slices" + + "go.temporal.io/sdk/log" + "golang.org/x/exp/maps" + + "github.com/PeerDB-io/peer-flow/generated/protos" +) + +func AdditionalTablesHasOverlap(currentTableMappings []*protos.TableMapping, + additionalTableMappings []*protos.TableMapping, +) bool { + currentSrcTables := make([]string, 0, len(currentTableMappings)) + currentDstTables := make([]string, 0, len(currentTableMappings)) + additionalSrcTables := make([]string, 0, len(additionalTableMappings)) + additionalDstTables := make([]string, 0, len(additionalTableMappings)) + + for _, currentTableMapping := range currentTableMappings { + currentSrcTables = append(currentSrcTables, currentTableMapping.SourceTableIdentifier) + currentDstTables = append(currentDstTables, currentTableMapping.DestinationTableIdentifier) + } + for _, additionalTableMapping := range additionalTableMappings { + additionalSrcTables = append(additionalSrcTables, additionalTableMapping.SourceTableIdentifier) + additionalDstTables = append(additionalDstTables, additionalTableMapping.DestinationTableIdentifier) + } + + return ArraysHaveOverlap(currentSrcTables, additionalSrcTables) || + ArraysHaveOverlap(currentDstTables, additionalDstTables) +} + +// given the output of GetTableSchema, processes it to be used by CDCFlow +// 1) changes the map key to be the destination table name instead of the source table name +// 2) performs column exclusion using protos.TableMapping as input. +func BuildProcessedSchemaMapping(tableMappings []*protos.TableMapping, + tableNameSchemaMapping map[string]*protos.TableSchema, + logger log.Logger, +) map[string]*protos.TableSchema { + processedSchemaMapping := make(map[string]*protos.TableSchema) + sortedSourceTables := maps.Keys(tableNameSchemaMapping) + slices.Sort(sortedSourceTables) + + for _, srcTableName := range sortedSourceTables { + tableSchema := tableNameSchemaMapping[srcTableName] + var dstTableName string + for _, mapping := range tableMappings { + if mapping.SourceTableIdentifier == srcTableName { + dstTableName = mapping.DestinationTableIdentifier + if len(mapping.Exclude) != 0 { + columnCount := len(tableSchema.Columns) + columns := make([]*protos.FieldDescription, 0, columnCount) + for _, column := range tableSchema.Columns { + if !slices.Contains(mapping.Exclude, column.Name) { + columns = append(columns, column) + } + } + tableSchema = &protos.TableSchema{ + TableIdentifier: tableSchema.TableIdentifier, + PrimaryKeyColumns: tableSchema.PrimaryKeyColumns, + IsReplicaIdentityFull: tableSchema.IsReplicaIdentityFull, + Columns: columns, + } + } + break + } + } + processedSchemaMapping[dstTableName] = tableSchema + + logger.Info("normalized table schema", + slog.String("table", dstTableName), + slog.Any("schema", tableSchema)) + } + return processedSchemaMapping +} diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go index 0574f0d24e..4355bd832c 100644 --- a/flow/workflows/setup_flow.go +++ b/flow/workflows/setup_flow.go @@ -3,7 +3,6 @@ package peerflow import ( "fmt" "log/slog" - "slices" "sort" "time" @@ -201,34 +200,8 @@ func (s *SetupFlowExecution) fetchTableSchemaAndSetupNormalizedTables( sort.Strings(sortedSourceTables) s.logger.Info("setting up normalized tables for peer flow") - normalizedTableMapping := make(map[string]*protos.TableSchema) - for _, srcTableName := range sortedSourceTables { - tableSchema := tableNameSchemaMapping[srcTableName] - normalizedTableName := s.tableNameMapping[srcTableName] - for _, mapping := range flowConnectionConfigs.TableMappings { - if mapping.SourceTableIdentifier == srcTableName { - if len(mapping.Exclude) != 0 { - columnCount := len(tableSchema.Columns) - columns := make([]*protos.FieldDescription, 0, columnCount) - for _, column := range tableSchema.Columns { - if !slices.Contains(mapping.Exclude, column.Name) { - columns = append(columns, column) - } - } - tableSchema = &protos.TableSchema{ - TableIdentifier: tableSchema.TableIdentifier, - PrimaryKeyColumns: tableSchema.PrimaryKeyColumns, - IsReplicaIdentityFull: tableSchema.IsReplicaIdentityFull, - Columns: columns, - } - } - break - } - } - normalizedTableMapping[normalizedTableName] = tableSchema - - s.logger.Info("normalized table schema", slog.String("table", normalizedTableName), slog.Any("schema", tableSchema)) - } + normalizedTableMapping := shared.BuildProcessedSchemaMapping(flowConnectionConfigs.TableMappings, + tableNameSchemaMapping, s.logger) // now setup the normalized tables on the destination peer setupConfig := &protos.SetupNormalizedTableBatchInput{ diff --git a/flow/workflows/sync_flow.go b/flow/workflows/sync_flow.go index 9958c1c79c..890da2e0fa 100644 --- a/flow/workflows/sync_flow.go +++ b/flow/workflows/sync_flow.go @@ -6,6 +6,7 @@ import ( "go.temporal.io/sdk/log" "go.temporal.io/sdk/workflow" + "golang.org/x/exp/maps" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" @@ -139,12 +140,10 @@ func SyncFlowWorkflow( tableSchemaDeltasCount := len(childSyncFlowRes.TableSchemaDeltas) // slightly hacky: table schema mapping is cached, so we need to manually update it if schema changes. - if tableSchemaDeltasCount != 0 { + if tableSchemaDeltasCount > 0 { modifiedSrcTables := make([]string, 0, tableSchemaDeltasCount) - modifiedDstTables := make([]string, 0, tableSchemaDeltasCount) for _, tableSchemaDelta := range childSyncFlowRes.TableSchemaDeltas { modifiedSrcTables = append(modifiedSrcTables, tableSchemaDelta.SrcTableName) - modifiedDstTables = append(modifiedDstTables, tableSchemaDelta.DstTableName) } getModifiedSchemaCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ @@ -167,10 +166,9 @@ func SyncFlowWorkflow( nil, ).Get(ctx, nil) } else { - for i, srcTable := range modifiedSrcTables { - dstTable := modifiedDstTables[i] - options.TableNameSchemaMapping[dstTable] = getModifiedSchemaRes.TableNameSchemaMapping[srcTable] - } + processedSchemaMapping := shared.BuildProcessedSchemaMapping(options.TableMappings, + getModifiedSchemaRes.TableNameSchemaMapping, logger) + maps.Copy(options.TableNameSchemaMapping, processedSchemaMapping) } } From 61a73d95c16c56e8d9cac421e491e0d1768a92b5 Mon Sep 17 00:00:00 2001 From: Amogh Bharadwaj Date: Thu, 21 Mar 2024 03:08:42 +0530 Subject: [PATCH 2/9] UI: remove pub load check (#1510) Currently configuration form for create CDC is gated by publications being loaded, but there's no user facing info that this wait is happening. Better to just have the loading be indicated by the publication dropdown and render the form no matter what --- ui/app/mirrors/create/cdc/cdc.tsx | 26 ++++++++++---------------- ui/app/mirrors/create/cdc/fields.tsx | 9 ++++++++- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/ui/app/mirrors/create/cdc/cdc.tsx b/ui/app/mirrors/create/cdc/cdc.tsx index cb9ea4a381..3c194f0a8c 100644 --- a/ui/app/mirrors/create/cdc/cdc.tsx +++ b/ui/app/mirrors/create/cdc/cdc.tsx @@ -38,6 +38,7 @@ export default function CDCConfigForm({ setRows, }: MirrorConfigProps) { const [publications, setPublications] = useState(); + const [pubLoading, setPubLoading] = useState(true); const [show, setShow] = useState(false); const handleChange = (val: string | boolean, setting: MirrorSetting) => { let stateVal: string | boolean = val; @@ -66,26 +67,15 @@ export default function CDCConfigForm({ return true; }; - const optionsForField = (setting: MirrorSetting) => { - switch (setting.label) { - case 'Publication Name': - return publications; - default: - return []; - } - }; - useEffect(() => { + setPubLoading(true); fetchPublications(mirrorConfig.source?.name || '').then((pubs) => { setPublications(pubs); + setPubLoading(false); }); }, [mirrorConfig.source?.name]); - if ( - mirrorConfig.source != undefined && - mirrorConfig.destination != undefined && - publications != undefined - ) + if (mirrorConfig.source != undefined && mirrorConfig.destination != undefined) return ( <> {normalSettings.map((setting, id) => { @@ -95,7 +85,12 @@ export default function CDCConfigForm({ key={id} handleChange={handleChange} setting={setting} - options={optionsForField(setting)} + options={ + setting.label === 'Publication Name' + ? publications + : undefined + } + publicationsLoading={pubLoading} /> ) ); @@ -126,7 +121,6 @@ export default function CDCConfigForm({ key={setting.label} handleChange={handleChange} setting={setting} - options={optionsForField(setting)} /> ); })} diff --git a/ui/app/mirrors/create/cdc/fields.tsx b/ui/app/mirrors/create/cdc/fields.tsx index 8bf802ffa1..7d8842e89a 100644 --- a/ui/app/mirrors/create/cdc/fields.tsx +++ b/ui/app/mirrors/create/cdc/fields.tsx @@ -13,9 +13,15 @@ interface FieldProps { setting: MirrorSetting; handleChange: (val: string | boolean, setting: MirrorSetting) => void; options?: string[]; + publicationsLoading?: boolean; } -const CDCField = ({ setting, handleChange, options }: FieldProps) => { +const CDCField = ({ + setting, + handleChange, + options, + publicationsLoading, +}: FieldProps) => { return setting.type === 'switch' ? ( { getOptionLabel={(option) => option.label} getOptionValue={(option) => option.option} theme={SelectTheme} + isLoading={publicationsLoading} /> {setting.tips && ( From 472d279b1fbeecb617fb6358a6641e50bc9dd41f Mon Sep 17 00:00:00 2001 From: Kevin Biju <52661649+heavycrystal@users.noreply.github.com> Date: Thu, 21 Mar 2024 08:30:29 +0530 Subject: [PATCH 3/9] SF storage integration fixes (#1517) Code is in need of general cleanup 1) CDC was hardcoding to internal stage, make it respect external stages. 2) fixing 1 leads to nil pointer dereference in CDC because integration detection accesses a field not set in CDC. --- flow/connectors/snowflake/qrep.go | 5 ++++- flow/connectors/snowflake/snowflake.go | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/flow/connectors/snowflake/qrep.go b/flow/connectors/snowflake/qrep.go index 96e337a3f1..7ba84741a9 100644 --- a/flow/connectors/snowflake/qrep.go +++ b/flow/connectors/snowflake/qrep.go @@ -135,7 +135,10 @@ func (c *SnowflakeConnector) createExternalStage(stageName string, config *proto cleanURL := fmt.Sprintf("s3://%s/%s/%s", s3o.Bucket, s3o.Prefix, config.FlowJobName) - s3Int := config.DestinationPeer.GetSnowflakeConfig().S3Integration + var s3Int string + if config.DestinationPeer != nil { + s3Int = config.DestinationPeer.GetSnowflakeConfig().S3Integration + } if s3Int == "" { credsStr := fmt.Sprintf("CREDENTIALS=(AWS_KEY_ID='%s' AWS_SECRET_KEY='%s')", awsCreds.AccessKeyID, awsCreds.SecretAccessKey) diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 746768bf51..b1877320e1 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -459,7 +459,7 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( } qrepConfig := &protos.QRepConfig{ - StagingPath: "", + StagingPath: req.StagingPath, FlowJobName: req.FlowJobName, DestinationTableIdentifier: strings.ToLower(fmt.Sprintf("%s.%s", c.rawSchema, rawTableIdentifier)), From 156a9b29f2ef1068df28acccc9d5565ef8b16712 Mon Sep 17 00:00:00 2001 From: Amogh Bharadwaj Date: Thu, 21 Mar 2024 13:32:53 +0530 Subject: [PATCH 4/9] Geospatial data types: set SRID for geometry (#1514) From Snowflake docs: ``` For GeoJSON, WKT, and WKB input, if the srid argument is not specified, the resulting GEOMETRY object has the SRID set to 0. ``` So we need to explicitly set the SRID in our WKT geospatial strings so that this is set on the target rows and can be seen with `ST_SRID`. Test added Functionally tested --- flow/connectors/sql/query_executor.go | 12 ++++++++++++ flow/e2e/snowflake/peer_flow_sf_test.go | 18 +++++++++++++++--- flow/e2e/snowflake/snowflake_helper.go | 9 +++++++++ flow/geo/geo.go | 4 ++++ flow/model/qrecord_batch.go | 13 +++++++++++-- flow/model/qvalue/qvalue.go | 10 +++++++++- 6 files changed, 60 insertions(+), 6 deletions(-) diff --git a/flow/connectors/sql/query_executor.go b/flow/connectors/sql/query_executor.go index 91972a75c3..05279fdde4 100644 --- a/flow/connectors/sql/query_executor.go +++ b/flow/connectors/sql/query_executor.go @@ -138,6 +138,18 @@ func (g *GenericSQLQueryExecutor) CountNonNullRows( return count.Int64, err } +func (g *GenericSQLQueryExecutor) CountSRIDs( + ctx context.Context, + schemaName string, + tableName string, + columnName string, +) (int64, error) { + var count pgtype.Int8 + err := g.db.QueryRowxContext(ctx, "SELECT COUNT(CASE WHEN ST_SRID("+columnName+ + ") <> 0 THEN 1 END) AS not_zero FROM "+schemaName+"."+tableName).Scan(&count) + return count.Int64, err +} + func (g *GenericSQLQueryExecutor) columnTypeToQField(ct *sql.ColumnType) (model.QField, error) { qvKind, ok := g.dbtypeToQValueKind[ct.DatabaseTypeName()] if !ok { diff --git a/flow/e2e/snowflake/peer_flow_sf_test.go b/flow/e2e/snowflake/peer_flow_sf_test.go index 56084a1a27..635a7f3e0d 100644 --- a/flow/e2e/snowflake/peer_flow_sf_test.go +++ b/flow/e2e/snowflake/peer_flow_sf_test.go @@ -127,7 +127,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() { for range 6 { _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` INSERT INTO %s (line,poly) VALUES ($1,$2) - `, srcTableName), "010200000002000000000000000000F03F000000000000004000000000000008400000000000001040", + `, srcTableName), "SRID=5678;010200000002000000000000000000F03F000000000000004000000000000008400000000000001040", "010300000001000000050000000000000000000000000000000000000000000000"+ "00000000000000000000f03f000000000000f03f000000000000f03f0000000000"+ "00f03f000000000000000000000000000000000000000000000000") @@ -143,6 +143,13 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() { return false } + // Make sure SRIDs are set + sridCount, err := s.sfHelper.CountSRIDs("test_invalid_geo_sf_avro_cdc", "line") + if err != nil { + s.t.Log(err) + return false + } + polyCount, err := s.sfHelper.CountNonNullRows("test_invalid_geo_sf_avro_cdc", "poly") if err != nil { return false @@ -151,9 +158,14 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Geo_SF_Avro_CDC() { if lineCount != 6 || polyCount != 6 { s.t.Logf("wrong counts, expect 6 lines 6 polies, not %d lines %d polies", lineCount, polyCount) return false - } else { - return true } + + if sridCount != 6 { + s.t.Logf("there are some srids that are 0, expected 6 non-zero srids, got %d non-zero srids", sridCount) + return false + } + + return true }) env.Cancel() diff --git a/flow/e2e/snowflake/snowflake_helper.go b/flow/e2e/snowflake/snowflake_helper.go index e0d41e838d..14ca9dc35f 100644 --- a/flow/e2e/snowflake/snowflake_helper.go +++ b/flow/e2e/snowflake/snowflake_helper.go @@ -136,6 +136,15 @@ func (s *SnowflakeTestHelper) CountNonNullRows(tableName string, columnName stri return int(res), nil } +func (s *SnowflakeTestHelper) CountSRIDs(tableName string, columnName string) (int, error) { + res, err := s.testClient.CountSRIDs(context.Background(), s.testSchemaName, tableName, columnName) + if err != nil { + return 0, err + } + + return int(res), nil +} + func (s *SnowflakeTestHelper) CheckNull(tableName string, colNames []string) (bool, error) { return s.testClient.CheckNull(context.Background(), s.testSchemaName, tableName, colNames) } diff --git a/flow/geo/geo.go b/flow/geo/geo.go index 6602a26d53..7e35d9bfe6 100644 --- a/flow/geo/geo.go +++ b/flow/geo/geo.go @@ -31,6 +31,10 @@ func GeoValidate(hexWkb string) (string, error) { } wkt := geometryObject.ToWKT() + + if SRID := geometryObject.SRID(); SRID != 0 { + wkt = fmt.Sprintf("SRID=%d;%s", geometryObject.SRID(), wkt) + } return wkt, nil } diff --git a/flow/model/qrecord_batch.go b/flow/model/qrecord_batch.go index 08c5ce7770..dd55ef7ecc 100644 --- a/flow/model/qrecord_batch.go +++ b/flow/model/qrecord_batch.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log/slog" + "strings" "time" "github.com/google/uuid" @@ -237,9 +238,17 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) { return nil, src.err } - wkb, err := geo.GeoToWKB(v) + geoWkt := v + if strings.HasPrefix(v, "SRID=") { + _, wkt, found := strings.Cut(v, ";") + if found { + geoWkt = wkt + } + } + + wkb, err := geo.GeoToWKB(geoWkt) if err != nil { - src.err = errors.New("failed to convert Geospatial value to wkb") + src.err = fmt.Errorf("failed to convert Geospatial value to wkb: %v", err) return nil, src.err } diff --git a/flow/model/qvalue/qvalue.go b/flow/model/qvalue/qvalue.go index ae0a3945ab..4a495a31cc 100644 --- a/flow/model/qvalue/qvalue.go +++ b/flow/model/qvalue/qvalue.go @@ -292,7 +292,15 @@ func compareGeometry(value1, value2 interface{}) bool { case *geom.Geom: return v1.Equals(geo2) case string: - geo1, err := geom.NewGeomFromWKT(v1) + geoWkt := v1 + if strings.HasPrefix(geoWkt, "SRID=") { + _, wkt, found := strings.Cut(geoWkt, ";") + if found { + geoWkt = wkt + } + } + + geo1, err := geom.NewGeomFromWKT(geoWkt) if err != nil { panic(err) } From 9e28839a38a3e78c09299b0b18910172eb2cc973 Mon Sep 17 00:00:00 2001 From: Amogh Bharadwaj Date: Thu, 21 Mar 2024 13:47:42 +0530 Subject: [PATCH 5/9] Mirror page: actions dropdown (#1513) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR moves edit mirror and resync mirror buttons to an actions dropdown Screenshot 2024-03-21 at 1 00 51 AM --- ui/app/mirrors/[mirrorId]/page.tsx | 44 ++++++--------- ui/components/AlertDropdown.tsx | 4 -- ui/components/EditButton.tsx | 16 ++---- ui/components/MirrorActionsDropdown.tsx | 72 +++++++++++++++++++++++++ ui/components/ResyncDialog.tsx | 4 +- 5 files changed, 96 insertions(+), 44 deletions(-) create mode 100644 ui/components/MirrorActionsDropdown.tsx diff --git a/ui/app/mirrors/[mirrorId]/page.tsx b/ui/app/mirrors/[mirrorId]/page.tsx index 939b82a3bd..d7a78f9187 100644 --- a/ui/app/mirrors/[mirrorId]/page.tsx +++ b/ui/app/mirrors/[mirrorId]/page.tsx @@ -1,7 +1,6 @@ import { SyncStatusRow } from '@/app/dto/MirrorsDTO'; import prisma from '@/app/utils/prisma'; -import EditButton from '@/components/EditButton'; -import { ResyncDialog } from '@/components/ResyncDialog'; +import MirrorActions from '@/components/MirrorActionsDropdown'; import { FlowConnectionConfigs, FlowStatus } from '@/grpc_generated/flow'; import { DBType } from '@/grpc_generated/peers'; import { MirrorStatusResponse } from '@/grpc_generated/route'; @@ -77,8 +76,7 @@ export default async function ViewMirror({ } let syncStatusChild = null; - let resyncComponent = null; - let editButtonHTML = null; + let actionsDropdown = null; if (mirrorStatus.cdcStatus) { let rowsSynced = syncs.reduce((acc, sync) => { @@ -88,32 +86,27 @@ export default async function ViewMirror({ return acc; }, 0); const mirrorConfig = FlowConnectionConfigs.decode(mirrorInfo.config_proto!); + syncStatusChild = ( + + ); + const dbType = mirrorConfig.destination!.type; const canResync = dbType.valueOf() === DBType.BIGQUERY.valueOf() || dbType.valueOf() === DBType.SNOWFLAKE.valueOf(); - if (canResync) { - resyncComponent = ( - - ); - } - syncStatusChild = ( - - ); const isNotPaused = mirrorStatus.currentFlowState.toString() !== FlowStatus[FlowStatus.STATUS_PAUSED]; - editButtonHTML = ( -
- -
+ + actionsDropdown = ( + ); } else { redirect(`/mirrors/status/qrep/${mirrorId}`); @@ -129,11 +122,8 @@ export default async function ViewMirror({ paddingRight: '2rem', }} > -
-
{mirrorId}
- {editButtonHTML} -
- {resyncComponent} +
{mirrorId}
+ {actionsDropdown} !prevOpen); }; - const handleClose = () => { - setOpen(false); - }; - return ( diff --git a/ui/components/EditButton.tsx b/ui/components/EditButton.tsx index 1e598de2c9..5e87cbfaa8 100644 --- a/ui/components/EditButton.tsx +++ b/ui/components/EditButton.tsx @@ -1,6 +1,5 @@ 'use client'; import { Button } from '@/lib/Button'; -import { Icon } from '@/lib/Icon'; import { Label } from '@/lib/Label'; import { ProgressCircle } from '@/lib/ProgressCircle'; import { useRouter } from 'next/navigation'; @@ -25,22 +24,17 @@ const EditButton = ({ className='IconButton' onClick={handleEdit} aria-label='sort up' + variant='normal' style={{ display: 'flex', - marginLeft: '1rem', - alignItems: 'center', - backgroundColor: 'whitesmoke', - border: '1px solid rgba(0,0,0,0.1)', - borderRadius: '0.5rem', + alignItems: 'flex-start', + columnGap: '0.3rem', + width: '100%', }} disabled={disabled} > - {loading ? ( - - ) : ( - - )} + {loading && } ); }; diff --git a/ui/components/MirrorActionsDropdown.tsx b/ui/components/MirrorActionsDropdown.tsx new file mode 100644 index 0000000000..c11e68e019 --- /dev/null +++ b/ui/components/MirrorActionsDropdown.tsx @@ -0,0 +1,72 @@ +'use client'; +import EditButton from '@/components/EditButton'; +import { ResyncDialog } from '@/components/ResyncDialog'; +import { FlowConnectionConfigs } from '@/grpc_generated/flow'; +import { Button } from '@/lib/Button/Button'; +import { Icon } from '@/lib/Icon'; +import { Label } from '@/lib/Label/Label'; +import * as DropdownMenu from '@radix-ui/react-dropdown-menu'; +import { useEffect, useState } from 'react'; + +const MirrorActions = ({ + mirrorConfig, + workflowId, + editLink, + canResync, + isNotPaused, +}: { + mirrorConfig: FlowConnectionConfigs; + workflowId: string; + editLink: string; + canResync: boolean; + isNotPaused: boolean; +}) => { + const [mounted, setMounted] = useState(false); + const [open, setOpen] = useState(false); + const handleToggle = () => { + setOpen((prevOpen) => !prevOpen); + }; + useEffect(() => setMounted(true), []); + if (mounted) + return ( + + + + + + + + + + {canResync && ( + + )} + + + + ); + return <>; +}; + +export default MirrorActions; diff --git a/ui/components/ResyncDialog.tsx b/ui/components/ResyncDialog.tsx index a8ef4f39aa..6caeea3821 100644 --- a/ui/components/ResyncDialog.tsx +++ b/ui/components/ResyncDialog.tsx @@ -62,8 +62,8 @@ export const ResyncDialog = ({ noInteract={true} size='xLarge' triggerButton={ - } > From 73a39e994ec0d8299bb87f88270c21b3261f60a5 Mon Sep 17 00:00:00 2001 From: Amogh Bharadwaj Date: Thu, 21 Mar 2024 18:31:05 +0530 Subject: [PATCH 6/9] Snowflake: Support interval data type (#1515) This PR maps PostgreSQL's Interval data type to VARIANT in Snowflake of the form: ```json { "hours": 9, "days": 131, "months": 10, "years": 51, "valid": true } ``` --- flow/connectors/postgres/qvalue_convert.go | 24 +++++++++++++++++++ .../snowflake/merge_stmt_generator.go | 2 +- flow/e2e/snowflake/peer_flow_sf_test.go | 16 +++++++++++-- flow/interval/interval.go | 11 +++++++++ flow/model/qvalue/avro_converter.go | 8 +++++-- flow/model/qvalue/kind.go | 3 +++ 6 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 flow/interval/interval.go diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index 79c817d28c..d0e3e8cc0b 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -12,6 +12,7 @@ import ( "github.com/lib/pq/oid" "github.com/shopspring/decimal" + peerdb_interval "github.com/PeerDB-io/peer-flow/interval" "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/shared" ) @@ -80,6 +81,8 @@ func (c *PostgresConnector) postgresOIDToQValueKind(recvOID uint32) qvalue.QValu return qvalue.QValueKindArrayTimestampTZ case pgtype.TextArrayOID, pgtype.VarcharArrayOID, pgtype.BPCharArrayOID: return qvalue.QValueKindArrayString + case pgtype.IntervalOID: + return qvalue.QValueKindInterval default: typeName, ok := pgtype.NewMap().TypeForOID(recvOID) if !ok { @@ -225,6 +228,27 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( case qvalue.QValueKindTimestampTZ: timestamp := value.(time.Time) val = qvalue.QValue{Kind: qvalue.QValueKindTimestampTZ, Value: timestamp} + case qvalue.QValueKindInterval: + intervalObject := value.(pgtype.Interval) + var interval peerdb_interval.PeerDBInterval + interval.Hours = int(intervalObject.Microseconds / 3600000000) + interval.Minutes = int((intervalObject.Microseconds % 3600000000) / 60000000) + interval.Seconds = float64(intervalObject.Microseconds%60000000) / 1000000.0 + interval.Days = int(intervalObject.Days) + interval.Years = int(intervalObject.Months / 12) + interval.Months = int(intervalObject.Months % 12) + interval.Valid = intervalObject.Valid + + intervalJSON, err := json.Marshal(interval) + if err != nil { + return qvalue.QValue{}, fmt.Errorf("failed to parse interval: %w", err) + } + + if !interval.Valid { + return qvalue.QValue{}, fmt.Errorf("invalid interval: %v", value) + } + + return qvalue.QValue{Kind: qvalue.QValueKindString, Value: string(intervalJSON)}, nil case qvalue.QValueKindDate: date := value.(time.Time) val = qvalue.QValue{Kind: qvalue.QValueKindDate, Value: date} diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go index 19be0cfd94..b25d465a74 100644 --- a/flow/connectors/snowflake/merge_stmt_generator.go +++ b/flow/connectors/snowflake/merge_stmt_generator.go @@ -51,7 +51,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TO_GEOMETRY(CAST(%s:\"%s\" AS STRING),true) AS %s", toVariantColumnName, column.Name, targetColumnName)) - case qvalue.QValueKindJSON, qvalue.QValueKindHStore: + case qvalue.QValueKindJSON, qvalue.QValueKindHStore, qvalue.QValueKindInterval: flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("PARSE_JSON(CAST(%s:\"%s\" AS STRING)) AS %s", toVariantColumnName, column.Name, targetColumnName)) diff --git a/flow/e2e/snowflake/peer_flow_sf_test.go b/flow/e2e/snowflake/peer_flow_sf_test.go index 635a7f3e0d..90b16be522 100644 --- a/flow/e2e/snowflake/peer_flow_sf_test.go +++ b/flow/e2e/snowflake/peer_flow_sf_test.go @@ -439,7 +439,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Types_SF() { `, srcTableName)) e2e.EnvNoError(s.t, env, err) - e2e.EnvWaitFor(s.t, env, 2*time.Minute, "normalize types", func() bool { + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize types", func() bool { noNulls, err := s.sfHelper.CheckNull("test_types_sf", []string{ "c41", "c1", "c2", "c3", "c4", "c6", "c39", "c40", "id", "c9", "c11", "c12", "c13", "c14", "c15", "c16", "c17", "c18", @@ -448,7 +448,19 @@ func (s PeerFlowE2ETestSuiteSF) Test_Types_SF() { "c50", "c51", "c52", "c53", "c54", }) if err != nil { - s.t.Log(err) + return false + } + + // interval checks + if err := s.checkJSONValue(dstTableName, "c16", "years", "5"); err != nil { + return false + } + + if err := s.checkJSONValue(dstTableName, "c16", "months", "2"); err != nil { + return false + } + + if err := s.checkJSONValue(dstTableName, "c16", "days", "29"); err != nil { return false } diff --git a/flow/interval/interval.go b/flow/interval/interval.go new file mode 100644 index 0000000000..79fbdc3ecf --- /dev/null +++ b/flow/interval/interval.go @@ -0,0 +1,11 @@ +package peerdb_interval + +type PeerDBInterval struct { + Hours int `json:"hours,omitempty"` + Minutes int `json:"minutes,omitempty"` + Seconds float64 `json:"seconds,omitempty"` + Days int `json:"days,omitempty"` + Months int `json:"months,omitempty"` + Years int `json:"years,omitempty"` + Valid bool `json:"valid"` +} diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index 0a299cf82f..3df8738209 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -58,7 +58,11 @@ type AvroSchemaField struct { // will return an error. func GetAvroSchemaFromQValueKind(kind QValueKind, targetDWH QDWHType, precision int16, scale int16) (interface{}, error) { switch kind { - case QValueKindString, QValueKindQChar, QValueKindCIDR, QValueKindINET: + case QValueKindString: + return "string", nil + case QValueKindQChar, QValueKindCIDR, QValueKindINET: + return "string", nil + case QValueKindInterval: return "string", nil case QValueKindUUID: return AvroSchemaLogical{ @@ -285,7 +289,7 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { return t, nil case QValueKindQChar: return c.processNullableUnion("string", string(c.Value.(uint8))) - case QValueKindString, QValueKindCIDR, QValueKindINET, QValueKindMacaddr: + case QValueKindString, QValueKindCIDR, QValueKindINET, QValueKindMacaddr, QValueKindInterval: if c.TargetDWH == QDWHTypeSnowflake && c.Value != nil && (len(c.Value.(string)) > 15*1024*1024) { slog.Warn("Truncating TEXT value > 15MB for Snowflake!") diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go index 9def7821f4..9ed9ac0beb 100644 --- a/flow/model/qvalue/kind.go +++ b/flow/model/qvalue/kind.go @@ -24,6 +24,7 @@ const ( QValueKindDate QValueKind = "date" QValueKindTime QValueKind = "time" QValueKindTimeTZ QValueKind = "timetz" + QValueKindInterval QValueKind = "interval" QValueKindNumeric QValueKind = "numeric" QValueKindBytes QValueKind = "bytes" QValueKindUUID QValueKind = "uuid" @@ -69,6 +70,7 @@ var QValueKindToSnowflakeTypeMap = map[QValueKind]string{ QValueKindJSON: "VARIANT", QValueKindTimestamp: "TIMESTAMP_NTZ", QValueKindTimestampTZ: "TIMESTAMP_TZ", + QValueKindInterval: "VARIANT", QValueKindTime: "TIME", QValueKindTimeTZ: "TIME", QValueKindDate: "DATE", @@ -117,6 +119,7 @@ var QValueKindToClickhouseTypeMap = map[QValueKind]string{ QValueKindTimeTZ: "String", QValueKindInvalid: "String", QValueKindHStore: "String", + // array types will be mapped to VARIANT QValueKindArrayFloat32: "Array(Float32)", QValueKindArrayFloat64: "Array(Float64)", From 970232d0c1cbe737d842922230bfff86d26bcb3f Mon Sep 17 00:00:00 2001 From: Amogh Bharadwaj Date: Thu, 21 Mar 2024 21:37:46 +0530 Subject: [PATCH 7/9] Add heartbeats to account for QRepPartitionIsSynced (#1518) When ReplicateQRepPartition fails for some reason midway after it's synced a bunch of partitions, it now enters a retry flow where it skims through partitions it's synced already, checking if they're synced and skipping if so. This process can take longer than the heartbeat for that activity (1 minute) So we need to add heartbeats in this process to prevent heartbeat timeouts Co-authored-by: Kaushik Iska --- flow/activities/flowable.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 68f764a890..52ab1a8ca7 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -616,6 +616,9 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, partition *protos.QRepPartition, runUUID string, ) error { + msg := fmt.Sprintf("replicating partition - %s: %d of %d total.", partition.PartitionId, idx, total) + activity.RecordHeartbeat(ctx, msg) + ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowJobName) logger := log.With(activity.GetLogger(ctx), slog.String(string(shared.FlowNameKey), config.FlowJobName)) @@ -643,6 +646,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, } if done { logger.Info("no records to push for partition " + partition.PartitionId) + activity.RecordHeartbeat(ctx, "no records to push for partition "+partition.PartitionId) return nil } From 15c1a68382cddcb10dfa26864e9d8cbd0ba807c9 Mon Sep 17 00:00:00 2001 From: Kevin Biju <52661649+heavycrystal@users.noreply.github.com> Date: Thu, 21 Mar 2024 22:39:05 +0530 Subject: [PATCH 8/9] test for column exclusion with schema changes (#1519) should have caught #1512 --- flow/e2e/snowflake/peer_flow_sf_test.go | 84 +++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/flow/e2e/snowflake/peer_flow_sf_test.go b/flow/e2e/snowflake/peer_flow_sf_test.go index 90b16be522..27e6156245 100644 --- a/flow/e2e/snowflake/peer_flow_sf_test.go +++ b/flow/e2e/snowflake/peer_flow_sf_test.go @@ -1143,3 +1143,87 @@ func (s PeerFlowE2ETestSuiteSF) Test_Supported_Mixed_Case_Table_SF() { e2e.RequireEnvCanceled(s.t, env) } + +func (s PeerFlowE2ETestSuiteSF) Test_Column_Exclusion_With_Schema_Changes() { + tc := e2e.NewTemporalClient(s.t) + + tableName := "test_exclude_schema_changes_sf" + srcTableName := s.attachSchemaSuffix(tableName) + dstTableName := fmt.Sprintf("%s.%s", s.sfHelper.testSchemaName, tableName) + + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id INT GENERATED ALWAYS AS IDENTITY, + c1 INT GENERATED BY DEFAULT AS IDENTITY, + c2 INT, + t TEXT, + PRIMARY KEY(id,t) + ); + `, srcTableName)) + require.NoError(s.t, err) + + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: s.attachSuffix(tableName), + } + + config := &protos.FlowConnectionConfigs{ + FlowJobName: connectionGen.FlowJobName, + Destination: s.sfHelper.Peer, + TableMappings: []*protos.TableMapping{ + { + SourceTableIdentifier: srcTableName, + DestinationTableIdentifier: dstTableName, + Exclude: []string{"c2"}, + }, + }, + Source: e2e.GeneratePostgresPeer(), + CdcStagingPath: connectionGen.CdcStagingPath, + SyncedAtColName: "_PEERDB_SYNCED_AT", + MaxBatchSize: 100, + } + + // wait for PeerFlowStatusQuery to finish setup + // and then insert, update and delete rows in the table. + env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, config, nil) + e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) + + // insert 10 rows into the source table + for i := range 10 { + testValue := fmt.Sprintf("test_value_%d", i) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c2,t) VALUES ($1,$2) + `, srcTableName), i, testValue) + e2e.EnvNoError(s.t, env, err) + } + s.t.Log("Inserted 10 rows into the source table") + + e2e.EnvWaitForEqualTables(env, s, "normalize table", tableName, "id,c1,t") + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf("ALTER TABLE %s ADD COLUMN t2 TEXT", srcTableName)) + e2e.EnvNoError(s.t, env, err) + // insert 10 more rows into the source table + for i := range 10 { + testValue := fmt.Sprintf("test_value_%d", i) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s(c2,t,t2) VALUES ($1,$2,random_string(100)) + `, srcTableName), i, testValue) + e2e.EnvNoError(s.t, env, err) + } + _, err = s.Conn().Exec(context.Background(), + fmt.Sprintf(`UPDATE %s SET c1=c1+1 WHERE MOD(c2,2)=1`, srcTableName)) + e2e.EnvNoError(s.t, env, err) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`DELETE FROM %s WHERE MOD(c2,2)=0`, srcTableName)) + e2e.EnvNoError(s.t, env, err) + e2e.EnvWaitForEqualTables(env, s, "normalize update/delete", tableName, "id,c1,t,t2") + + env.Cancel() + + e2e.RequireEnvCanceled(s.t, env) + + sfRows, err := s.GetRows(tableName, "*") + require.NoError(s.t, err) + + for _, field := range sfRows.Schema.Fields { + require.NotEqual(s.t, "c2", field.Name) + } + require.Len(s.t, sfRows.Schema.Fields, 5) +} From 2b8a575fc7a7706d80f52ab5f6516f2f910a83d9 Mon Sep 17 00:00:00 2001 From: Amogh Bharadwaj Date: Fri, 22 Mar 2024 01:40:10 +0530 Subject: [PATCH 9/9] Refactor peer validate check functions (#1521) Validation functions for most of our peers were inside NewSnowflakeConnector() , NewBigQueryConnector() and so on. We instantiate new connectors often during mirrors, where it isn't necessary to validate again and again. This PR introduces a new method for Connector interface - `ValidateCheck()` which does the validation and this is called only in validate peer. The validation which was done in NewConnector() is now done here. Functionally tested for all peers where validation is there: - Snowflake - BigQuery - Clickhouse - S3 --- flow/cmd/validate_peer.go | 12 ++++++++++++ flow/connectors/bigquery/bigquery.go | 18 ++++++------------ flow/connectors/clickhouse/clickhouse.go | 23 +++++++++-------------- flow/connectors/core.go | 13 +++++++++++++ flow/connectors/s3/s3.go | 18 ++++++------------ flow/connectors/snowflake/snowflake.go | 16 ++++++---------- 6 files changed, 52 insertions(+), 48 deletions(-) diff --git a/flow/cmd/validate_peer.go b/flow/cmd/validate_peer.go index 5bbc1cfb2c..6a77fbe5ca 100644 --- a/flow/cmd/validate_peer.go +++ b/flow/cmd/validate_peer.go @@ -55,6 +55,18 @@ func (h *FlowRequestHandler) ValidatePeer( } } + validationConn, ok := conn.(connectors.ValidationConnector) + if ok { + validErr := validationConn.ValidateCheck(ctx) + if validErr != nil { + return &protos.ValidatePeerResponse{ + Status: protos.ValidatePeerStatus_INVALID, + Message: fmt.Sprintf("failed to validate %s peer %s: %v", + req.Peer.Type, req.Peer.Name, validErr), + }, nil + } + } + connErr := conn.ConnectionActive(ctx) if connErr != nil { return &protos.ValidatePeerResponse{ diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 5b2b8b01f8..db56cfd9fc 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -131,14 +131,14 @@ func (bqsa *BigQueryServiceAccount) CreateStorageClient(ctx context.Context) (*s return client, nil } -// TableCheck: +// ValidateCheck: // 1. Creates a table // 2. Inserts one row into the table // 3. Deletes the table -func TableCheck(ctx context.Context, client *bigquery.Client, dataset string, project string) error { +func (c *BigQueryConnector) ValidateCheck(ctx context.Context) error { dummyTable := "peerdb_validate_dummy_" + shared.RandomString(4) - newTable := client.DatasetInProject(project, dataset).Table(dummyTable) + newTable := c.client.DatasetInProject(c.projectID, c.datasetID).Table(dummyTable) createErr := newTable.Create(ctx, &bigquery.TableMetadata{ Schema: []*bigquery.FieldSchema{ @@ -155,9 +155,9 @@ func TableCheck(ctx context.Context, client *bigquery.Client, dataset string, pr } var errs []error - insertQuery := client.Query(fmt.Sprintf("INSERT INTO %s VALUES(true)", dummyTable)) - insertQuery.DefaultDatasetID = dataset - insertQuery.DefaultProjectID = project + insertQuery := c.client.Query(fmt.Sprintf("INSERT INTO %s VALUES(true)", dummyTable)) + insertQuery.DefaultDatasetID = c.datasetID + insertQuery.DefaultProjectID = c.projectID _, insertErr := insertQuery.Run(ctx) if insertErr != nil { errs = append(errs, fmt.Errorf("unable to validate insertion into table: %w. ", insertErr)) @@ -207,12 +207,6 @@ func NewBigQueryConnector(ctx context.Context, config *protos.BigqueryConfig) (* return nil, fmt.Errorf("failed to get dataset metadata: %v", datasetErr) } - permissionErr := TableCheck(ctx, client, datasetID, projectID) - if permissionErr != nil { - logger.Error("failed to get run mock table check", "error", permissionErr) - return nil, permissionErr - } - storageClient, err := bqsa.CreateStorageClient(ctx) if err != nil { return nil, fmt.Errorf("failed to create Storage client: %v", err) diff --git a/flow/connectors/clickhouse/clickhouse.go b/flow/connectors/clickhouse/clickhouse.go index 2390ad1109..a4e959b505 100644 --- a/flow/connectors/clickhouse/clickhouse.go +++ b/flow/connectors/clickhouse/clickhouse.go @@ -56,27 +56,32 @@ func ValidateS3(ctx context.Context, creds *utils.ClickhouseS3Credentials) error } // Creates and drops a dummy table to validate the peer -func ValidateClickhouse(ctx context.Context, conn *sql.DB) error { +func (c *ClickhouseConnector) ValidateCheck(ctx context.Context) error { validateDummyTableName := "peerdb_validation_" + shared.RandomString(4) // create a table - _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id UInt64) ENGINE = Memory", + _, err := c.database.ExecContext(ctx, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id UInt64) ENGINE = Memory", validateDummyTableName)) if err != nil { return fmt.Errorf("failed to create validation table %s: %w", validateDummyTableName, err) } // insert a row - _, err = conn.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s VALUES (1)", validateDummyTableName)) + _, err = c.database.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s VALUES (1)", validateDummyTableName)) if err != nil { return fmt.Errorf("failed to insert into validation table %s: %w", validateDummyTableName, err) } // drop the table - _, err = conn.ExecContext(ctx, "DROP TABLE IF EXISTS "+validateDummyTableName) + _, err = c.database.ExecContext(ctx, "DROP TABLE IF EXISTS "+validateDummyTableName) if err != nil { return fmt.Errorf("failed to drop validation table %s: %w", validateDummyTableName, err) } + validateErr := ValidateS3(ctx, c.creds) + if validateErr != nil { + return fmt.Errorf("failed to validate S3 bucket: %w", validateErr) + } + return nil } @@ -90,11 +95,6 @@ func NewClickhouseConnector( return nil, fmt.Errorf("failed to open connection to Clickhouse peer: %w", err) } - err = ValidateClickhouse(ctx, database) - if err != nil { - return nil, fmt.Errorf("invalidated Clickhouse peer: %w", err) - } - pgMetadata, err := metadataStore.NewPostgresMetadataStore(ctx) if err != nil { logger.Error("failed to create postgres metadata store", "error", err) @@ -122,11 +122,6 @@ func NewClickhouseConnector( clickhouseS3Creds = utils.GetClickhouseAWSSecrets(bucketPathSuffix) } - validateErr := ValidateS3(ctx, clickhouseS3Creds) - if validateErr != nil { - return nil, fmt.Errorf("failed to validate S3 bucket: %w", validateErr) - } - return &ClickhouseConnector{ database: database, pgMetadata: pgMetadata, diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 988b8e3f28..39e31f8171 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -27,6 +27,14 @@ type Connector interface { ConnectionActive(context.Context) error } +type ValidationConnector interface { + Connector + + // ValidationCheck performs validation for the connectors, + // usually includes permissions to create and use objects (tables, schema etc). + ValidateCheck(context.Context) error +} + type GetTableSchemaConnector interface { Connector @@ -279,4 +287,9 @@ var ( _ QRepConsolidateConnector = &connsnowflake.SnowflakeConnector{} _ QRepConsolidateConnector = &connclickhouse.ClickhouseConnector{} + + _ ValidationConnector = &connsnowflake.SnowflakeConnector{} + _ ValidationConnector = &connclickhouse.ClickhouseConnector{} + _ ValidationConnector = &connbigquery.BigQueryConnector{} + _ ValidationConnector = &conns3.S3Connector{} ) diff --git a/flow/connectors/s3/s3.go b/flow/connectors/s3/s3.go index de24f7e090..ddb6061b0d 100644 --- a/flow/connectors/s3/s3.go +++ b/flow/connectors/s3/s3.go @@ -89,10 +89,10 @@ func (c *S3Connector) Close() error { return nil } -func ValidCheck(ctx context.Context, s3Client *s3.Client, bucketURL string, metadataDB *metadataStore.PostgresMetadataStore) error { +func (c *S3Connector) ValidateCheck(ctx context.Context) error { reader := strings.NewReader(time.Now().Format(time.RFC3339)) - bucketPrefix, parseErr := utils.NewS3BucketAndPrefix(bucketURL) + bucketPrefix, parseErr := utils.NewS3BucketAndPrefix(c.url) if parseErr != nil { return fmt.Errorf("failed to parse bucket url: %w", parseErr) } @@ -100,7 +100,7 @@ func ValidCheck(ctx context.Context, s3Client *s3.Client, bucketURL string, meta // Write an empty file and then delete it // to check if we have write permissions bucketName := aws.String(bucketPrefix.Bucket) - _, putErr := s3Client.PutObject(ctx, &s3.PutObjectInput{ + _, putErr := c.client.PutObject(ctx, &s3.PutObjectInput{ Bucket: bucketName, Key: aws.String(_peerDBCheck), Body: reader, @@ -109,7 +109,7 @@ func ValidCheck(ctx context.Context, s3Client *s3.Client, bucketURL string, meta return fmt.Errorf("failed to write to bucket: %w", putErr) } - _, delErr := s3Client.DeleteObject(ctx, &s3.DeleteObjectInput{ + _, delErr := c.client.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: bucketName, Key: aws.String(_peerDBCheck), }) @@ -118,8 +118,8 @@ func ValidCheck(ctx context.Context, s3Client *s3.Client, bucketURL string, meta } // check if we can ping external metadata - if metadataDB != nil { - err := metadataDB.Ping(ctx) + if c.pgMetadata != nil { + err := c.pgMetadata.Ping(ctx) if err != nil { return fmt.Errorf("failed to ping external metadata: %w", err) } @@ -129,12 +129,6 @@ func ValidCheck(ctx context.Context, s3Client *s3.Client, bucketURL string, meta } func (c *S3Connector) ConnectionActive(ctx context.Context) error { - validErr := ValidCheck(ctx, &c.client, c.url, c.pgMetadata) - if validErr != nil { - c.logger.Error("failed to validate s3 connector:", "error", validErr) - return validErr - } - return nil } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index b1877320e1..f67fe59df0 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -104,10 +104,11 @@ type UnchangedToastColumnResult struct { UnchangedToastColumns ArrayString } -func ValidationCheck(ctx context.Context, database *sql.DB, schemaName string) error { +func (c *SnowflakeConnector) ValidateCheck(ctx context.Context) error { + schemaName := c.rawSchema // check if schema exists var schemaExists sql.NullBool - err := database.QueryRowContext(ctx, checkIfSchemaExistsSQL, schemaName).Scan(&schemaExists) + err := c.database.QueryRowContext(ctx, checkIfSchemaExistsSQL, schemaName).Scan(&schemaExists) if err != nil { return fmt.Errorf("error while checking if schema exists: %w", err) } @@ -116,9 +117,9 @@ func ValidationCheck(ctx context.Context, database *sql.DB, schemaName string) e // In a transaction, create a table, insert a row into the table and then drop the table // If any of these steps fail, the transaction will be rolled back - tx, err := database.BeginTx(ctx, nil) + tx, err := c.database.BeginTx(ctx, nil) if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) + return fmt.Errorf("failed to begin transaction for table check: %w", err) } // in case we return after error, ensure transaction is rolled back defer func() { @@ -158,7 +159,7 @@ func ValidationCheck(ctx context.Context, database *sql.DB, schemaName string) e // commit transaction err = tx.Commit() if err != nil { - return fmt.Errorf("failed to commit transaction: %w", err) + return fmt.Errorf("failed to commit transaction for table check: %w", err) } return nil @@ -212,11 +213,6 @@ func NewSnowflakeConnector( rawSchema = *snowflakeProtoConfig.MetadataSchema } - err = ValidationCheck(ctx, database, rawSchema) - if err != nil { - return nil, fmt.Errorf("could not validate snowflake peer: %w", err) - } - pgMetadata, err := metadataStore.NewPostgresMetadataStore(ctx) if err != nil { return nil, fmt.Errorf("could not connect to metadata store: %w", err)