Skip to content

Commit

Permalink
more avro consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Dec 2, 2024
1 parent dd43178 commit 72d819a
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 33 deletions.
2 changes: 1 addition & 1 deletion flow/connectors/bigquery/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (c *BigQueryConnector) SyncQRepRecords(
partition.PartitionId, destTable))

avroSync := NewQRepAvroSyncMethod(c, config.StagingPath, config.FlowJobName)
return avroSync.SyncQRepRecords(ctx, config.FlowJobName, destTable, partition,
return avroSync.SyncQRepRecords(ctx, config.Env, config.FlowJobName, destTable, partition,
tblMetadata, stream, config.SyncedAtColName, config.SoftDeleteColName)
}

Expand Down
10 changes: 6 additions & 4 deletions flow/connectors/bigquery/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (s *QRepAvroSyncMethod) SyncRecords(
}

stagingTable := fmt.Sprintf("%s_%s_staging", rawTableName, strconv.FormatInt(syncBatchID, 10))
numRecords, err := s.writeToStage(ctx, strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema,
numRecords, err := s.writeToStage(ctx, req.Env, strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema,
&datasetTable{
project: s.connector.projectID,
dataset: s.connector.datasetID,
Expand Down Expand Up @@ -139,6 +139,7 @@ func getTransformedColumns(dstSchema *bigquery.Schema, syncedAtCol string, softD

func (s *QRepAvroSyncMethod) SyncQRepRecords(
ctx context.Context,
env map[string]string,
flowJobName string,
dstTableName string,
partition *protos.QRepPartition,
Expand Down Expand Up @@ -167,7 +168,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
table: fmt.Sprintf("%s_%s_staging", dstDatasetTable.table,
strings.ReplaceAll(partition.PartitionId, "-", "_")),
}
numRecords, err := s.writeToStage(ctx, partition.PartitionId, flowJobName, avroSchema,
numRecords, err := s.writeToStage(ctx, env, partition.PartitionId, flowJobName, avroSchema,
stagingDatasetTable, stream, flowJobName)
if err != nil {
return -1, fmt.Errorf("failed to push to avro stage: %w", err)
Expand Down Expand Up @@ -389,6 +390,7 @@ func GetAvroField(bqField *bigquery.FieldSchema) (AvroField, error) {

func (s *QRepAvroSyncMethod) writeToStage(
ctx context.Context,
env map[string]string,
syncID string,
objectFolder string,
avroSchema *model.QRecordAvroSchemaDefinition,
Expand All @@ -408,7 +410,7 @@ func (s *QRepAvroSyncMethod) writeToStage(
obj := bucket.Object(avroFilePath)
w := obj.NewWriter(ctx)

numRecords, err := ocfWriter.WriteOCF(ctx, w)
numRecords, err := ocfWriter.WriteOCF(ctx, env, w)
if err != nil {
return 0, fmt.Errorf("failed to write records to Avro file on GCS: %w", err)
}
Expand All @@ -426,7 +428,7 @@ func (s *QRepAvroSyncMethod) writeToStage(

avroFilePath := fmt.Sprintf("%s/%s.avro", tmpDir, syncID)
s.connector.logger.Info("writing records to local file", idLog)
avroFile, err = ocfWriter.WriteRecordsToAvroFile(ctx, avroFilePath)
avroFile, err = ocfWriter.WriteRecordsToAvroFile(ctx, env, avroFilePath)
if err != nil {
return 0, fmt.Errorf("failed to write records to local Avro file: %w", err)
}
Expand Down
10 changes: 5 additions & 5 deletions flow/connectors/snowflake/avro_file_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestWriteRecordsToAvroFileHappyPath(t *testing.T) {

// Call function
writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressNone, protos.DBType_SNOWFLAKE)
_, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name())
_, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name())
require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors")

// Check file is not empty
Expand All @@ -178,7 +178,7 @@ func TestWriteRecordsToZstdAvroFileHappyPath(t *testing.T) {

// Call function
writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressZstd, protos.DBType_SNOWFLAKE)
_, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name())
_, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name())
require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors")

// Check file is not empty
Expand All @@ -205,7 +205,7 @@ func TestWriteRecordsToDeflateAvroFileHappyPath(t *testing.T) {

// Call function
writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressDeflate, protos.DBType_SNOWFLAKE)
_, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name())
_, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name())
require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors")

// Check file is not empty
Expand All @@ -231,7 +231,7 @@ func TestWriteRecordsToAvroFileNonNull(t *testing.T) {

// Call function
writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressNone, protos.DBType_SNOWFLAKE)
_, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name())
_, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name())
require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors")

// Check file is not empty
Expand All @@ -258,7 +258,7 @@ func TestWriteRecordsToAvroFileAllNulls(t *testing.T) {

// Call function
writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressNone, protos.DBType_SNOWFLAKE)
_, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name())
_, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name())
require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors")

