Skip to content

Commit

Permalink
more avro consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Dec 2, 2024
1 parent dd43178 commit d33b56e
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 50 deletions.
2 changes: 1 addition & 1 deletion flow/connectors/bigquery/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (c *BigQueryConnector) SyncQRepRecords(
partition.PartitionId, destTable))

avroSync := NewQRepAvroSyncMethod(c, config.StagingPath, config.FlowJobName)
return avroSync.SyncQRepRecords(ctx, config.FlowJobName, destTable, partition,
return avroSync.SyncQRepRecords(ctx, config.Env, config.FlowJobName, destTable, partition,
tblMetadata, stream, config.SyncedAtColName, config.SoftDeleteColName)
}

Expand Down
10 changes: 6 additions & 4 deletions flow/connectors/bigquery/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (s *QRepAvroSyncMethod) SyncRecords(
}

stagingTable := fmt.Sprintf("%s_%s_staging", rawTableName, strconv.FormatInt(syncBatchID, 10))
numRecords, err := s.writeToStage(ctx, strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema,
numRecords, err := s.writeToStage(ctx, req.Env, strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema,
&datasetTable{
project: s.connector.projectID,
dataset: s.connector.datasetID,
Expand Down Expand Up @@ -139,6 +139,7 @@ func getTransformedColumns(dstSchema *bigquery.Schema, syncedAtCol string, softD

func (s *QRepAvroSyncMethod) SyncQRepRecords(
ctx context.Context,
env map[string]string,
flowJobName string,
dstTableName string,
partition *protos.QRepPartition,
Expand Down Expand Up @@ -167,7 +168,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
table: fmt.Sprintf("%s_%s_staging", dstDatasetTable.table,
strings.ReplaceAll(partition.PartitionId, "-", "_")),
}
numRecords, err := s.writeToStage(ctx, partition.PartitionId, flowJobName, avroSchema,
numRecords, err := s.writeToStage(ctx, env, partition.PartitionId, flowJobName, avroSchema,
stagingDatasetTable, stream, flowJobName)
if err != nil {
return -1, fmt.Errorf("failed to push to avro stage: %w", err)
Expand Down Expand Up @@ -389,6 +390,7 @@ func GetAvroField(bqField *bigquery.FieldSchema) (AvroField, error) {

func (s *QRepAvroSyncMethod) writeToStage(
ctx context.Context,
env map[string]string,
syncID string,
objectFolder string,
avroSchema *model.QRecordAvroSchemaDefinition,
Expand All @@ -408,7 +410,7 @@ func (s *QRepAvroSyncMethod) writeToStage(
obj := bucket.Object(avroFilePath)
w := obj.NewWriter(ctx)

numRecords, err := ocfWriter.WriteOCF(ctx, w)
numRecords, err := ocfWriter.WriteOCF(ctx, env, w)
if err != nil {
return 0, fmt.Errorf("failed to write records to Avro file on GCS: %w", err)
}
Expand All @@ -426,7 +428,7 @@ func (s *QRepAvroSyncMethod) writeToStage(

avroFilePath := fmt.Sprintf("%s/%s.avro", tmpDir, syncID)
s.connector.logger.Info("writing records to local file", idLog)
avroFile, err = ocfWriter.WriteRecordsToAvroFile(ctx, avroFilePath)
avroFile, err = ocfWriter.WriteRecordsToAvroFile(ctx, env, avroFilePath)
if err != nil {
return 0, fmt.Errorf("failed to write records to local Avro file: %w", err)
}
Expand Down
10 changes: 5 additions & 5 deletions flow/connectors/snowflake/avro_file_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestWriteRecordsToAvroFileHappyPath(t *testing.T) {

// Call function
writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressNone, protos.DBType_SNOWFLAKE)
_, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name())
_, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name())
require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors")

// Check file is not empty
Expand All @@ -178,7 +178,7 @@ func TestWriteRecordsToZstdAvroFileHappyPath(t *testing.T) {

// Call function
writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressZstd, protos.DBType_SNOWFLAKE)
_, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name())
_, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name())
require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors")

// Check file is not empty
Expand All @@ -205,7 +205,7 @@ func TestWriteRecordsToDeflateAvroFileHappyPath(t *testing.T) {

// Call function
writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressDeflate, protos.DBType_SNOWFLAKE)
_, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name())
_, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name())
require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors")

// Check file is not empty
Expand All @@ -231,7 +231,7 @@ func TestWriteRecordsToAvroFileNonNull(t *testing.T) {

// Call function
writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressNone, protos.DBType_SNOWFLAKE)
_, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name())
_, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name())
require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors")

// Check file is not empty
Expand All @@ -258,7 +258,7 @@ func TestWriteRecordsToAvroFileAllNulls(t *testing.T) {

// Call function
writer := avro.NewPeerDBOCFWriter(records, avroSchema, avro.CompressNone, protos.DBType_SNOWFLAKE)
_, err = writer.WriteRecordsToAvroFile(context.Background(), tmpfile.Name())
_, err = writer.WriteRecordsToAvroFile(context.Background(), nil, tmpfile.Name())
require.NoError(t, err, "expected WriteRecordsToAvroFile to complete without errors")

// Check file is not empty
Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/snowflake/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (s *SnowflakeAvroSyncHandler) writeToAvroFile(

localFilePath := fmt.Sprintf("%s/%s.avro.zst", tmpDir, partitionID)
s.logger.Info("writing records to local file " + localFilePath)
avroFile, err := ocfWriter.WriteRecordsToAvroFile(ctx, localFilePath)
avroFile, err := ocfWriter.WriteRecordsToAvroFile(ctx, env, localFilePath)
if err != nil {
return nil, fmt.Errorf("failed to write records to Avro file: %w", err)
}
Expand Down
21 changes: 13 additions & 8 deletions flow/connectors/utils/avro/avro_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,21 @@ func (p *peerDBOCFWriter) createOCFWriter(w io.Writer) (*goavro.OCFWriter, error
return ocfWriter, nil
}

func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter *goavro.OCFWriter) (int64, error) {
func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, env map[string]string, ocfWriter *goavro.OCFWriter) (int64, error) {
logger := shared.LoggerFromCtx(ctx)
schema := p.stream.Schema()

avroConverter := model.NewQRecordAvroConverter(
avroConverter, err := model.NewQRecordAvroConverter(
ctx,
env,
p.avroSchema,
p.targetDWH,
schema.GetColumnNames(),
logger,
)
if err != nil {
return 0, err
}

numRows := atomic.Int64{}

Expand All @@ -149,7 +154,7 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter
if err := ctx.Err(); err != nil {
return numRows.Load(), err
} else {
avroMap, err := avroConverter.Convert(qrecord)
avroMap, err := avroConverter.Convert(ctx, env, qrecord)
if err != nil {
logger.Error("Failed to convert QRecord to Avro compatible map", slog.Any("error", err))
return numRows.Load(), fmt.Errorf("failed to convert QRecord to Avro compatible map: %w", err)
Expand All @@ -172,15 +177,15 @@ func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, ocfWriter
return numRows.Load(), nil
}

func (p *peerDBOCFWriter) WriteOCF(ctx context.Context, w io.Writer) (int, error) {
func (p *peerDBOCFWriter) WriteOCF(ctx context.Context, env map[string]string, w io.Writer) (int, error) {
ocfWriter, err := p.createOCFWriter(w)
if err != nil {
return 0, fmt.Errorf("failed to create OCF writer: %w", err)
}
// we have to keep a reference to the underlying writer as goavro doesn't provide any access to it
defer p.writer.Close()

numRows, err := p.writeRecordsToOCFWriter(ctx, ocfWriter)
numRows, err := p.writeRecordsToOCFWriter(ctx, env, ocfWriter)
if err != nil {
return 0, fmt.Errorf("failed to write records to OCF writer: %w", err)
}
Expand Down Expand Up @@ -217,7 +222,7 @@ func (p *peerDBOCFWriter) WriteRecordsToS3(
}
w.Close()
}()
numRows, writeOcfError = p.WriteOCF(ctx, w)
numRows, writeOcfError = p.WriteOCF(ctx, env, w)
}()

partSize, err := peerdbenv.PeerDBS3PartSize(ctx, env)
Expand Down Expand Up @@ -254,7 +259,7 @@ func (p *peerDBOCFWriter) WriteRecordsToS3(
}, nil
}

func (p *peerDBOCFWriter) WriteRecordsToAvroFile(ctx context.Context, filePath string) (*AvroFile, error) {
func (p *peerDBOCFWriter) WriteRecordsToAvroFile(ctx context.Context, env map[string]string, filePath string) (*AvroFile, error) {
file, err := os.Create(filePath)
if err != nil {
return nil, fmt.Errorf("failed to create temporary Avro file: %w", err)
Expand All @@ -275,7 +280,7 @@ func (p *peerDBOCFWriter) WriteRecordsToAvroFile(ctx context.Context, filePath s
bufferedWriter := bufio.NewWriterSize(file, buffSizeBytes)
defer bufferedWriter.Flush()

numRecords, err := p.WriteOCF(ctx, bufferedWriter)
numRecords, err := p.WriteOCF(ctx, env, bufferedWriter)
if err != nil {
return nil, fmt.Errorf("failed to write records to temporary Avro file: %w", err)
}
Expand Down
44 changes: 29 additions & 15 deletions flow/model/conversion_avro.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,52 @@ import (

"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/PeerDB-io/peer-flow/peerdbenv"
)

type QRecordAvroConverter struct {
logger log.Logger
Schema *QRecordAvroSchemaDefinition
ColNames []string
TargetDWH protos.DBType
logger log.Logger
Schema *QRecordAvroSchemaDefinition
ColNames []string
TargetDWH protos.DBType
UnboundedNumericAsString bool
}

func NewQRecordAvroConverter(
ctx context.Context,
env map[string]string,
schema *QRecordAvroSchemaDefinition,
targetDWH protos.DBType,
colNames []string,
logger log.Logger,
) *QRecordAvroConverter {
return &QRecordAvroConverter{
Schema: schema,
TargetDWH: targetDWH,
ColNames: colNames,
logger: logger,
) (*QRecordAvroConverter, error) {
var unboundedNumericAsString bool
if targetDWH == protos.DBType_CLICKHOUSE {
var err error
unboundedNumericAsString, err = peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env)
if err != nil {
return nil, err
}
}
}

func (qac *QRecordAvroConverter) Convert(qrecord []qvalue.QValue) (map[string]interface{}, error) {
m := make(map[string]interface{}, len(qrecord))
return &QRecordAvroConverter{
Schema: schema,
TargetDWH: targetDWH,
ColNames: colNames,
logger: logger,
UnboundedNumericAsString: unboundedNumericAsString,
}, nil
}

func (qac *QRecordAvroConverter) Convert(ctx context.Context, env map[string]string, qrecord []qvalue.QValue) (map[string]any, error) {
m := make(map[string]any, len(qrecord))
for idx, val := range qrecord {
avroVal, err := qvalue.QValueToAvro(
val,
&qac.Schema.Fields[idx],
qac.TargetDWH,
qac.logger,
qac.UnboundedNumericAsString,
)
if err != nil {
return nil, fmt.Errorf("failed to convert QValue to Avro-compatible value: %w", err)
Expand All @@ -53,8 +67,8 @@ func (qac *QRecordAvroConverter) Convert(qrecord []qvalue.QValue) (map[string]in
}

type QRecordAvroField struct {
Type interface{} `json:"type"`
Name string `json:"name"`
Type any `json:"type"`
Name string `json:"name"`
}

type QRecordAvroSchema struct {
Expand Down
42 changes: 26 additions & 16 deletions flow/model/qvalue/avro_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,17 @@ func GetAvroSchemaFromQValueKind(
}
return "bytes", nil
case QValueKindNumeric:
if targetDWH == protos.DBType_CLICKHOUSE && precision == 0 && scale == 0 {
asString, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env)
if err != nil {
return nil, err
if targetDWH == protos.DBType_CLICKHOUSE {
if precision == 0 && scale == 0 {
asString, err := peerdbenv.PeerDBEnableClickHouseNumericAsString(ctx, env)
if err != nil {
return nil, err
}
if asString {
return "string", nil
}
}
if asString {
if precision > datatypes.PeerDBClickHouseMaxPrecision {
return "string", nil
}
}
Expand Down Expand Up @@ -226,19 +231,24 @@ func GetAvroSchemaFromQValueKind(

type QValueAvroConverter struct {
*QField
logger log.Logger
TargetDWH protos.DBType
logger log.Logger
TargetDWH protos.DBType
UnboundedNumericAsString bool
}

func QValueToAvro(value QValue, field *QField, targetDWH protos.DBType, logger log.Logger) (interface{}, error) {
func QValueToAvro(
value QValue, field *QField, targetDWH protos.DBType, logger log.Logger,
unboundedNumericAsString bool,
) (any, error) {
if value.Value() == nil {
return nil, nil
}

c := &QValueAvroConverter{
QField: field,
TargetDWH: targetDWH,
logger: logger,
c := QValueAvroConverter{
QField: field,
TargetDWH: targetDWH,
logger: logger,
UnboundedNumericAsString: unboundedNumericAsString,
}

switch v := value.(type) {
Expand Down Expand Up @@ -470,18 +480,18 @@ func (c *QValueAvroConverter) processNullableUnion(
return value, nil
}

func (c *QValueAvroConverter) processNumeric(num decimal.Decimal) interface{} {
func (c *QValueAvroConverter) processNumeric(num decimal.Decimal) any {
num, err := TruncateOrLogNumeric(num, c.Precision, c.Scale, c.TargetDWH)
if err != nil {
return nil
}

if c.TargetDWH == protos.DBType_CLICKHOUSE &&
c.Precision > datatypes.PeerDBClickHouseMaxPrecision {
// no error returned
if (c.UnboundedNumericAsString && c.Precision == 0 && c.Scale == 0) ||
(c.TargetDWH == protos.DBType_CLICKHOUSE && c.Precision > datatypes.PeerDBClickHouseMaxPrecision) {
numStr, _ := c.processNullableUnion("string", num.String())
return numStr
}

rat := num.Rat()
if c.Nullable {
return goavro.Union("bytes.decimal", rat)
Expand Down

0 comments on commit d33b56e

Please sign in to comment.