diff --git a/flow/activities/flowable_core.go b/flow/activities/flowable_core.go index 2d1f7e1f3e..daf3e5f0a6 100644 --- a/flow/activities/flowable_core.go +++ b/flow/activities/flowable_core.go @@ -225,7 +225,7 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon } defer connectors.CloseConnector(ctx, dstConn) - if err := dstConn.ReplayTableSchemaDeltas(ctx, flowName, recordBatchSync.SchemaDeltas); err != nil { + if err := dstConn.ReplayTableSchemaDeltas(ctx, config.Env, flowName, recordBatchSync.SchemaDeltas); err != nil { return nil, fmt.Errorf("failed to sync schema: %w", err) } diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index f990b2f19d..d6504322ca 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -203,6 +203,7 @@ func (c *BigQueryConnector) waitForTableReady(ctx context.Context, datasetTable // This could involve adding or dropping multiple columns. func (c *BigQueryConnector) ReplayTableSchemaDeltas( ctx context.Context, + env map[string]string, flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { diff --git a/flow/connectors/bigquery/qrep.go b/flow/connectors/bigquery/qrep.go index 3da50c8e8f..f4dd19397b 100644 --- a/flow/connectors/bigquery/qrep.go +++ b/flow/connectors/bigquery/qrep.go @@ -80,7 +80,7 @@ func (c *BigQueryConnector) replayTableSchemaDeltasQRep( } } - err = c.ReplayTableSchemaDeltas(ctx, config.FlowJobName, []*protos.TableSchemaDelta{tableSchemaDelta}) + err = c.ReplayTableSchemaDeltas(ctx, config.Env, config.FlowJobName, []*protos.TableSchemaDelta{tableSchemaDelta}) if err != nil { return nil, fmt.Errorf("failed to add columns to destination table: %w", err) } diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index da3b15c37f..99322361e8 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -97,7 +97,7 @@ func (s *QRepAvroSyncMethod) SyncRecords( slog.String(string(shared.FlowNameKey), req.FlowJobName), slog.String("dstTableName", rawTableName)) - err = s.connector.ReplayTableSchemaDeltas(ctx, req.FlowJobName, req.Records.SchemaDeltas) + err = s.connector.ReplayTableSchemaDeltas(ctx, req.Env, req.FlowJobName, req.Records.SchemaDeltas) if err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } diff --git a/flow/connectors/clickhouse/cdc.go b/flow/connectors/clickhouse/cdc.go index d3eb883b46..5dc8a14628 100644 --- a/flow/connectors/clickhouse/cdc.go +++ b/flow/connectors/clickhouse/cdc.go @@ -93,7 +93,7 @@ func (c *ClickHouseConnector) syncRecordsViaAvro( return nil, err } - if err := c.ReplayTableSchemaDeltas(ctx, req.FlowJobName, req.Records.SchemaDeltas); err != nil { + if err := c.ReplayTableSchemaDeltas(ctx, req.Env, req.FlowJobName, req.Records.SchemaDeltas); err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } @@ -120,7 +120,10 @@ func (c *ClickHouseConnector) SyncRecords(ctx context.Context, req *model.SyncRe return res, nil } -func (c *ClickHouseConnector) ReplayTableSchemaDeltas(ctx context.Context, flowJobName string, +func (c *ClickHouseConnector) ReplayTableSchemaDeltas( + ctx context.Context, + env map[string]string, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { if len(schemaDeltas) == 0 { @@ -133,7 +136,7 @@ func (c *ClickHouseConnector) ReplayTableSchemaDeltas(ctx context.Context, flowJ } for _, addedColumn := range schemaDelta.AddedColumns { - clickHouseColType, err := qvalue.QValueKind(addedColumn.Type).ToDWHColumnType(protos.DBType_CLICKHOUSE) + clickHouseColType, err := qvalue.QValueKind(addedColumn.Type).ToDWHColumnType(ctx, env, protos.DBType_CLICKHOUSE, addedColumn) if err != nil { return fmt.Errorf("failed to convert column type %s to ClickHouse type: %w", addedColumn.Type, err) } diff --git a/flow/connectors/clickhouse/normalize.go b/flow/connectors/clickhouse/normalize.go index 86d7e65dd1..fabe07a35f 100644 --- a/flow/connectors/clickhouse/normalize.go +++ b/flow/connectors/clickhouse/normalize.go @@ -15,7 +15,6 @@ import ( "github.com/ClickHouse/clickhouse-go/v2" "golang.org/x/sync/errgroup" - "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -81,22 +80,6 @@ func getColName(overrides map[string]string, name string) string { return name } -func getClickhouseTypeForNumericColumn(ctx context.Context, column *protos.FieldDescription, env map[string]string) (string, error) { - if column.TypeModifier == -1 { - numericAsStringEnabled, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env) - if err != nil { - return "", err - } - if numericAsStringEnabled { - return "String", nil - } - } else if rawPrecision, _ := datatypes.ParseNumericTypmod(column.TypeModifier); rawPrecision > datatypes.PeerDBClickHouseMaxPrecision { - return "String", nil - } - precision, scale := datatypes.GetNumericTypeForWarehouse(column.TypeModifier, datatypes.ClickHouseNumericCompatibility{}) - return fmt.Sprintf("Decimal(%d, %d)", precision, scale), nil -} - func generateCreateTableSQLForNormalizedTable( ctx context.Context, config *protos.SetupNormalizedTableBatchInput, @@ -148,18 +131,10 @@ func generateCreateTableSQLForNormalizedTable( } if clickHouseType == "" { - if colType == qvalue.QValueKindNumeric { - var numericErr error - clickHouseType, numericErr = getClickhouseTypeForNumericColumn(ctx, column, config.Env) - if numericErr != nil { - return "", fmt.Errorf("error while mapping numeric column type to ClickHouse type: %w", numericErr) - } - } else { - var err error - clickHouseType, err = colType.ToDWHColumnType(protos.DBType_CLICKHOUSE) - if err != nil { - return "", fmt.Errorf("error while converting column type to ClickHouse type: %w", err) - } + var err error + clickHouseType, err = colType.ToDWHColumnType(ctx, config.Env, protos.DBType_CLICKHOUSE, column) + if err != nil { + return "", fmt.Errorf("error while converting column type to ClickHouse type: %w", err) } } if (tableSchema.NullableEnabled || columnNullableEnabled) && column.Nullable && !colType.IsArray() { @@ -378,21 +353,13 @@ func (c *ClickHouseConnector) NormalizeRecords( colSelector.WriteString(fmt.Sprintf("`%s`,", dstColName)) if clickHouseType == "" { - if colType == qvalue.QValueKindNumeric { - var numericErr error - clickHouseType, numericErr = getClickhouseTypeForNumericColumn(ctx, column, req.Env) - if numericErr != nil { - close(queries) - return nil, fmt.Errorf("error while mapping numeric column type to clickhouse type: %w", numericErr) - } - } else { - var err error - clickHouseType, err = colType.ToDWHColumnType(protos.DBType_CLICKHOUSE) - if err != nil { - close(queries) - return nil, fmt.Errorf("error while converting column type to clickhouse type: %w", err) - } + var err error + clickHouseType, err = colType.ToDWHColumnType(ctx, req.Env, protos.DBType_CLICKHOUSE, column) + if err != nil { + close(queries) + return nil, fmt.Errorf("error while converting column type to clickhouse type: %w", err) } + if (schema.NullableEnabled || columnNullableEnabled) && column.Nullable && !colType.IsArray() { clickHouseType = fmt.Sprintf("Nullable(%s)", clickHouseType) } diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 073d9d82b4..0991a50978 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -173,7 +173,7 @@ type CDCSyncConnectorCore interface { // ReplayTableSchemaDelta changes a destination table to match the schema at source // This could involve adding or dropping multiple columns. // Connectors which are non-normalizing should implement this as a nop. - ReplayTableSchemaDeltas(ctx context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error + ReplayTableSchemaDeltas(ctx context.Context, env map[string]string, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error } type CDCSyncConnector interface { @@ -463,8 +463,6 @@ var ( _ CDCSyncConnector = &connclickhouse.ClickHouseConnector{} _ CDCSyncConnector = &connelasticsearch.ElasticsearchConnector{} - _ CDCSyncPgConnector = &connpostgres.PostgresConnector{} - _ CDCNormalizeConnector = &connpostgres.PostgresConnector{} _ CDCNormalizeConnector = &connbigquery.BigQueryConnector{} _ CDCNormalizeConnector = &connsnowflake.SnowflakeConnector{} diff --git a/flow/connectors/elasticsearch/elasticsearch.go b/flow/connectors/elasticsearch/elasticsearch.go index e675168051..30279fd74e 100644 --- a/flow/connectors/elasticsearch/elasticsearch.go +++ b/flow/connectors/elasticsearch/elasticsearch.go @@ -95,7 +95,7 @@ func (esc *ElasticsearchConnector) CreateRawTable(ctx context.Context, } // we handle schema changes by not handling them since no mapping is being enforced right now -func (esc *ElasticsearchConnector) ReplayTableSchemaDeltas(ctx context.Context, +func (esc *ElasticsearchConnector) ReplayTableSchemaDeltas(ctx context.Context, env map[string]string, flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { return nil diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index 01982bf713..0f175233ef 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -380,7 +380,9 @@ func (c *EventHubConnector) CreateRawTable(ctx context.Context, req *protos.Crea }, nil } -func (c *EventHubConnector) ReplayTableSchemaDeltas(_ context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error { +func (c *EventHubConnector) ReplayTableSchemaDeltas(_ context.Context, _ map[string]string, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, +) error { c.logger.Info("ReplayTableSchemaDeltas for event hub is a no-op") return nil } diff --git a/flow/connectors/kafka/kafka.go b/flow/connectors/kafka/kafka.go index ea0805b84b..ee78093fe6 100644 --- a/flow/connectors/kafka/kafka.go +++ b/flow/connectors/kafka/kafka.go @@ -149,7 +149,9 @@ func (c *KafkaConnector) SetupMetadataTables(_ context.Context) error { return nil } -func (c *KafkaConnector) ReplayTableSchemaDeltas(_ context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error { +func (c *KafkaConnector) ReplayTableSchemaDeltas(_ context.Context, _ map[string]string, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, +) error { return nil } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 14b827cc89..8f49545fff 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -592,7 +592,7 @@ func syncRecordsCore[Items model.Items]( return nil, err } - err = c.ReplayTableSchemaDeltas(ctx, req.FlowJobName, req.Records.SchemaDeltas) + err = c.ReplayTableSchemaDeltas(ctx, req.Env, req.FlowJobName, req.Records.SchemaDeltas) if err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } @@ -941,6 +941,7 @@ func (c *PostgresConnector) SetupNormalizedTable( // This could involve adding or dropping multiple columns. func (c *PostgresConnector) ReplayTableSchemaDeltas( ctx context.Context, + _ map[string]string, flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { diff --git a/flow/connectors/postgres/postgres_schema_delta_test.go b/flow/connectors/postgres/postgres_schema_delta_test.go index 946b20eb3e..0b6668a5a2 100644 --- a/flow/connectors/postgres/postgres_schema_delta_test.go +++ b/flow/connectors/postgres/postgres_schema_delta_test.go @@ -58,7 +58,7 @@ func (s PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() { fmt.Sprintf("CREATE TABLE %s(id INT PRIMARY KEY)", tableName)) require.NoError(s.t, err) - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: []*protos.FieldDescription{ @@ -113,7 +113,7 @@ func (s PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() { } } - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, @@ -144,7 +144,7 @@ func (s PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() { } } - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, @@ -175,7 +175,7 @@ func (s PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() { } } - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, diff --git a/flow/connectors/pubsub/pubsub.go b/flow/connectors/pubsub/pubsub.go index 49aed379c4..537cda7241 100644 --- a/flow/connectors/pubsub/pubsub.go +++ b/flow/connectors/pubsub/pubsub.go @@ -67,7 +67,9 @@ func (c *PubSubConnector) CreateRawTable(ctx context.Context, req *protos.Create return &protos.CreateRawTableOutput{TableIdentifier: "n/a"}, nil } -func (c *PubSubConnector) ReplayTableSchemaDeltas(_ context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error { +func (c *PubSubConnector) ReplayTableSchemaDeltas(_ context.Context, _ map[string]string, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, +) error { return nil } diff --git a/flow/connectors/s3/s3.go b/flow/connectors/s3/s3.go index eac37cd7c8..7d16a20af0 100644 --- a/flow/connectors/s3/s3.go +++ b/flow/connectors/s3/s3.go @@ -118,7 +118,9 @@ func (c *S3Connector) SyncRecords(ctx context.Context, req *model.SyncRecordsReq }, nil } -func (c *S3Connector) ReplayTableSchemaDeltas(_ context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error { +func (c *S3Connector) ReplayTableSchemaDeltas(_ context.Context, _ map[string]string, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, +) error { c.logger.Info("ReplayTableSchemaDeltas for S3 is a no-op") return nil } diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go index 37b4ed7bdb..d87d3004f7 100644 --- a/flow/connectors/snowflake/merge_stmt_generator.go +++ b/flow/connectors/snowflake/merge_stmt_generator.go @@ -1,6 +1,7 @@ package connsnowflake import ( + "context" "fmt" "strings" @@ -24,7 +25,7 @@ type mergeStmtGenerator struct { mergeBatchId int64 } -func (m *mergeStmtGenerator) generateMergeStmt(dstTable string) (string, error) { +func (m *mergeStmtGenerator) generateMergeStmt(ctx context.Context, env map[string]string, dstTable string) (string, error) { parsedDstTable, _ := utils.ParseSchemaTable(dstTable) normalizedTableSchema := m.tableSchemaMapping[dstTable] unchangedToastColumns := m.unchangedToastColumnsMap[dstTable] @@ -34,7 +35,7 @@ func (m *mergeStmtGenerator) generateMergeStmt(dstTable string) (string, error) for _, column := range columns { genericColumnType := column.Type qvKind := qvalue.QValueKind(genericColumnType) - sfType, err := qvKind.ToDWHColumnType(protos.DBType_SNOWFLAKE) + sfType, err := qvKind.ToDWHColumnType(ctx, env, protos.DBType_SNOWFLAKE, column) if err != nil { return "", fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err) } diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index 0fea54b027..3fa9b861d9 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -98,7 +98,7 @@ func (s *SnowflakeAvroSyncHandler) SyncQRepRecords( schema := stream.Schema() s.logger.Info("sync function called and schema acquired", partitionLog) - err := s.addMissingColumns(ctx, schema, dstTableSchema, dstTableName, partition) + err := s.addMissingColumns(ctx, config.Env, schema, dstTableSchema, dstTableName, partition) if err != nil { return 0, err } @@ -130,6 +130,7 @@ func (s *SnowflakeAvroSyncHandler) SyncQRepRecords( func (s *SnowflakeAvroSyncHandler) addMissingColumns( ctx context.Context, + env map[string]string, schema qvalue.QRecordSchema, dstTableSchema []*sql.ColumnType, dstTableName string, @@ -138,7 +139,7 @@ func (s *SnowflakeAvroSyncHandler) addMissingColumns( partitionLog := slog.String(string(shared.PartitionIDKey), partition.PartitionId) // check if avro schema has additional columns compared to destination table // if so, we need to add those columns to the destination table - colsToTypes := map[string]qvalue.QValueKind{} + var newColumns []qvalue.QField for _, col := range schema.Fields { hasColumn := false // check ignoring case @@ -152,24 +153,23 @@ func (s *SnowflakeAvroSyncHandler) addMissingColumns( if !hasColumn { s.logger.Info(fmt.Sprintf("adding column %s to destination table %s", col.Name, dstTableName), partitionLog) - colsToTypes[col.Name] = col.Type + newColumns = append(newColumns, col) } } - if len(colsToTypes) > 0 { + if len(newColumns) > 0 { tx, err := s.database.Begin() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } - for colName, colType := range colsToTypes { - sfColType, err := colType.ToDWHColumnType(protos.DBType_SNOWFLAKE) + for _, column := range newColumns { + sfColType, err := column.ToDWHColumnType(ctx, env, protos.DBType_SNOWFLAKE) if err != nil { return fmt.Errorf("failed to convert QValueKind to Snowflake column type: %w", err) } - upperCasedColName := strings.ToUpper(colName) - alterTableCmd := fmt.Sprintf("ALTER TABLE %s ", dstTableName) - alterTableCmd += fmt.Sprintf("ADD COLUMN IF NOT EXISTS \"%s\" %s;", upperCasedColName, sfColType) + upperCasedColName := strings.ToUpper(column.Name) + alterTableCmd := fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s;", dstTableName, upperCasedColName, sfColType) s.logger.Info(fmt.Sprintf("altering destination table %s with command `%s`", dstTableName, alterTableCmd), partitionLog) diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 06e3fb881e..518b01ff2b 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -19,7 +19,6 @@ import ( metadataStore "github.com/PeerDB-io/peer-flow/connectors/external_metadata" "github.com/PeerDB-io/peer-flow/connectors/utils" - numeric "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -338,7 +337,7 @@ func (c *SnowflakeConnector) SetupNormalizedTable( return true, nil } - normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(config, normalizedSchemaTable, tableSchema) + normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(ctx, config, normalizedSchemaTable, tableSchema) if _, err := c.execWithLogging(ctx, normalizedTableCreateSQL); err != nil { return false, fmt.Errorf("[sf] error while creating normalized table: %w", err) } @@ -349,6 +348,7 @@ func (c *SnowflakeConnector) SetupNormalizedTable( // This could involve adding or dropping multiple columns. func (c *SnowflakeConnector) ReplayTableSchemaDeltas( ctx context.Context, + env map[string]string, flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { @@ -374,17 +374,12 @@ func (c *SnowflakeConnector) ReplayTableSchemaDeltas( } for _, addedColumn := range schemaDelta.AddedColumns { - sfColtype, err := qvalue.QValueKind(addedColumn.Type).ToDWHColumnType(protos.DBType_SNOWFLAKE) + sfColtype, err := qvalue.QValueKind(addedColumn.Type).ToDWHColumnType(ctx, env, protos.DBType_SNOWFLAKE, addedColumn) if err != nil { return fmt.Errorf("failed to convert column type %s to snowflake type: %w", addedColumn.Type, err) } - if addedColumn.Type == string(qvalue.QValueKindNumeric) { - precision, scale := numeric.GetNumericTypeForWarehouse(addedColumn.TypeModifier, numeric.SnowflakeNumericCompatibility{}) - sfColtype = fmt.Sprintf("NUMERIC(%d,%d)", precision, scale) - } - _, err = tableSchemaModifyTx.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s", schemaDelta.DstTableName, strings.ToUpper(addedColumn.Name), sfColtype)) @@ -460,7 +455,7 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( return nil, err } - err = c.ReplayTableSchemaDeltas(ctx, req.FlowJobName, req.Records.SchemaDeltas) + err = c.ReplayTableSchemaDeltas(ctx, req.Env, req.FlowJobName, req.Records.SchemaDeltas) if err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } @@ -557,7 +552,7 @@ func (c *SnowflakeConnector) mergeTablesForBatch( } g.Go(func() error { - mergeStatement, err := mergeGen.generateMergeStmt(tableName) + mergeStatement, err := mergeGen.generateMergeStmt(gCtx, env, tableName) if err != nil { return err } @@ -666,6 +661,7 @@ func (c *SnowflakeConnector) checkIfTableExists( } func generateCreateTableSQLForNormalizedTable( + ctx context.Context, config *protos.SetupNormalizedTableBatchInput, dstSchemaTable *utils.SchemaTable, tableSchema *protos.TableSchema, @@ -674,18 +670,13 @@ func generateCreateTableSQLForNormalizedTable( for _, column := range tableSchema.Columns { genericColumnType := column.Type normalizedColName := SnowflakeIdentifierNormalize(column.Name) - sfColType, err := qvalue.QValueKind(genericColumnType).ToDWHColumnType(protos.DBType_SNOWFLAKE) + sfColType, err := qvalue.QValueKind(genericColumnType).ToDWHColumnType(ctx, config.Env, protos.DBType_SNOWFLAKE, column) if err != nil { slog.Warn(fmt.Sprintf("failed to convert column type %s to snowflake type", genericColumnType), slog.Any("error", err)) continue } - if genericColumnType == "numeric" { - precision, scale := numeric.GetNumericTypeForWarehouse(column.TypeModifier, numeric.SnowflakeNumericCompatibility{}) - sfColType = fmt.Sprintf("NUMERIC(%d,%d)", precision, scale) - } - var notNull string if tableSchema.NullableEnabled && !column.Nullable { notNull = " NOT NULL" diff --git a/flow/e2e/clickhouse/peer_flow_ch_test.go b/flow/e2e/clickhouse/peer_flow_ch_test.go index 964f23784c..b9dd7c625c 100644 --- a/flow/e2e/clickhouse/peer_flow_ch_test.go +++ b/flow/e2e/clickhouse/peer_flow_ch_test.go @@ -4,6 +4,7 @@ import ( "context" "embed" "fmt" + "strconv" "strings" "testing" "time" @@ -624,7 +625,7 @@ func (s ClickHouseSuite) testNumericFF(ffValue bool) { } flowConnConfig := connectionGen.GenerateFlowConnectionConfigs(s.t) flowConnConfig.DoInitialSnapshot = true - flowConnConfig.Env = map[string]string{"PEERDB_CLICKHOUSE_NUMERIC_AS_STRING": fmt.Sprint(ffValue)} + flowConnConfig.Env = map[string]string{"PEERDB_CLICKHOUSE_NUMERIC_AS_STRING": strconv.FormatBool(ffValue)} tc := e2e.NewTemporalClient(s.t) env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) e2e.SetupCDCFlowStatusQuery(s.t, env, flowConnConfig) diff --git a/flow/e2e/snowflake/snowflake_schema_delta_test.go b/flow/e2e/snowflake/snowflake_schema_delta_test.go index 32cb03b644..ada2b10f6a 100644 --- a/flow/e2e/snowflake/snowflake_schema_delta_test.go +++ b/flow/e2e/snowflake/snowflake_schema_delta_test.go @@ -53,7 +53,7 @@ func (s SnowflakeSchemaDeltaTestSuite) TestSimpleAddColumn() { err := s.sfTestHelper.RunCommand(fmt.Sprintf("CREATE TABLE %s(ID TEXT PRIMARY KEY)", tableName)) require.NoError(s.t, err) - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: []*protos.FieldDescription{ @@ -167,7 +167,7 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddAllColumnTypes() { } } - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, @@ -246,7 +246,7 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddTrickyColumnNames() { } } - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, @@ -301,7 +301,7 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddWhitespaceColumnNames() { } } - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go index 91ab867a0e..3cffcc274a 100644 --- a/flow/model/qvalue/kind.go +++ b/flow/model/qvalue/kind.go @@ -1,10 +1,13 @@ package qvalue import ( + "context" "fmt" "strings" + "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/peerdbenv" ) type QValueKind string @@ -68,7 +71,6 @@ var QValueKindToSnowflakeTypeMap = map[QValueKind]string{ QValueKindInt64: "INTEGER", QValueKindFloat32: "FLOAT", QValueKindFloat64: "FLOAT", - QValueKindNumeric: "NUMBER(38, 9)", QValueKindQChar: "CHAR", QValueKindString: "STRING", QValueKindJSON: "VARIANT", @@ -110,7 +112,6 @@ var QValueKindToClickHouseTypeMap = map[QValueKind]string{ QValueKindInt64: "Int64", QValueKindFloat32: "Float32", QValueKindFloat64: "Float64", - QValueKindNumeric: "Decimal128(9)", QValueKindQChar: "FixedString(1)", QValueKindString: "String", QValueKindJSON: "String", @@ -140,16 +141,39 @@ var QValueKindToClickHouseTypeMap = map[QValueKind]string{ QValueKindArrayJSONB: "String", } -func (kind QValueKind) ToDWHColumnType(dwhType protos.DBType) (string, error) { +func getClickHouseTypeForNumericColumn(ctx context.Context, env map[string]string, column *protos.FieldDescription) (string, error) { + if column.TypeModifier == -1 { + numericAsStringEnabled, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env) + if err != nil { + return "", err + } + if numericAsStringEnabled { + return "String", nil + } + } else if rawPrecision, _ := datatypes.ParseNumericTypmod(column.TypeModifier); rawPrecision > datatypes.PeerDBClickHouseMaxPrecision { + return "String", nil + } + precision, scale := datatypes.GetNumericTypeForWarehouse(column.TypeModifier, datatypes.ClickHouseNumericCompatibility{}) + return fmt.Sprintf("Decimal(%d, %d)", precision, scale), nil +} + +// SEE ALSO: QField ToDWHColumnType +func (kind QValueKind) ToDWHColumnType(ctx context.Context, env map[string]string, dwhType protos.DBType, column *protos.FieldDescription, +) (string, error) { switch dwhType { case protos.DBType_SNOWFLAKE: - if val, ok := QValueKindToSnowflakeTypeMap[kind]; ok { + if kind == QValueKindNumeric { + precision, scale := datatypes.GetNumericTypeForWarehouse(column.TypeModifier, datatypes.SnowflakeNumericCompatibility{}) + return fmt.Sprintf("NUMERIC(%d,%d)", precision, scale), nil + } else if val, ok := QValueKindToSnowflakeTypeMap[kind]; ok { return val, nil } else { return "STRING", nil } case protos.DBType_CLICKHOUSE: - if val, ok := QValueKindToClickHouseTypeMap[kind]; ok { + if kind == QValueKindNumeric { + return getClickHouseTypeForNumericColumn(ctx, env, column) + } else if val, ok := QValueKindToClickHouseTypeMap[kind]; ok { return val, nil } else { return "String", nil diff --git a/flow/model/qvalue/qschema.go b/flow/model/qvalue/qschema.go index a956968ac1..a6632fdf5f 100644 --- a/flow/model/qvalue/qschema.go +++ b/flow/model/qvalue/qschema.go @@ -1,7 +1,13 @@ package qvalue import ( + "context" + "fmt" "strings" + + "github.com/PeerDB-io/peer-flow/datatypes" + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/peerdbenv" ) type QField struct { @@ -47,3 +53,42 @@ func (q QRecordSchema) GetColumnNames() []string { } return names } + +func (q QField) getClickHouseTypeForNumericField(ctx context.Context, env map[string]string) (string, error) { + if q.Precision == 0 && q.Scale == 0 { + numericAsStringEnabled, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env) + if err != nil { + return "", err + } + if numericAsStringEnabled { + return "String", nil + } + } else if q.Precision > datatypes.PeerDBClickHouseMaxPrecision { + return "String", nil + } + return fmt.Sprintf("Decimal(%d, %d)", q.Precision, q.Scale), nil +} + +// SEE ALSO: qvalue/kind.go ToDWHColumnType +func (q QField) ToDWHColumnType(ctx context.Context, env map[string]string, dwhType protos.DBType) (string, error) { + switch dwhType { + case protos.DBType_SNOWFLAKE: + if val, ok := QValueKindToSnowflakeTypeMap[q.Type]; ok { + return val, nil + } else if q.Type == QValueKindNumeric { + return fmt.Sprintf("NUMERIC(%d,%d)", q.Precision, q.Scale), nil + } else { + return "STRING", nil + } + case protos.DBType_CLICKHOUSE: + if val, ok := QValueKindToClickHouseTypeMap[q.Type]; ok { + return q.getClickHouseTypeForNumericField(ctx, env) + } else if q.Type == QValueKindNumeric { + return val, nil + } else { + return "String", nil + } + default: + return "", fmt.Errorf("unknown dwh type: %v", dwhType) + } +}