// Check file is not empty
Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/snowflake/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (s *SnowflakeAvroSyncHandler) writeToAvroFile(

localFilePath := fmt.Sprintf("%s/%s.avro.zst", tmpDir, partitionID)
s.logger.Info("writing records to local file " + localFilePath)
avroFile, err := ocfWriter.WriteRecordsToAvroFile(ctx, localFilePath)
avroFile, err := ocfWriter.WriteRecordsToAvroFile(ctx, env, localFilePath)
if err != nil {
return nil, fmt.Errorf("failed to write records to Avro file: %w", err)
}
Expand Down
14 changes: 7 additions & 7 deletions flow/connectors/utils/avro/avro_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (p *peerDBOCFWriter) createOCFWriter(w io.Writer) (*goavro.OCFWriter, error
return ocfWriter, nil
}

func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter *goavro.OCFWriter) (int64, error) {
func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, env map[string]string, ocfWriter *goavro.OCFWriter) (int64, error) {
logger := shared.LoggerFromCtx(ctx)
schema := p.stream.Schema()

Expand All @@ -149,7 +149,7 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter
if err := ctx.Err(); err != nil {
return numRows.Load(), err
} else {
avroMap, err := avroConverter.Convert(qrecord)
avroMap, err := avroConverter.Convert(ctx, env, qrecord)
if err != nil {
logger.Error("Failed to convert QRecord to Avro compatible map", slog.Any("error", err))
return numRows.Load(), fmt.Errorf("failed to convert QRecord to Avro compatible map: %w", err)
Expand All @@ -172,15 +172,15 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter
return numRows.Load(), nil
}

func (p *peerDBOCFWriter) WriteOCF(ctx context.Context, w io.Writer) (int, error) {
func (p *peerDBOCFWriter) WriteOCF(ctx context.Context, env map[string]string, w io.Writer) (int, error) {
ocfWriter, err := p.createOCFWriter(w)
if err != nil {
return 0, fmt.Errorf("failed to create OCF writer: %w", err)
}
// we have to keep a reference to the underlying writer as goavro doesn't provide any access to it
defer p.writer.Close()

numRows, err := p.writeRecordsToOCFWriter(ctx, ocfWriter)
numRows, err := p.writeRecordsToOCFWriter(ctx, env, ocfWriter)
if err != nil {
return 0, fmt.Errorf("failed to write records to OCF writer: %w", err)
}
Expand Down Expand Up @@ -217,7 +217,7 @@ func (p *peerDBOCFWriter) WriteRecordsToS3(
}
w.Close()
}()
numRows, writeOcfError = p.WriteOCF(ctx, w)
numRows, writeOcfError = p.WriteOCF(ctx, env, w)
}()

partSize, err := peerdbenv.PeerDBS3PartSize(ctx, env)
Expand Down Expand Up @@ -254,7 +254,7 @@ func (p *peerDBOCFWriter) WriteRecordsToS3(
}, nil
}

func (p *peerDBOCFWriter) WriteRecordsToAvroFile(ctx context.Context, filePath string) (*AvroFile, error) {
func (p *peerDBOCFWriter) WriteRecordsToAvroFile(ctx context.Context, env map[string]string, filePath string) (*AvroFile, error) {
file, err := os.Create(filePath)
if err != nil {
return nil, fmt.Errorf("failed to create temporary Avro file: %w", err)
Expand All @@ -275,7 +275,7 @@ func (p *peerDBOCFWriter) WriteRecordsToAvroFile(ctx context.Context, filePath s
bufferedWriter := bufio.NewWriterSize(file, buffSizeBytes)
defer bufferedWriter.Flush()

numRecords, err := p.WriteOCF(ctx, bufferedWriter)
numRecords, err := p.WriteOCF(ctx, env, bufferedWriter)
if err != nil {
return nil, fmt.Errorf("failed to write records to temporary Avro file: %w", err)
}
Expand Down
6 changes: 4 additions & 2 deletions flow/model/conversion_avro.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ func NewQRecordAvroConverter(
}
}

func (qac *QRecordAvroConverter) Convert(qrecord []qvalue.QValue) (map[string]interface{}, error) {
m := make(map[string]interface{}, len(qrecord))
func (qac *QRecordAvroConverter) Convert(ctx context.Context, env map[string]string, qrecord []qvalue.QValue) (map[string]any, error) {
m := make(map[string]any, len(qrecord))

for idx, val := range qrecord {
avroVal, err := qvalue.QValueToAvro(
ctx,
env,
val,
&qac.Schema.Fields[idx],
qac.TargetDWH,
Expand Down
45 changes: 32 additions & 13 deletions flow/model/qvalue/avro_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,17 @@ func GetAvroSchemaFromQValueKind(
}
return "bytes", nil
case QValueKindNumeric:
if targetDWH == protos.DBType_CLICKHOUSE && precision == 0 && scale == 0 {
asString, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env)
if err != nil {
return nil, err
if targetDWH == protos.DBType_CLICKHOUSE {
if precision == 0 && scale == 0 {
asString, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env)
if err != nil {
return nil, err
}
if asString {
return "string", nil
}
}
if asString {
if precision > datatypes.PeerDBClickHouseMaxPrecision {
return "string", nil
}
}
Expand Down Expand Up @@ -230,7 +235,9 @@ type QValueAvroConverter struct {
TargetDWH protos.DBType
}

func QValueToAvro(value QValue, field *QField, targetDWH protos.DBType, logger log.Logger) (interface{}, error) {
func QValueToAvro(ctx context.Context, env map[string]string,
value QValue, field *QField, targetDWH protos.DBType, logger log.Logger,
) (any, error) {
if value.Value() == nil {
return nil, nil
}
Expand Down Expand Up @@ -357,7 +364,7 @@ func QValueToAvro(value QValue, field *QField, targetDWH protos.DBType, logger l
case QValueStruct:
return nil, errors.New("QValueStruct not supported")
case QValueNumeric:
return c.processNumeric(v.Val), nil
return c.processNumeric(ctx, env, v.Val), nil
case QValueBytes:
return c.processBytes(v.Val), nil
case QValueJSON:
Expand Down Expand Up @@ -470,18 +477,30 @@ func (c *QValueAvroConverter) processNullableUnion(
return value, nil
}

func (c *QValueAvroConverter) processNumeric(num decimal.Decimal) interface{} {
func (c *QValueAvroConverter) processNumeric(ctx context.Context, env map[string]string, num decimal.Decimal) any {
num, err := TruncateOrLogNumeric(num, c.Precision, c.Scale, c.TargetDWH)
if err != nil {
return nil
}

if c.TargetDWH == protos.DBType_CLICKHOUSE &&
c.Precision > datatypes.PeerDBClickHouseMaxPrecision {
// no error returned
numStr, _ := c.processNullableUnion("string", num.String())
return numStr
if c.TargetDWH == protos.DBType_CLICKHOUSE {
if c.Precision == 0 && c.Scale == 0 {
asString, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env)
if err != nil {
// TODO lift
return err
}
if asString {
numStr, _ := c.processNullableUnion("string", num.String())
return numStr
}
}
if c.Precision > datatypes.PeerDBClickHouseMaxPrecision {
numStr, _ := c.processNullableUnion("string", num.String())
return numStr
}
}

rat := num.Rat()
if c.Nullable {
return goavro.Union("bytes.decimal", rat)
Expand Down

0 comments on commit 72d819a

Please sign in to comment.