diff --git a/flow/activities/flowable_core.go b/flow/activities/flowable_core.go index 2d1f7e1f3e..a9f58a5f3f 100644 --- a/flow/activities/flowable_core.go +++ b/flow/activities/flowable_core.go @@ -225,7 +225,7 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon } defer connectors.CloseConnector(ctx, dstConn) - if err := dstConn.ReplayTableSchemaDeltas(ctx, flowName, recordBatchSync.SchemaDeltas); err != nil { + if err := dstConn.ReplayTableSchemaDeltas(ctx, config.Env, flowName, recordBatchSync.SchemaDeltas); err != nil { return nil, fmt.Errorf("failed to sync schema: %w", err) } @@ -440,6 +440,7 @@ func replicateQRepPartition[TRead any, TWrite any, TSync connectors.QRepSyncConn }) errGroup.Go(func() error { + var err error rowsSynced, err = syncRecords(dstConn, errCtx, config, partition, outstream) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index f990b2f19d..d6504322ca 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -203,6 +203,7 @@ func (c *BigQueryConnector) waitForTableReady(ctx context.Context, datasetTable // This could involve adding or dropping multiple columns. func (c *BigQueryConnector) ReplayTableSchemaDeltas( ctx context.Context, + env map[string]string, flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { diff --git a/flow/connectors/bigquery/qrep.go b/flow/connectors/bigquery/qrep.go index 3da50c8e8f..b184cc62a9 100644 --- a/flow/connectors/bigquery/qrep.go +++ b/flow/connectors/bigquery/qrep.go @@ -35,7 +35,7 @@ func (c *BigQueryConnector) SyncQRepRecords( partition.PartitionId, destTable)) avroSync := NewQRepAvroSyncMethod(c, config.StagingPath, config.FlowJobName) - return avroSync.SyncQRepRecords(ctx, config.FlowJobName, destTable, partition, + return avroSync.SyncQRepRecords(ctx, config.Env, config.FlowJobName, destTable, partition, tblMetadata, stream, config.SyncedAtColName, config.SoftDeleteColName) } @@ -80,7 +80,7 @@ func (c *BigQueryConnector) replayTableSchemaDeltasQRep( } } - err = c.ReplayTableSchemaDeltas(ctx, config.FlowJobName, []*protos.TableSchemaDelta{tableSchemaDelta}) + err = c.ReplayTableSchemaDeltas(ctx, config.Env, config.FlowJobName, []*protos.TableSchemaDelta{tableSchemaDelta}) if err != nil { return nil, fmt.Errorf("failed to add columns to destination table: %w", err) } diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index da3b15c37f..07285eb997 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -55,7 +55,7 @@ func (s *QRepAvroSyncMethod) SyncRecords( } stagingTable := fmt.Sprintf("%s_%s_staging", rawTableName, strconv.FormatInt(syncBatchID, 10)) - numRecords, err := s.writeToStage(ctx, strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema, + numRecords, err := s.writeToStage(ctx, req.Env, strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema, &datasetTable{ project: s.connector.projectID, dataset: s.connector.datasetID, @@ -97,7 +97,7 @@ func (s *QRepAvroSyncMethod) SyncRecords( slog.String(string(shared.FlowNameKey), req.FlowJobName), slog.String("dstTableName", rawTableName)) - err = s.connector.ReplayTableSchemaDeltas(ctx, req.FlowJobName, req.Records.SchemaDeltas) + err = s.connector.ReplayTableSchemaDeltas(ctx, req.Env, req.FlowJobName, req.Records.SchemaDeltas) if err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } @@ -139,6 +139,7 @@ func getTransformedColumns(dstSchema *bigquery.Schema, syncedAtCol string, softD func (s *QRepAvroSyncMethod) SyncQRepRecords( ctx context.Context, + env map[string]string, flowJobName string, dstTableName string, partition *protos.QRepPartition, @@ -167,7 +168,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( table: fmt.Sprintf("%s_%s_staging", dstDatasetTable.table, strings.ReplaceAll(partition.PartitionId, "-", "_")), } - numRecords, err := s.writeToStage(ctx, partition.PartitionId, flowJobName, avroSchema, + numRecords, err := s.writeToStage(ctx, env, partition.PartitionId, flowJobName, avroSchema, stagingDatasetTable, stream, flowJobName) if err != nil { return -1, fmt.Errorf("failed to push to avro stage: %w", err) @@ -389,6 +390,7 @@ func GetAvroField(bqField *bigquery.FieldSchema) (AvroField, error) { func (s *QRepAvroSyncMethod) writeToStage( ctx context.Context, + env map[string]string, syncID string, objectFolder string, avroSchema *model.QRecordAvroSchemaDefinition, @@ -408,7 +410,7 @@ func (s *QRepAvroSyncMethod) writeToStage( obj := bucket.Object(avroFilePath) w := obj.NewWriter(ctx) - numRecords, err := ocfWriter.WriteOCF(ctx, w) + numRecords, err := ocfWriter.WriteOCF(ctx, env, w) if err != nil { return 0, fmt.Errorf("failed to write records to Avro file on GCS: %w", err) } @@ -426,7 +428,7 @@ func (s *QRepAvroSyncMethod) writeToStage( avroFilePath := fmt.Sprintf("%s/%s.avro", tmpDir, syncID) s.connector.logger.Info("writing records to local file", idLog) - avroFile, err = ocfWriter.WriteRecordsToAvroFile(ctx, avroFilePath) + avroFile, err = ocfWriter.WriteRecordsToAvroFile(ctx, env, avroFilePath) if err != nil { return 0, fmt.Errorf("failed to write records to local Avro file: %w", err) } diff --git a/flow/connectors/clickhouse/cdc.go b/flow/connectors/clickhouse/cdc.go index d3eb883b46..5dc8a14628 100644 --- a/flow/connectors/clickhouse/cdc.go +++ b/flow/connectors/clickhouse/cdc.go @@ -93,7 +93,7 @@ func (c *ClickHouseConnector) syncRecordsViaAvro( return nil, err } - if err := c.ReplayTableSchemaDeltas(ctx, req.FlowJobName, req.Records.SchemaDeltas); err != nil { + if err := c.ReplayTableSchemaDeltas(ctx, req.Env, req.FlowJobName, req.Records.SchemaDeltas); err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } @@ -120,7 +120,10 @@ func (c *ClickHouseConnector) SyncRecords(ctx context.Context, req *model.SyncRe return res, nil } -func (c *ClickHouseConnector) ReplayTableSchemaDeltas(ctx context.Context, flowJobName string, +func (c *ClickHouseConnector) ReplayTableSchemaDeltas( + ctx context.Context, + env map[string]string, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { if len(schemaDeltas) == 0 { @@ -133,7 +136,7 @@ func (c *ClickHouseConnector) ReplayTableSchemaDeltas(ctx context.Context, flowJ } for _, addedColumn := range schemaDelta.AddedColumns { - clickHouseColType, err := qvalue.QValueKind(addedColumn.Type).ToDWHColumnType(protos.DBType_CLICKHOUSE) + clickHouseColType, err := qvalue.QValueKind(addedColumn.Type).ToDWHColumnType(ctx, env, protos.DBType_CLICKHOUSE, addedColumn) if err != nil { return fmt.Errorf("failed to convert column type %s to ClickHouse type: %w", addedColumn.Type, err) } diff --git a/flow/connectors/clickhouse/normalize.go b/flow/connectors/clickhouse/normalize.go index 2debe0f4d5..fabe07a35f 100644 --- a/flow/connectors/clickhouse/normalize.go +++ b/flow/connectors/clickhouse/normalize.go @@ -15,7 +15,6 @@ import ( "github.com/ClickHouse/clickhouse-go/v2" "golang.org/x/sync/errgroup" - "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -81,16 +80,6 @@ func getColName(overrides map[string]string, name string) string { return name } -func getClickhouseTypeForNumericColumn(column *protos.FieldDescription) string { - rawPrecision, _ := datatypes.ParseNumericTypmod(column.TypeModifier) - if rawPrecision > datatypes.PeerDBClickHouseMaxPrecision { - return "String" - } else { - precision, scale := datatypes.GetNumericTypeForWarehouse(column.TypeModifier, datatypes.ClickHouseNumericCompatibility{}) - return fmt.Sprintf("Decimal(%d, %d)", precision, scale) - } -} - func generateCreateTableSQLForNormalizedTable( ctx context.Context, config *protos.SetupNormalizedTableBatchInput, @@ -142,14 +131,10 @@ func generateCreateTableSQLForNormalizedTable( } if clickHouseType == "" { - if colType == qvalue.QValueKindNumeric { - clickHouseType = getClickhouseTypeForNumericColumn(column) - } else { - var err error - clickHouseType, err = colType.ToDWHColumnType(protos.DBType_CLICKHOUSE) - if err != nil { - return "", fmt.Errorf("error while converting column type to ClickHouse type: %w", err) - } + var err error + clickHouseType, err = colType.ToDWHColumnType(ctx, config.Env, protos.DBType_CLICKHOUSE, column) + if err != nil { + return "", fmt.Errorf("error while converting column type to ClickHouse type: %w", err) } } if (tableSchema.NullableEnabled || columnNullableEnabled) && column.Nullable && !colType.IsArray() { @@ -368,16 +353,13 @@ func (c *ClickHouseConnector) NormalizeRecords( colSelector.WriteString(fmt.Sprintf("`%s`,", dstColName)) if clickHouseType == "" { - if colType == qvalue.QValueKindNumeric { - clickHouseType = getClickhouseTypeForNumericColumn(column) - } else { - var err error - clickHouseType, err = colType.ToDWHColumnType(protos.DBType_CLICKHOUSE) - if err != nil { - close(queries) - return nil, fmt.Errorf("error while converting column type to clickhouse type: %w", err) - } + var err error + clickHouseType, err = colType.ToDWHColumnType(ctx, req.Env, protos.DBType_CLICKHOUSE, column) + if err != nil { + close(queries) + return nil, fmt.Errorf("error while converting column type to clickhouse type: %w", err) } + if (schema.NullableEnabled || columnNullableEnabled) && column.Nullable && !colType.IsArray() { clickHouseType = fmt.Sprintf("Nullable(%s)", clickHouseType) } diff --git a/flow/connectors/clickhouse/qrep_avro_sync.go b/flow/connectors/clickhouse/qrep_avro_sync.go index fa2cfe1034..61450dd55c 100644 --- a/flow/connectors/clickhouse/qrep_avro_sync.go +++ b/flow/connectors/clickhouse/qrep_avro_sync.go @@ -71,7 +71,7 @@ func (s *ClickHouseAvroSyncMethod) SyncRecords( s.logger.Info("sync function called and schema acquired", slog.String("dstTable", dstTableName)) - avroSchema, err := s.getAvroSchema(dstTableName, schema) + avroSchema, err := s.getAvroSchema(ctx, env, dstTableName, schema) if err != nil { return 0, err } @@ -106,7 +106,7 @@ func (s *ClickHouseAvroSyncMethod) SyncQRepRecords( stagingPath := s.credsProvider.BucketPath startTime := time.Now() - avroSchema, err := s.getAvroSchema(dstTableName, stream.Schema()) + avroSchema, err := s.getAvroSchema(ctx, config.Env, dstTableName, stream.Schema()) if err != nil { return 0, err } @@ -165,10 +165,12 @@ func (s *ClickHouseAvroSyncMethod) SyncQRepRecords( } func (s *ClickHouseAvroSyncMethod) getAvroSchema( + ctx context.Context, + env map[string]string, dstTableName string, schema qvalue.QRecordSchema, ) (*model.QRecordAvroSchemaDefinition, error) { - avroSchema, err := model.GetAvroSchemaDefinition(dstTableName, schema, protos.DBType_CLICKHOUSE) + avroSchema, err := model.GetAvroSchemaDefinition(ctx, env, dstTableName, schema, protos.DBType_CLICKHOUSE) if err != nil { return nil, fmt.Errorf("failed to define Avro schema: %w", err) } diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 073d9d82b4..0991a50978 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -173,7 +173,7 @@ type CDCSyncConnectorCore interface { // ReplayTableSchemaDelta changes a destination table to match the schema at source // This could involve adding or dropping multiple columns. // Connectors which are non-normalizing should implement this as a nop. - ReplayTableSchemaDeltas(ctx context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error + ReplayTableSchemaDeltas(ctx context.Context, env map[string]string, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error } type CDCSyncConnector interface { @@ -463,8 +463,6 @@ var ( _ CDCSyncConnector = &connclickhouse.ClickHouseConnector{} _ CDCSyncConnector = &connelasticsearch.ElasticsearchConnector{} - _ CDCSyncPgConnector = &connpostgres.PostgresConnector{} - _ CDCNormalizeConnector = &connpostgres.PostgresConnector{} _ CDCNormalizeConnector = &connbigquery.BigQueryConnector{} _ CDCNormalizeConnector = &connsnowflake.SnowflakeConnector{} diff --git a/flow/connectors/elasticsearch/elasticsearch.go b/flow/connectors/elasticsearch/elasticsearch.go index e675168051..30279fd74e 100644 --- a/flow/connectors/elasticsearch/elasticsearch.go +++ b/flow/connectors/elasticsearch/elasticsearch.go @@ -95,7 +95,7 @@ func (esc *ElasticsearchConnector) CreateRawTable(ctx context.Context, } // we handle schema changes by not handling them since no mapping is being enforced right now -func (esc *ElasticsearchConnector) ReplayTableSchemaDeltas(ctx context.Context, +func (esc *ElasticsearchConnector) ReplayTableSchemaDeltas(ctx context.Context, env map[string]string, flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { return nil diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index 01982bf713..0f175233ef 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -380,7 +380,9 @@ func (c *EventHubConnector) CreateRawTable(ctx context.Context, req *protos.Crea }, nil } -func (c *EventHubConnector) ReplayTableSchemaDeltas(_ context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error { +func (c *EventHubConnector) ReplayTableSchemaDeltas(_ context.Context, _ map[string]string, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, +) error { c.logger.Info("ReplayTableSchemaDeltas for event hub is a no-op") return nil } diff --git a/flow/connectors/kafka/kafka.go b/flow/connectors/kafka/kafka.go index ea0805b84b..ee78093fe6 100644 --- a/flow/connectors/kafka/kafka.go +++ b/flow/connectors/kafka/kafka.go @@ -149,7 +149,9 @@ func (c *KafkaConnector) SetupMetadataTables(_ context.Context) error { return nil } -func (c *KafkaConnector) ReplayTableSchemaDeltas(_ context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error { +func (c *KafkaConnector) ReplayTableSchemaDeltas(_ context.Context, _ map[string]string, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, +) error { return nil } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 14b827cc89..8f49545fff 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -592,7 +592,7 @@ func syncRecordsCore[Items model.Items]( return nil, err } - err = c.ReplayTableSchemaDeltas(ctx, req.FlowJobName, req.Records.SchemaDeltas) + err = c.ReplayTableSchemaDeltas(ctx, req.Env, req.FlowJobName, req.Records.SchemaDeltas) if err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } @@ -941,6 +941,7 @@ func (c *PostgresConnector) SetupNormalizedTable( // This could involve adding or dropping multiple columns. func (c *PostgresConnector) ReplayTableSchemaDeltas( ctx context.Context, + _ map[string]string, flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { diff --git a/flow/connectors/postgres/postgres_schema_delta_test.go b/flow/connectors/postgres/postgres_schema_delta_test.go index 946b20eb3e..0b6668a5a2 100644 --- a/flow/connectors/postgres/postgres_schema_delta_test.go +++ b/flow/connectors/postgres/postgres_schema_delta_test.go @@ -58,7 +58,7 @@ func (s PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() { fmt.Sprintf("CREATE TABLE %s(id INT PRIMARY KEY)", tableName)) require.NoError(s.t, err) - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: []*protos.FieldDescription{ @@ -113,7 +113,7 @@ func (s PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() { } } - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, @@ -144,7 +144,7 @@ func (s PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() { } } - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, @@ -175,7 +175,7 @@ func (s PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() { } } - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, diff --git a/flow/connectors/pubsub/pubsub.go b/flow/connectors/pubsub/pubsub.go index 49aed379c4..537cda7241 100644 --- a/flow/connectors/pubsub/pubsub.go +++ b/flow/connectors/pubsub/pubsub.go @@ -67,7 +67,9 @@ func (c *PubSubConnector) CreateRawTable(ctx context.Context, req *protos.Create return &protos.CreateRawTableOutput{TableIdentifier: "n/a"}, nil } -func (c *PubSubConnector) ReplayTableSchemaDeltas(_ context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error { +func (c *PubSubConnector) ReplayTableSchemaDeltas(_ context.Context, _ map[string]string, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, +) error { return nil } diff --git a/flow/connectors/s3/qrep.go b/flow/connectors/s3/qrep.go index 9fbb485ab8..968c956aab 100644 --- a/flow/connectors/s3/qrep.go +++ b/flow/connectors/s3/qrep.go @@ -20,7 +20,7 @@ func (c *S3Connector) SyncQRepRecords( schema := stream.Schema() dstTableName := config.DestinationTableIdentifier - avroSchema, err := getAvroSchema(dstTableName, schema) + avroSchema, err := getAvroSchema(ctx, config.Env, dstTableName, schema) if err != nil { return 0, err } @@ -34,10 +34,12 @@ func (c *S3Connector) SyncQRepRecords( } func getAvroSchema( + ctx context.Context, + env map[string]string, dstTableName string, schema qvalue.QRecordSchema, ) (*model.QRecordAvroSchemaDefinition, error) { - avroSchema, err := model.GetAvroSchemaDefinition(dstTableName, schema, protos.DBType_S3) + avroSchema, err := model.GetAvroSchemaDefinition(ctx, env, dstTableName, schema, protos.DBType_S3) if err != nil { return nil, fmt.Errorf("failed to define Avro schema: %w", err) } diff --git a/flow/connectors/s3/s3.go b/flow/connectors/s3/s3.go index eac37cd7c8..7d16a20af0 100644 --- a/flow/connectors/s3/s3.go +++ b/flow/connectors/s3/s3.go @@ -118,7 +118,9 @@ func (c *S3Connector) SyncRecords(ctx context.Context, req *model.SyncRecordsReq }, nil } -func (c *S3Connector) ReplayTableSchemaDeltas(_ context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error { +func (c *S3Connector) ReplayTableSchemaDeltas(_ context.Context, _ map[string]string, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, +) error { c.logger.Info("ReplayTableSchemaDeltas for S3 is a no-op") return nil } diff --git a/flow/connectors/snowflake/avro_file_writer_test.go b/flow/connectors/snowflake/avro_file_writer_test.go index ac6f253517..4a76fccd01 100644 --- a/flow/connectors/snowflake/avro_file_writer_test.go +++ b/flow/connectors/snowflake/avro_file_writer_test.go @@ -144,14 +144,14 @@ func TestWriteRecordsToAvroFileHappyPath(t *testing.T) { // Define sample data records, schema := generateRecords(t, true, 10, false) - avroSchema, err := model.GetAvroSchemaDefinition("not_applicable", schema, protos.DBType_SNOWFLAKE) + avroSchema, err := model.GetAvroSchemaDefinition(context.Background(), nil, "not_applicable", schema, protos.DBType_SNOWFLAKE) require.NoError(t, err) t.Logf("[test] avroSchema: %v", avroSchema) // Call function writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressNone, protos.DBType_SNOWFLAKE) - _, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name()) + _, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name()) require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors") // Check file is not empty @@ -171,14 +171,14 @@ func TestWriteRecordsToZstdAvroFileHappyPath(t *testing.T) { // Define sample data records, schema := generateRecords(t, true, 10, false) - avroSchema, err := model.GetAvroSchemaDefinition("not_applicable", schema, protos.DBType_SNOWFLAKE) + avroSchema, err := model.GetAvroSchemaDefinition(context.Background(), nil, "not_applicable", schema, protos.DBType_SNOWFLAKE) require.NoError(t, err) t.Logf("[test] avroSchema: %v", avroSchema) // Call function writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressZstd, protos.DBType_SNOWFLAKE) - _, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name()) + _, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name()) require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors") // Check file is not empty @@ -198,14 +198,14 @@ func TestWriteRecordsToDeflateAvroFileHappyPath(t *testing.T) { // Define sample data records, schema := generateRecords(t, true, 10, false) - avroSchema, err := model.GetAvroSchemaDefinition("not_applicable", schema, protos.DBType_SNOWFLAKE) + avroSchema, err := model.GetAvroSchemaDefinition(context.Background(), nil, "not_applicable", schema, protos.DBType_SNOWFLAKE) require.NoError(t, err) t.Logf("[test] avroSchema: %v", avroSchema) // Call function writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressDeflate, protos.DBType_SNOWFLAKE) - _, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name()) + _, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name()) require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors") // Check file is not empty @@ -224,14 +224,14 @@ func TestWriteRecordsToAvroFileNonNull(t *testing.T) { records, schema := generateRecords(t, false, 10, false) - avroSchema, err := model.GetAvroSchemaDefinition("not_applicable", schema, protos.DBType_SNOWFLAKE) + avroSchema, err := model.GetAvroSchemaDefinition(context.Background(), nil, "not_applicable", schema, protos.DBType_SNOWFLAKE) require.NoError(t, err) t.Logf("[test] avroSchema: %v", avroSchema) // Call function writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressNone, protos.DBType_SNOWFLAKE) - _, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name()) + _, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name()) require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors") // Check file is not empty @@ -251,14 +251,14 @@ func TestWriteRecordsToAvroFileAllNulls(t *testing.T) { // Define sample data records, schema := generateRecords(t, true, 10, true) - avroSchema, err := model.GetAvroSchemaDefinition("not_applicable", schema, protos.DBType_SNOWFLAKE) + avroSchema, err := model.GetAvroSchemaDefinition(context.Background(), nil, "not_applicable", schema, protos.DBType_SNOWFLAKE) require.NoError(t, err) t.Logf("[test] avroSchema: %v", avroSchema) // Call function writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressNone, protos.DBType_SNOWFLAKE) - _, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name()) + _, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name()) require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors") // Check file is not empty diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go index 37b4ed7bdb..d87d3004f7 100644 --- a/flow/connectors/snowflake/merge_stmt_generator.go +++ b/flow/connectors/snowflake/merge_stmt_generator.go @@ -1,6 +1,7 @@ package connsnowflake import ( + "context" "fmt" "strings" @@ -24,7 +25,7 @@ type mergeStmtGenerator struct { mergeBatchId int64 } -func (m *mergeStmtGenerator) generateMergeStmt(dstTable string) (string, error) { +func (m *mergeStmtGenerator) generateMergeStmt(ctx context.Context, env map[string]string, dstTable string) (string, error) { parsedDstTable, _ := utils.ParseSchemaTable(dstTable) normalizedTableSchema := m.tableSchemaMapping[dstTable] unchangedToastColumns := m.unchangedToastColumnsMap[dstTable] @@ -34,7 +35,7 @@ func (m *mergeStmtGenerator) generateMergeStmt(dstTable string) (string, error) for _, column := range columns { genericColumnType := column.Type qvKind := qvalue.QValueKind(genericColumnType) - sfType, err := qvKind.ToDWHColumnType(protos.DBType_SNOWFLAKE) + sfType, err := qvKind.ToDWHColumnType(ctx, env, protos.DBType_SNOWFLAKE, column) if err != nil { return "", fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err) } diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index 0fea54b027..728d393e62 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -48,7 +48,7 @@ func (s *SnowflakeAvroSyncHandler) SyncRecords( s.logger.Info("sync function called and schema acquired", tableLog) - avroSchema, err := s.getAvroSchema(dstTableName, schema) + avroSchema, err := s.getAvroSchema(ctx, env, dstTableName, schema) if err != nil { return 0, err } @@ -98,12 +98,12 @@ func (s *SnowflakeAvroSyncHandler) SyncQRepRecords( schema := stream.Schema() s.logger.Info("sync function called and schema acquired", partitionLog) - err := s.addMissingColumns(ctx, schema, dstTableSchema, dstTableName, partition) + err := s.addMissingColumns(ctx, config.Env, schema, dstTableSchema, dstTableName, partition) if err != nil { return 0, err } - avroSchema, err := s.getAvroSchema(dstTableName, schema) + avroSchema, err := s.getAvroSchema(ctx, config.Env, dstTableName, schema) if err != nil { return 0, err } @@ -130,6 +130,7 @@ func (s *SnowflakeAvroSyncHandler) SyncQRepRecords( func (s *SnowflakeAvroSyncHandler) addMissingColumns( ctx context.Context, + env map[string]string, schema qvalue.QRecordSchema, dstTableSchema []*sql.ColumnType, dstTableName string, @@ -138,7 +139,7 @@ func (s *SnowflakeAvroSyncHandler) addMissingColumns( partitionLog := slog.String(string(shared.PartitionIDKey), partition.PartitionId) // check if avro schema has additional columns compared to destination table // if so, we need to add those columns to the destination table - colsToTypes := map[string]qvalue.QValueKind{} + var newColumns []qvalue.QField for _, col := range schema.Fields { hasColumn := false // check ignoring case @@ -152,24 +153,23 @@ func (s *SnowflakeAvroSyncHandler) addMissingColumns( if !hasColumn { s.logger.Info(fmt.Sprintf("adding column %s to destination table %s", col.Name, dstTableName), partitionLog) - colsToTypes[col.Name] = col.Type + newColumns = append(newColumns, col) } } - if len(colsToTypes) > 0 { + if len(newColumns) > 0 { tx, err := s.database.Begin() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } - for colName, colType := range colsToTypes { - sfColType, err := colType.ToDWHColumnType(protos.DBType_SNOWFLAKE) + for _, column := range newColumns { + sfColType, err := column.ToDWHColumnType(ctx, env, protos.DBType_SNOWFLAKE) if err != nil { return fmt.Errorf("failed to convert QValueKind to Snowflake column type: %w", err) } - upperCasedColName := strings.ToUpper(colName) - alterTableCmd := fmt.Sprintf("ALTER TABLE %s ", dstTableName) - alterTableCmd += fmt.Sprintf("ADD COLUMN IF NOT EXISTS \"%s\" %s;", upperCasedColName, sfColType) + upperCasedColName := strings.ToUpper(column.Name) + alterTableCmd := fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s;", dstTableName, upperCasedColName, sfColType) s.logger.Info(fmt.Sprintf("altering destination table %s with command `%s`", dstTableName, alterTableCmd), partitionLog) @@ -193,10 +193,12 @@ func (s *SnowflakeAvroSyncHandler) addMissingColumns( } func (s *SnowflakeAvroSyncHandler) getAvroSchema( + ctx context.Context, + env map[string]string, dstTableName string, schema qvalue.QRecordSchema, ) (*model.QRecordAvroSchemaDefinition, error) { - avroSchema, err := model.GetAvroSchemaDefinition(dstTableName, schema, protos.DBType_SNOWFLAKE) + avroSchema, err := model.GetAvroSchemaDefinition(ctx, env, dstTableName, schema, protos.DBType_SNOWFLAKE) if err != nil { return nil, fmt.Errorf("failed to define Avro schema: %w", err) } @@ -223,7 +225,7 @@ func (s *SnowflakeAvroSyncHandler) writeToAvroFile( localFilePath := fmt.Sprintf("%s/%s.avro.zst", tmpDir, partitionID) s.logger.Info("writing records to local file " + localFilePath) - avroFile, err := ocfWriter.WriteRecordsToAvroFile(ctx, localFilePath) + avroFile, err := ocfWriter.WriteRecordsToAvroFile(ctx, env, localFilePath) if err != nil { return nil, fmt.Errorf("failed to write records to Avro file: %w", err) } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 06e3fb881e..518b01ff2b 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -19,7 +19,6 @@ import ( metadataStore "github.com/PeerDB-io/peer-flow/connectors/external_metadata" "github.com/PeerDB-io/peer-flow/connectors/utils" - numeric "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -338,7 +337,7 @@ func (c *SnowflakeConnector) SetupNormalizedTable( return true, nil } - normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(config, normalizedSchemaTable, tableSchema) + normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable(ctx, config, normalizedSchemaTable, tableSchema) if _, err := c.execWithLogging(ctx, normalizedTableCreateSQL); err != nil { return false, fmt.Errorf("[sf] error while creating normalized table: %w", err) } @@ -349,6 +348,7 @@ func (c *SnowflakeConnector) SetupNormalizedTable( // This could involve adding or dropping multiple columns. func (c *SnowflakeConnector) ReplayTableSchemaDeltas( ctx context.Context, + env map[string]string, flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { @@ -374,17 +374,12 @@ func (c *SnowflakeConnector) ReplayTableSchemaDeltas( } for _, addedColumn := range schemaDelta.AddedColumns { - sfColtype, err := qvalue.QValueKind(addedColumn.Type).ToDWHColumnType(protos.DBType_SNOWFLAKE) + sfColtype, err := qvalue.QValueKind(addedColumn.Type).ToDWHColumnType(ctx, env, protos.DBType_SNOWFLAKE, addedColumn) if err != nil { return fmt.Errorf("failed to convert column type %s to snowflake type: %w", addedColumn.Type, err) } - if addedColumn.Type == string(qvalue.QValueKindNumeric) { - precision, scale := numeric.GetNumericTypeForWarehouse(addedColumn.TypeModifier, numeric.SnowflakeNumericCompatibility{}) - sfColtype = fmt.Sprintf("NUMERIC(%d,%d)", precision, scale) - } - _, err = tableSchemaModifyTx.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s", schemaDelta.DstTableName, strings.ToUpper(addedColumn.Name), sfColtype)) @@ -460,7 +455,7 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( return nil, err } - err = c.ReplayTableSchemaDeltas(ctx, req.FlowJobName, req.Records.SchemaDeltas) + err = c.ReplayTableSchemaDeltas(ctx, req.Env, req.FlowJobName, req.Records.SchemaDeltas) if err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } @@ -557,7 +552,7 @@ func (c *SnowflakeConnector) mergeTablesForBatch( } g.Go(func() error { - mergeStatement, err := mergeGen.generateMergeStmt(tableName) + mergeStatement, err := mergeGen.generateMergeStmt(gCtx, env, tableName) if err != nil { return err } @@ -666,6 +661,7 @@ func (c *SnowflakeConnector) checkIfTableExists( } func generateCreateTableSQLForNormalizedTable( + ctx context.Context, config *protos.SetupNormalizedTableBatchInput, dstSchemaTable *utils.SchemaTable, tableSchema *protos.TableSchema, @@ -674,18 +670,13 @@ func generateCreateTableSQLForNormalizedTable( for _, column := range tableSchema.Columns { genericColumnType := column.Type normalizedColName := SnowflakeIdentifierNormalize(column.Name) - sfColType, err := qvalue.QValueKind(genericColumnType).ToDWHColumnType(protos.DBType_SNOWFLAKE) + sfColType, err := qvalue.QValueKind(genericColumnType).ToDWHColumnType(ctx, config.Env, protos.DBType_SNOWFLAKE, column) if err != nil { slog.Warn(fmt.Sprintf("failed to convert column type %s to snowflake type", genericColumnType), slog.Any("error", err)) continue } - if genericColumnType == "numeric" { - precision, scale := numeric.GetNumericTypeForWarehouse(column.TypeModifier, numeric.SnowflakeNumericCompatibility{}) - sfColType = fmt.Sprintf("NUMERIC(%d,%d)", precision, scale) - } - var notNull string if tableSchema.NullableEnabled && !column.Nullable { notNull = " NOT NULL" diff --git a/flow/connectors/utils/avro/avro_writer.go b/flow/connectors/utils/avro/avro_writer.go index ee72e2c28b..75bc9f4358 100644 --- a/flow/connectors/utils/avro/avro_writer.go +++ b/flow/connectors/utils/avro/avro_writer.go @@ -127,16 +127,21 @@ func (p *peerDBOCFWriter) createOCFWriter(w io.Writer) (*goavro.OCFWriter, error return ocfWriter, nil } -func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter *goavro.OCFWriter) (int64, error) { +func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, env map[string]string, ocfWriter *goavro.OCFWriter) (int64, error) { logger := shared.LoggerFromCtx(ctx) schema := p.stream.Schema() - avroConverter := model.NewQRecordAvroConverter( + avroConverter, err := model.NewQRecordAvroConverter( + ctx, + env, p.avroSchema, p.targetDWH, schema.GetColumnNames(), logger, ) + if err != nil { + return 0, err + } numRows := atomic.Int64{} @@ -147,7 +152,7 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter for qrecord := range p.stream.Records { if err := ctx.Err(); err != nil { - return numRows.Load(), ctx.Err() + return numRows.Load(), err } else { avroMap, err := avroConverter.Convert(qrecord) if err != nil { @@ -172,7 +177,7 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter return numRows.Load(), nil } -func (p *peerDBOCFWriter) WriteOCF(ctx context.Context, w io.Writer) (int, error) { +func (p *peerDBOCFWriter) WriteOCF(ctx context.Context, env map[string]string, w io.Writer) (int, error) { ocfWriter, err := p.createOCFWriter(w) if err != nil { return 0, fmt.Errorf("failed to create OCF writer: %w", err) @@ -180,7 +185,7 @@ func (p *peerDBOCFWriter) WriteOCF(ctx context.Context, w io.Writer) (int, error // we have to keep a reference to the underlying writer as goavro doesn't provide any access to it defer p.writer.Close() - numRows, err := p.writeRecordsToOCFWriter(ctx, ocfWriter) + numRows, err := p.writeRecordsToOCFWriter(ctx, env, ocfWriter) if err != nil { return 0, fmt.Errorf("failed to write records to OCF writer: %w", err) } @@ -217,7 +222,7 @@ func (p *peerDBOCFWriter) WriteRecordsToS3( } w.Close() }() - numRows, writeOcfError = p.WriteOCF(ctx, w) + numRows, writeOcfError = p.WriteOCF(ctx, env, w) }() partSize, err := peerdbenv.PeerDBS3PartSize(ctx, env) @@ -254,7 +259,7 @@ func (p *peerDBOCFWriter) WriteRecordsToS3( }, nil } -func (p *peerDBOCFWriter) WriteRecordsToAvroFile(ctx context.Context, filePath string) (*AvroFile, error) { +func (p *peerDBOCFWriter) WriteRecordsToAvroFile(ctx context.Context, env map[string]string, filePath string) (*AvroFile, error) { file, err := os.Create(filePath) if err != nil { return nil, fmt.Errorf("failed to create temporary Avro file: %w", err) @@ -275,7 +280,7 @@ func (p *peerDBOCFWriter) WriteRecordsToAvroFile(ctx context.Context, filePath s bufferedWriter := bufio.NewWriterSize(file, buffSizeBytes) defer bufferedWriter.Flush() - numRecords, err := p.WriteOCF(ctx, bufferedWriter) + numRecords, err := p.WriteOCF(ctx, env, bufferedWriter) if err != nil { return nil, fmt.Errorf("failed to write records to temporary Avro file: %w", err) } diff --git a/flow/datatypes/numeric.go b/flow/datatypes/numeric.go index 56c1b17839..8b942e4f67 100644 --- a/flow/datatypes/numeric.go +++ b/flow/datatypes/numeric.go @@ -90,6 +90,10 @@ func MakeNumericTypmod(precision int32, scale int32) int32 { // This is to reverse what make_numeric_typmod of Postgres does: // https://github.com/postgres/postgres/blob/21912e3c0262e2cfe64856e028799d6927862563/src/backend/utils/adt/numeric.c#L897 func ParseNumericTypmod(typmod int32) (int16, int16) { + if typmod == -1 { + return 0, 0 + } + offsetMod := typmod - VARHDRSZ precision := int16((offsetMod >> 16) & 0x7FFF) scale := int16(offsetMod & 0x7FFF) @@ -102,6 +106,14 @@ func GetNumericTypeForWarehouse(typmod int32, warehouseNumeric WarehouseNumericC } precision, scale := ParseNumericTypmod(typmod) + return GetNumericTypeForWarehousePrecisionScale(precision, scale, warehouseNumeric) +} + +func GetNumericTypeForWarehousePrecisionScale(precision int16, scale int16, warehouseNumeric WarehouseNumericCompatibility) (int16, int16) { + if precision == 0 && scale == 0 { + return warehouseNumeric.DefaultPrecisionAndScale() + } + if !IsValidPrecision(precision, warehouseNumeric) { precision = warehouseNumeric.MaxPrecision() } diff --git a/flow/e2e/clickhouse/peer_flow_ch_test.go b/flow/e2e/clickhouse/peer_flow_ch_test.go index 9c4fa2a167..a19e69c8c7 100644 --- a/flow/e2e/clickhouse/peer_flow_ch_test.go +++ b/flow/e2e/clickhouse/peer_flow_ch_test.go @@ -4,6 +4,7 @@ import ( "context" "embed" "fmt" + "strconv" "strings" "testing" "time" @@ -11,7 +12,7 @@ import ( "github.com/shopspring/decimal" "github.com/stretchr/testify/require" - "github.com/PeerDB-io/peer-flow/connectors/clickhouse" + connclickhouse "github.com/PeerDB-io/peer-flow/connectors/clickhouse" "github.com/PeerDB-io/peer-flow/e2e" "github.com/PeerDB-io/peer-flow/e2eshared" "github.com/PeerDB-io/peer-flow/generated/protos" @@ -557,8 +558,8 @@ func (s ClickHouseSuite) Test_Large_Numeric() { `, srcFullName)) require.NoError(s.t, err) - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(c1,c2) VALUES(%s,%s);`, srcFullName, strings.Repeat("7", 76), strings.Repeat("9", 78))) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf("INSERT INTO %s(c1,c2) VALUES($1,$2)", srcFullName), + strings.Repeat("7", 76), strings.Repeat("9", 78)) require.NoError(s.t, err) connectionGen := e2e.FlowConnectionGenerationConfig{ @@ -568,14 +569,15 @@ func (s ClickHouseSuite) Test_Large_Numeric() { } flowConnConfig := connectionGen.GenerateFlowConnectionConfigs(s.t) flowConnConfig.DoInitialSnapshot = true + tc := e2e.NewTemporalClient(s.t) env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) e2e.SetupCDCFlowStatusQuery(s.t, env, flowConnConfig) e2e.EnvWaitForCount(env, s, "waiting for CDC count", dstTableName, "id,c1,c2", 1) - _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` - INSERT INTO %s(c1,c2) VALUES(%s,%s);`, srcFullName, strings.Repeat("7", 76), strings.Repeat("9", 78))) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf("INSERT INTO %s(c1,c2) VALUES($1,$2)", srcFullName), + strings.Repeat("7", 76), strings.Repeat("9", 78)) require.NoError(s.t, err) e2e.EnvWaitForCount(env, s, "waiting for CDC count", dstTableName, "id,c1,c2", 2) @@ -598,3 +600,67 @@ func (s ClickHouseSuite) Test_Large_Numeric() { env.Cancel() e2e.RequireEnvCanceled(s.t, env) } + +// Unbounded NUMERICs (no precision, scale specified) are mapped to String on CH if FF enabled, Decimal if not +func (s ClickHouseSuite) testNumericFF(ffValue bool) { + nines := strings.Repeat("9", 38) + dstTableName := fmt.Sprintf("unumeric_ff_%v", ffValue) + srcFullName := s.attachSchemaSuffix(dstTableName) + + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s( + id INT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, + c numeric + ); + `, srcFullName)) + require.NoError(s.t, err) + + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf("INSERT INTO %s(c) VALUES($1)", srcFullName), nines) + require.NoError(s.t, err) + + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: s.attachSuffix(fmt.Sprintf("clickhouse_test_unbounded_numerics_ff_%v", ffValue)), + TableNameMapping: map[string]string{srcFullName: dstTableName}, + Destination: s.Peer().Name, + } + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs(s.t) + flowConnConfig.DoInitialSnapshot = true + flowConnConfig.Env = map[string]string{"PEERDB_CLICKHOUSE_UNBOUNDED_NUMERIC_AS_STRING": strconv.FormatBool(ffValue)} + tc := e2e.NewTemporalClient(s.t) + env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) + e2e.SetupCDCFlowStatusQuery(s.t, env, flowConnConfig) + + e2e.EnvWaitForCount(env, s, "waiting for CDC count", dstTableName, "id,c", 1) + + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf("INSERT INTO %s(c) VALUES($1)", srcFullName), nines) + require.NoError(s.t, err) + + e2e.EnvWaitForCount(env, s, "waiting for CDC count", dstTableName, "id,c", 2) + + rows, err := s.GetRows(dstTableName, "c") + require.NoError(s.t, err) + require.Len(s.t, rows.Records, 2, "expected 2 rows") + for _, row := range rows.Records { + require.Len(s.t, row, 1, "expected 1 column") + if ffValue { + c, ok := row[0].Value().(string) + require.True(s.t, ok, "expected unbounded NUMERIC to be String") + require.Equal(s.t, nines, c, "expected unbounded NUMERIC to be 9s") + } else { + c, ok := row[0].Value().(decimal.Decimal) + require.True(s.t, ok, "expected unbounded NUMERIC to be Decimal") + require.Equal(s.t, nines, c.String(), "expected unbounded NUMERIC to be 9s") + } + } + + env.Cancel() + e2e.RequireEnvCanceled(s.t, env) +} + +func (s ClickHouseSuite) Test_Unbounded_Numeric_With_FF() { + s.testNumericFF(true) +} + +func (s ClickHouseSuite) Test_Unbounded_Numeric_Without_FF() { + s.testNumericFF(false) +} diff --git a/flow/e2e/snowflake/snowflake_schema_delta_test.go b/flow/e2e/snowflake/snowflake_schema_delta_test.go index 32cb03b644..ada2b10f6a 100644 --- a/flow/e2e/snowflake/snowflake_schema_delta_test.go +++ b/flow/e2e/snowflake/snowflake_schema_delta_test.go @@ -53,7 +53,7 @@ func (s SnowflakeSchemaDeltaTestSuite) TestSimpleAddColumn() { err := s.sfTestHelper.RunCommand(fmt.Sprintf("CREATE TABLE %s(ID TEXT PRIMARY KEY)", tableName)) require.NoError(s.t, err) - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: []*protos.FieldDescription{ @@ -167,7 +167,7 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddAllColumnTypes() { } } - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, @@ -246,7 +246,7 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddTrickyColumnNames() { } } - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, @@ -301,7 +301,7 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddWhitespaceColumnNames() { } } - err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), nil, "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, diff --git a/flow/model/conversion_avro.go b/flow/model/conversion_avro.go index 8f52c44611..ec7cfc6e37 100644 --- a/flow/model/conversion_avro.go +++ b/flow/model/conversion_avro.go @@ -1,6 +1,7 @@ package model import ( + "context" "encoding/json" "fmt" @@ -8,38 +9,52 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model/qvalue" + "github.com/PeerDB-io/peer-flow/peerdbenv" ) type QRecordAvroConverter struct { - logger log.Logger - Schema *QRecordAvroSchemaDefinition - ColNames []string - TargetDWH protos.DBType + logger log.Logger + Schema *QRecordAvroSchemaDefinition + ColNames []string + TargetDWH protos.DBType + UnboundedNumericAsString bool } func NewQRecordAvroConverter( + ctx context.Context, + env map[string]string, schema *QRecordAvroSchemaDefinition, targetDWH protos.DBType, colNames []string, logger log.Logger, -) *QRecordAvroConverter { - return &QRecordAvroConverter{ - Schema: schema, - TargetDWH: targetDWH, - ColNames: colNames, - logger: logger, +) (*QRecordAvroConverter, error) { + var unboundedNumericAsString bool + if targetDWH == protos.DBType_CLICKHOUSE { + var err error + unboundedNumericAsString, err = peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env) + if err != nil { + return nil, err + } } -} -func (qac *QRecordAvroConverter) Convert(qrecord []qvalue.QValue) (map[string]interface{}, error) { - m := make(map[string]interface{}, len(qrecord)) + return &QRecordAvroConverter{ + Schema: schema, + TargetDWH: targetDWH, + ColNames: colNames, + logger: logger, + UnboundedNumericAsString: unboundedNumericAsString, + }, nil +} +func (qac *QRecordAvroConverter) Convert(qrecord []qvalue.QValue) (map[string]any, error) { + m := make(map[string]any, len(qrecord)) for idx, val := range qrecord { avroVal, err := qvalue.QValueToAvro( val, &qac.Schema.Fields[idx], qac.TargetDWH, qac.logger, + qac.UnboundedNumericAsString, ) if err != nil { return nil, fmt.Errorf("failed to convert QValue to Avro-compatible value: %w", err) @@ -52,8 +67,8 @@ func (qac *QRecordAvroConverter) Convert(qrecord []qvalue.QValue) (map[string]in } type QRecordAvroField struct { - Type interface{} `json:"type"` - Name string `json:"name"` + Type any `json:"type"` + Name string `json:"name"` } type QRecordAvroSchema struct { @@ -68,6 +83,8 @@ type QRecordAvroSchemaDefinition struct { } func GetAvroSchemaDefinition( + ctx context.Context, + env map[string]string, dstTableName string, qRecordSchema qvalue.QRecordSchema, targetDWH protos.DBType, @@ -75,7 +92,7 @@ func GetAvroSchemaDefinition( avroFields := make([]QRecordAvroField, 0, len(qRecordSchema.Fields)) for _, qField := range qRecordSchema.Fields { - avroType, err := qvalue.GetAvroSchemaFromQValueKind(qField.Type, targetDWH, qField.Precision, qField.Scale) + avroType, err := qvalue.GetAvroSchemaFromQValueKind(ctx, env, qField.Type, targetDWH, qField.Precision, qField.Scale) if err != nil { return nil, err } diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index 97d9641b6b..db5bf4e2af 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -1,6 +1,7 @@ package qvalue import ( + "context" "encoding/base64" "errors" "fmt" @@ -14,6 +15,7 @@ import ( "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/peerdbenv" ) type AvroSchemaField struct { @@ -74,7 +76,14 @@ func TruncateOrLogNumeric(num decimal.Decimal, precision int16, scale int16, tar // // For example, QValueKindInt64 would return an AvroLogicalSchema of "long". Unsupported QValueKinds // will return an error. -func GetAvroSchemaFromQValueKind(kind QValueKind, targetDWH protos.DBType, precision int16, scale int16) (interface{}, error) { +func GetAvroSchemaFromQValueKind( + ctx context.Context, + env map[string]string, + kind QValueKind, + targetDWH protos.DBType, + precision int16, + scale int16, +) (interface{}, error) { switch kind { case QValueKindString: return "string", nil @@ -103,9 +112,19 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, targetDWH protos.DBType, preci } return "bytes", nil case QValueKindNumeric: - if targetDWH == protos.DBType_CLICKHOUSE && - precision > datatypes.PeerDBClickHouseMaxPrecision { - return "string", nil + if targetDWH == protos.DBType_CLICKHOUSE { + if precision == 0 && scale == 0 { + asString, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env) + if err != nil { + return nil, err + } + if asString { + return "string", nil + } + } + if precision > datatypes.PeerDBClickHouseMaxPrecision { + return "string", nil + } } avroNumericPrecision, avroNumericScale := DetermineNumericSettingForDWH(precision, scale, targetDWH) return AvroSchemaNumeric{ @@ -212,19 +231,24 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, targetDWH protos.DBType, preci type QValueAvroConverter struct { *QField - logger log.Logger - TargetDWH protos.DBType + logger log.Logger + TargetDWH protos.DBType + UnboundedNumericAsString bool } -func QValueToAvro(value QValue, field *QField, targetDWH protos.DBType, logger log.Logger) (interface{}, error) { +func QValueToAvro( + value QValue, field *QField, targetDWH protos.DBType, logger log.Logger, + unboundedNumericAsString bool, +) (any, error) { if value.Value() == nil { return nil, nil } - c := &QValueAvroConverter{ - QField: field, - TargetDWH: targetDWH, - logger: logger, + c := QValueAvroConverter{ + QField: field, + TargetDWH: targetDWH, + logger: logger, + UnboundedNumericAsString: unboundedNumericAsString, } switch v := value.(type) { @@ -456,18 +480,18 @@ func (c *QValueAvroConverter) processNullableUnion( return value, nil } -func (c *QValueAvroConverter) processNumeric(num decimal.Decimal) interface{} { +func (c *QValueAvroConverter) processNumeric(num decimal.Decimal) any { + if (c.UnboundedNumericAsString && c.Precision == 0 && c.Scale == 0) || + (c.TargetDWH == protos.DBType_CLICKHOUSE && c.Precision > datatypes.PeerDBClickHouseMaxPrecision) { + numStr, _ := c.processNullableUnion("string", num.String()) + return numStr + } + num, err := TruncateOrLogNumeric(num, c.Precision, c.Scale, c.TargetDWH) if err != nil { return nil } - if c.TargetDWH == protos.DBType_CLICKHOUSE && - c.Precision > datatypes.PeerDBClickHouseMaxPrecision { - // no error returned - numStr, _ := c.processNullableUnion("string", num.String()) - return numStr - } rat := num.Rat() if c.Nullable { return goavro.Union("bytes.decimal", rat) diff --git a/flow/model/qvalue/dwh.go b/flow/model/qvalue/dwh.go index 49c359b885..b2d085acb4 100644 --- a/flow/model/qvalue/dwh.go +++ b/flow/model/qvalue/dwh.go @@ -5,24 +5,24 @@ import ( "go.temporal.io/sdk/log" - numeric "github.com/PeerDB-io/peer-flow/datatypes" + "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" ) func DetermineNumericSettingForDWH(precision int16, scale int16, dwh protos.DBType) (int16, int16) { - var warehouseNumeric numeric.WarehouseNumericCompatibility + var warehouseNumeric datatypes.WarehouseNumericCompatibility switch dwh { case protos.DBType_CLICKHOUSE: - warehouseNumeric = numeric.ClickHouseNumericCompatibility{} + warehouseNumeric = datatypes.ClickHouseNumericCompatibility{} case protos.DBType_SNOWFLAKE: - warehouseNumeric = numeric.SnowflakeNumericCompatibility{} + warehouseNumeric = datatypes.SnowflakeNumericCompatibility{} case protos.DBType_BIGQUERY: - warehouseNumeric = numeric.BigQueryNumericCompatibility{} + warehouseNumeric = datatypes.BigQueryNumericCompatibility{} default: - warehouseNumeric = numeric.DefaultNumericCompatibility{} + warehouseNumeric = datatypes.DefaultNumericCompatibility{} } - return numeric.GetNumericTypeForWarehouse(numeric.MakeNumericTypmod(int32(precision), int32(scale)), warehouseNumeric) + return datatypes.GetNumericTypeForWarehousePrecisionScale(precision, scale, warehouseNumeric) } // Bigquery will not allow timestamp if it is less than 1AD and more than 9999AD diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go index 91ab867a0e..3cffcc274a 100644 --- a/flow/model/qvalue/kind.go +++ b/flow/model/qvalue/kind.go @@ -1,10 +1,13 @@ package qvalue import ( + "context" "fmt" "strings" + "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/peerdbenv" ) type QValueKind string @@ -68,7 +71,6 @@ var QValueKindToSnowflakeTypeMap = map[QValueKind]string{ QValueKindInt64: "INTEGER", QValueKindFloat32: "FLOAT", QValueKindFloat64: "FLOAT", - QValueKindNumeric: "NUMBER(38, 9)", QValueKindQChar: "CHAR", QValueKindString: "STRING", QValueKindJSON: "VARIANT", @@ -110,7 +112,6 @@ var QValueKindToClickHouseTypeMap = map[QValueKind]string{ QValueKindInt64: "Int64", QValueKindFloat32: "Float32", QValueKindFloat64: "Float64", - QValueKindNumeric: "Decimal128(9)", QValueKindQChar: "FixedString(1)", QValueKindString: "String", QValueKindJSON: "String", @@ -140,16 +141,39 @@ var QValueKindToClickHouseTypeMap = map[QValueKind]string{ QValueKindArrayJSONB: "String", } -func (kind QValueKind) ToDWHColumnType(dwhType protos.DBType) (string, error) { +func getClickHouseTypeForNumericColumn(ctx context.Context, env map[string]string, column *protos.FieldDescription) (string, error) { + if column.TypeModifier == -1 { + numericAsStringEnabled, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env) + if err != nil { + return "", err + } + if numericAsStringEnabled { + return "String", nil + } + } else if rawPrecision, _ := datatypes.ParseNumericTypmod(column.TypeModifier); rawPrecision > datatypes.PeerDBClickHouseMaxPrecision { + return "String", nil + } + precision, scale := datatypes.GetNumericTypeForWarehouse(column.TypeModifier, datatypes.ClickHouseNumericCompatibility{}) + return fmt.Sprintf("Decimal(%d, %d)", precision, scale), nil +} + +// SEE ALSO: QField ToDWHColumnType +func (kind QValueKind) ToDWHColumnType(ctx context.Context, env map[string]string, dwhType protos.DBType, column *protos.FieldDescription, +) (string, error) { switch dwhType { case protos.DBType_SNOWFLAKE: - if val, ok := QValueKindToSnowflakeTypeMap[kind]; ok { + if kind == QValueKindNumeric { + precision, scale := datatypes.GetNumericTypeForWarehouse(column.TypeModifier, datatypes.SnowflakeNumericCompatibility{}) + return fmt.Sprintf("NUMERIC(%d,%d)", precision, scale), nil + } else if val, ok := QValueKindToSnowflakeTypeMap[kind]; ok { return val, nil } else { return "STRING", nil } case protos.DBType_CLICKHOUSE: - if val, ok := QValueKindToClickHouseTypeMap[kind]; ok { + if kind == QValueKindNumeric { + return getClickHouseTypeForNumericColumn(ctx, env, column) + } else if val, ok := QValueKindToClickHouseTypeMap[kind]; ok { return val, nil } else { return "String", nil diff --git a/flow/model/qvalue/qschema.go b/flow/model/qvalue/qschema.go index a956968ac1..a6632fdf5f 100644 --- a/flow/model/qvalue/qschema.go +++ b/flow/model/qvalue/qschema.go @@ -1,7 +1,13 @@ package qvalue import ( + "context" + "fmt" "strings" + + "github.com/PeerDB-io/peer-flow/datatypes" + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/peerdbenv" ) type QField struct { @@ -47,3 +53,42 @@ func (q QRecordSchema) GetColumnNames() []string { } return names } + +func (q QField) getClickHouseTypeForNumericField(ctx context.Context, env map[string]string) (string, error) { + if q.Precision == 0 && q.Scale == 0 { + numericAsStringEnabled, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env) + if err != nil { + return "", err + } + if numericAsStringEnabled { + return "String", nil + } + } else if q.Precision > datatypes.PeerDBClickHouseMaxPrecision { + return "String", nil + } + return fmt.Sprintf("Decimal(%d, %d)", q.Precision, q.Scale), nil +} + +// SEE ALSO: qvalue/kind.go ToDWHColumnType +func (q QField) ToDWHColumnType(ctx context.Context, env map[string]string, dwhType protos.DBType) (string, error) { + switch dwhType { + case protos.DBType_SNOWFLAKE: + if val, ok := QValueKindToSnowflakeTypeMap[q.Type]; ok { + return val, nil + } else if q.Type == QValueKindNumeric { + return fmt.Sprintf("NUMERIC(%d,%d)", q.Precision, q.Scale), nil + } else { + return "STRING", nil + } + case protos.DBType_CLICKHOUSE: + if val, ok := QValueKindToClickHouseTypeMap[q.Type]; ok { + return q.getClickHouseTypeForNumericField(ctx, env) + } else if q.Type == QValueKindNumeric { + return val, nil + } else { + return "String", nil + } + default: + return "", fmt.Errorf("unknown dwh type: %v", dwhType) + } +} diff --git a/flow/peerdbenv/dynamicconf.go b/flow/peerdbenv/dynamicconf.go index b0cbe05f51..98a47d8fdc 100644 --- a/flow/peerdbenv/dynamicconf.go +++ b/flow/peerdbenv/dynamicconf.go @@ -187,6 +187,14 @@ var DynamicSettings = [...]*protos.DynamicSetting{ ApplyMode: protos.DynconfApplyMode_APPLY_MODE_IMMEDIATE, TargetForSetting: protos.DynconfTarget_CLICKHOUSE, }, + { + Name: "PEERDB_CLICKHOUSE_UNBOUNDED_NUMERIC_AS_STRING", + Description: "Map unbounded numerics in Postgres to String in ClickHouse to preserve precision and scale", + DefaultValue: "false", + ValueType: protos.DynconfValueType_BOOL, + ApplyMode: protos.DynconfApplyMode_APPLY_MODE_NEW_MIRROR, + TargetForSetting: protos.DynconfTarget_CLICKHOUSE, + }, { Name: "PEERDB_INTERVAL_SINCE_LAST_NORMALIZE_THRESHOLD_MINUTES", Description: "Duration in minutes since last normalize to start alerting, 0 disables all alerting entirely", @@ -389,6 +397,10 @@ func PeerDBClickHouseParallelNormalize(ctx context.Context, env map[string]strin return dynamicConfSigned[int](ctx, env, "PEERDB_CLICKHOUSE_PARALLEL_NORMALIZE") } +func PeerDBEnableClickHouseNumericAsString(ctx context.Context, env map[string]string) (bool, error) { + return dynamicConfBool(ctx, env, "PEERDB_CLICKHOUSE_UNBOUNDED_NUMERIC_AS_STRING") +} + func PeerDBSnowflakeMergeParallelism(ctx context.Context, env map[string]string) (int64, error) { return dynamicConfSigned[int64](ctx, env, "PEERDB_SNOWFLAKE_MERGE_PARALLELISM") } diff --git a/flow/workflows/snapshot_flow.go b/flow/workflows/snapshot_flow.go index 9b21b7b384..1db3b6d60b 100644 --- a/flow/workflows/snapshot_flow.go +++ b/flow/workflows/snapshot_flow.go @@ -208,6 +208,7 @@ func (s *SnapshotFlowExecution) cloneTable( WriteMode: snapshotWriteMode, System: s.config.System, Script: s.config.Script, + Env: s.config.Env, ParentMirrorName: flowName, }