From d33b56e355688e8c87f9274df3471957d197b617 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Mon, 2 Dec 2024 20:21:34 +0000 Subject: [PATCH] more avro consistency --- flow/connectors/bigquery/qrep.go | 2 +- flow/connectors/bigquery/qrep_avro_sync.go | 10 +++-- .../snowflake/avro_file_writer_test.go | 10 ++--- flow/connectors/snowflake/qrep_avro_sync.go | 2 +- flow/connectors/utils/avro/avro_writer.go | 21 +++++---- flow/model/conversion_avro.go | 44 ++++++++++++------- flow/model/qvalue/avro_converter.go | 42 +++++++++++------- 7 files changed, 81 insertions(+), 50 deletions(-) diff --git a/flow/connectors/bigquery/qrep.go b/flow/connectors/bigquery/qrep.go index f4dd19397..b184cc62a 100644 --- a/flow/connectors/bigquery/qrep.go +++ b/flow/connectors/bigquery/qrep.go @@ -35,7 +35,7 @@ func (c *BigQueryConnector) SyncQRepRecords( partition.PartitionId, destTable)) avroSync := NewQRepAvroSyncMethod(c, config.StagingPath, config.FlowJobName) - return avroSync.SyncQRepRecords(ctx, config.FlowJobName, destTable, partition, + return avroSync.SyncQRepRecords(ctx, config.Env, config.FlowJobName, destTable, partition, tblMetadata, stream, config.SyncedAtColName, config.SoftDeleteColName) } diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index 99322361e..07285eb99 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -55,7 +55,7 @@ func (s *QRepAvroSyncMethod) SyncRecords( } stagingTable := fmt.Sprintf("%s_%s_staging", rawTableName, strconv.FormatInt(syncBatchID, 10)) - numRecords, err := s.writeToStage(ctx, strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema, + numRecords, err := s.writeToStage(ctx, req.Env, strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema, &datasetTable{ project: s.connector.projectID, dataset: s.connector.datasetID, @@ -139,6 +139,7 @@ func getTransformedColumns(dstSchema *bigquery.Schema, syncedAtCol string, softD func (s *QRepAvroSyncMethod) SyncQRepRecords( ctx context.Context, + env map[string]string, flowJobName string, dstTableName string, partition *protos.QRepPartition, @@ -167,7 +168,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( table: fmt.Sprintf("%s_%s_staging", dstDatasetTable.table, strings.ReplaceAll(partition.PartitionId, "-", "_")), } - numRecords, err := s.writeToStage(ctx, partition.PartitionId, flowJobName, avroSchema, + numRecords, err := s.writeToStage(ctx, env, partition.PartitionId, flowJobName, avroSchema, stagingDatasetTable, stream, flowJobName) if err != nil { return -1, fmt.Errorf("failed to push to avro stage: %w", err) @@ -389,6 +390,7 @@ func GetAvroField(bqField *bigquery.FieldSchema) (AvroField, error) { func (s *QRepAvroSyncMethod) writeToStage( ctx context.Context, + env map[string]string, syncID string, objectFolder string, avroSchema *model.QRecordAvroSchemaDefinition, @@ -408,7 +410,7 @@ func (s *QRepAvroSyncMethod) writeToStage( obj := bucket.Object(avroFilePath) w := obj.NewWriter(ctx) - numRecords, err := ocfWriter.WriteOCF(ctx, w) + numRecords, err := ocfWriter.WriteOCF(ctx, env, w) if err != nil { return 0, fmt.Errorf("failed to write records to Avro file on GCS: %w", err) } @@ -426,7 +428,7 @@ func (s *QRepAvroSyncMethod) writeToStage( avroFilePath := fmt.Sprintf("%s/%s.avro", tmpDir, syncID) s.connector.logger.Info("writing records to local file", idLog) - avroFile, err = ocfWriter.WriteRecordsToAvroFile(ctx, avroFilePath) + avroFile, err = ocfWriter.WriteRecordsToAvroFile(ctx, env, avroFilePath) if err != nil { return 0, fmt.Errorf("failed to write records to local Avro file: %w", err) } diff --git a/flow/connectors/snowflake/avro_file_writer_test.go b/flow/connectors/snowflake/avro_file_writer_test.go index f3736b373..4a76fccd0 100644 --- a/flow/connectors/snowflake/avro_file_writer_test.go +++ b/flow/connectors/snowflake/avro_file_writer_test.go @@ -151,7 +151,7 @@ func TestWriteRecordsToAvroFileHappyPath(t *testing.T) { // Call function writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressNone, protos.DBType_SNOWFLAKE) - _, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name()) + _, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name()) require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors") // Check file is not empty @@ -178,7 +178,7 @@ func TestWriteRecordsToZstdAvroFileHappyPath(t *testing.T) { // Call function writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressZstd, protos.DBType_SNOWFLAKE) - _, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name()) + _, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name()) require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors") // Check file is not empty @@ -205,7 +205,7 @@ func TestWriteRecordsToDeflateAvroFileHappyPath(t *testing.T) { // Call function writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressDeflate, protos.DBType_SNOWFLAKE) - _, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name()) + _, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name()) require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors") // Check file is not empty @@ -231,7 +231,7 @@ func TestWriteRecordsToAvroFileNonNull(t *testing.T) { // Call function writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressNone, protos.DBType_SNOWFLAKE) - _, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name()) + _, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name()) require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors") // Check file is not empty @@ -258,7 +258,7 @@ func TestWriteRecordsToAvroFileAllNulls(t *testing.T) { // Call function writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressNone, protos.DBType_SNOWFLAKE) - _, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name()) + _, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name()) require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors") // Check file is not empty diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index 787a87487..728d393e6 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -225,7 +225,7 @@ func (s *SnowflakeAvroSyncHandler) writeToAvroFile( localFilePath := fmt.Sprintf("%s/%s.avro.zst", tmpDir, partitionID) s.logger.Info("writing records to local file " + localFilePath) - avroFile, err := ocfWriter.WriteRecordsToAvroFile(ctx, localFilePath) + avroFile, err := ocfWriter.WriteRecordsToAvroFile(ctx, env, localFilePath) if err != nil { return nil, fmt.Errorf("failed to write records to Avro file: %w", err) } diff --git a/flow/connectors/utils/avro/avro_writer.go b/flow/connectors/utils/avro/avro_writer.go index 13a722f61..21648bc45 100644 --- a/flow/connectors/utils/avro/avro_writer.go +++ b/flow/connectors/utils/avro/avro_writer.go @@ -127,16 +127,21 @@ func (p *peerDBOCFWriter) createOCFWriter(w io.Writer) (*goavro.OCFWriter, error return ocfWriter, nil } -func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter *goavro.OCFWriter) (int64, error) { +func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, env map[string]string, ocfWriter *goavro.OCFWriter) (int64, error) { logger := shared.LoggerFromCtx(ctx) schema := p.stream.Schema() - avroConverter := model.NewQRecordAvroConverter( + avroConverter, err := model.NewQRecordAvroConverter( + ctx, + env, p.avroSchema, p.targetDWH, schema.GetColumnNames(), logger, ) + if err != nil { + return 0, err + } numRows := atomic.Int64{} @@ -149,7 +154,7 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter if err := ctx.Err(); err != nil { return numRows.Load(), err } else { - avroMap, err := avroConverter.Convert(qrecord) + avroMap, err := avroConverter.Convert(ctx, env, qrecord) if err != nil { logger.Error("Failed to convert QRecord to Avro compatible map", slog.Any("error", err)) return numRows.Load(), fmt.Errorf("failed to convert QRecord to Avro compatible map: %w", err) @@ -172,7 +177,7 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter return numRows.Load(), nil } -func (p *peerDBOCFWriter) WriteOCF(ctx context.Context, w io.Writer) (int, error) { +func (p *peerDBOCFWriter) WriteOCF(ctx context.Context, env map[string]string, w io.Writer) (int, error) { ocfWriter, err := p.createOCFWriter(w) if err != nil { return 0, fmt.Errorf("failed to create OCF writer: %w", err) @@ -180,7 +185,7 @@ func (p *peerDBOCFWriter) WriteOCF(ctx context.Context, w io.Writer) (int, error // we have to keep a reference to the underlying writer as goavro doesn't provide any access to it defer p.writer.Close() - numRows, err := p.writeRecordsToOCFWriter(ctx, ocfWriter) + numRows, err := p.writeRecordsToOCFWriter(ctx, env, ocfWriter) if err != nil { return 0, fmt.Errorf("failed to write records to OCF writer: %w", err) } @@ -217,7 +222,7 @@ func (p *peerDBOCFWriter) WriteRecordsToS3( } w.Close() }() - numRows, writeOcfError = p.WriteOCF(ctx, w) + numRows, writeOcfError = p.WriteOCF(ctx, env, w) }() partSize, err := peerdbenv.PeerDBS3PartSize(ctx, env) @@ -254,7 +259,7 @@ func (p *peerDBOCFWriter) WriteRecordsToS3( }, nil } -func (p *peerDBOCFWriter) WriteRecordsToAvroFile(ctx context.Context, filePath string) (*AvroFile, error) { +func (p *peerDBOCFWriter) WriteRecordsToAvroFile(ctx context.Context, env map[string]string, filePath string) (*AvroFile, error) { file, err := os.Create(filePath) if err != nil { return nil, fmt.Errorf("failed to create temporary Avro file: %w", err) @@ -275,7 +280,7 @@ func (p *peerDBOCFWriter) WriteRecordsToAvroFile(ctx context.Context, filePath s bufferedWriter := bufio.NewWriterSize(file, buffSizeBytes) defer bufferedWriter.Flush() - numRecords, err := p.WriteOCF(ctx, bufferedWriter) + numRecords, err := p.WriteOCF(ctx, env, bufferedWriter) if err != nil { return nil, fmt.Errorf("failed to write records to temporary Avro file: %w", err) } diff --git a/flow/model/conversion_avro.go b/flow/model/conversion_avro.go index 070f76559..1fe74efc7 100644 --- a/flow/model/conversion_avro.go +++ b/flow/model/conversion_avro.go @@ -9,38 +9,52 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model/qvalue" + "github.com/PeerDB-io/peer-flow/peerdbenv" ) type QRecordAvroConverter struct { - logger log.Logger - Schema *QRecordAvroSchemaDefinition - ColNames []string - TargetDWH protos.DBType + logger log.Logger + Schema *QRecordAvroSchemaDefinition + ColNames []string + TargetDWH protos.DBType + UnboundedNumericAsString bool } func NewQRecordAvroConverter( + ctx context.Context, + env map[string]string, schema *QRecordAvroSchemaDefinition, targetDWH protos.DBType, colNames []string, logger log.Logger, -) *QRecordAvroConverter { - return &QRecordAvroConverter{ - Schema: schema, - TargetDWH: targetDWH, - ColNames: colNames, - logger: logger, +) (*QRecordAvroConverter, error) { + var unboundedNumericAsString bool + if targetDWH == protos.DBType_CLICKHOUSE { + var err error + unboundedNumericAsString, err = peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env) + if err != nil { + return nil, err + } } -} -func (qac *QRecordAvroConverter) Convert(qrecord []qvalue.QValue) (map[string]interface{}, error) { - m := make(map[string]interface{}, len(qrecord)) + return &QRecordAvroConverter{ + Schema: schema, + TargetDWH: targetDWH, + ColNames: colNames, + logger: logger, + UnboundedNumericAsString: unboundedNumericAsString, + }, nil +} +func (qac *QRecordAvroConverter) Convert(ctx context.Context, env map[string]string, qrecord []qvalue.QValue) (map[string]any, error) { + m := make(map[string]any, len(qrecord)) for idx, val := range qrecord { avroVal, err := qvalue.QValueToAvro( val, &qac.Schema.Fields[idx], qac.TargetDWH, qac.logger, + qac.UnboundedNumericAsString, ) if err != nil { return nil, fmt.Errorf("failed to convert QValue to Avro-compatible value: %w", err) @@ -53,8 +67,8 @@ func (qac *QRecordAvroConverter) Convert(qrecord []qvalue.QValue) (map[string]in } type QRecordAvroField struct { - Type interface{} `json:"type"` - Name string `json:"name"` + Type any `json:"type"` + Name string `json:"name"` } type QRecordAvroSchema struct { diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index 024b3e31e..b75bcdc21 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -112,12 +112,17 @@ func GetAvroSchemaFromQValueKind( } return "bytes", nil case QValueKindNumeric: - if targetDWH == protos.DBType_CLICKHOUSE && precision == 0 && scale == 0 { - asString, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env) - if err != nil { - return nil, err + if targetDWH == protos.DBType_CLICKHOUSE { + if precision == 0 && scale == 0 { + asString, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env) + if err != nil { + return nil, err + } + if asString { + return "string", nil + } } - if asString { + if precision > datatypes.PeerDBClickHouseMaxPrecision { return "string", nil } } @@ -226,19 +231,24 @@ func GetAvroSchemaFromQValueKind( type QValueAvroConverter struct { *QField - logger log.Logger - TargetDWH protos.DBType + logger log.Logger + TargetDWH protos.DBType + UnboundedNumericAsString bool } -func QValueToAvro(value QValue, field *QField, targetDWH protos.DBType, logger log.Logger) (interface{}, error) { +func QValueToAvro( + value QValue, field *QField, targetDWH protos.DBType, logger log.Logger, + unboundedNumericAsString bool, +) (any, error) { if value.Value() == nil { return nil, nil } - c := &QValueAvroConverter{ - QField: field, - TargetDWH: targetDWH, - logger: logger, + c := QValueAvroConverter{ + QField: field, + TargetDWH: targetDWH, + logger: logger, + UnboundedNumericAsString: unboundedNumericAsString, } switch v := value.(type) { @@ -470,18 +480,18 @@ func (c *QValueAvroConverter) processNullableUnion( return value, nil } -func (c *QValueAvroConverter) processNumeric(num decimal.Decimal) interface{} { +func (c *QValueAvroConverter) processNumeric(num decimal.Decimal) any { num, err := TruncateOrLogNumeric(num, c.Precision, c.Scale, c.TargetDWH) if err != nil { return nil } - if c.TargetDWH == protos.DBType_CLICKHOUSE && - c.Precision > datatypes.PeerDBClickHouseMaxPrecision { - // no error returned + if (c.UnboundedNumericAsString && c.Precision == 0 && c.Scale == 0) || + (c.TargetDWH == protos.DBType_CLICKHOUSE && c.Precision > datatypes.PeerDBClickHouseMaxPrecision) { numStr, _ := c.processNullableUnion("string", num.String()) return numStr } + rat := num.Rat() if c.Nullable { return goavro.Union("bytes.decimal", rat)