Skip to content

Commit

Permalink
switch to using channel to distribute queries across connections
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Nov 14, 2024
1 parent 404fe3e commit e7b92bf
Showing 1 changed file with 172 additions and 165 deletions.
337 changes: 172 additions & 165 deletions flow/connectors/clickhouse/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@ import (
"database/sql"
"errors"
"fmt"
"iter"
"log/slog"
"slices"
"strconv"
"strings"
"time"

"github.com/ClickHouse/clickhouse-go/v2"
"golang.org/x/sync/errgroup"

"github.com/PeerDB-io/peer-flow/datatypes"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/PeerDB-io/peer-flow/peerdbenv"
"github.com/PeerDB-io/peer-flow/shared"
"golang.org/x/sync/errgroup"
)

const (
Expand Down Expand Up @@ -280,200 +280,207 @@ func (c *ClickHouseConnector) NormalizeRecords(
return nil, err
}

parallelNormalize, err := peerdbenv.PeerDBClickHouseParallelNormalize(ctx, req.Env)
enablePrimaryUpdate, err := peerdbenv.PeerDBEnableClickHousePrimaryUpdate(ctx, req.Env)
if err != nil {
return nil, err
}

var destinationTableNameChunks iter.Seq[[]string]
if parallelNormalize > 1 {
var chunkSize int
if parallelNormalize >= len(destinationTableNames) {
chunkSize = 1
} else {
chunkSize = shared.DivCeil(len(destinationTableNames), parallelNormalize)
}
destinationTableNameChunks = slices.Chunk(destinationTableNames, chunkSize)
} else {
destinationTableNameChunks = func(yield func([]string) bool) {
yield(destinationTableNames)
}
}

enablePrimaryUpdate, err := peerdbenv.PeerDBEnableClickHousePrimaryUpdate(ctx, req.Env)
parallelNormalize, err := peerdbenv.PeerDBClickHouseParallelNormalize(ctx, req.Env)
if err != nil {
return nil, err
}
parallelNormalize = min(max(parallelNormalize, 1), len(destinationTableNames))
if parallelNormalize > 1 {
c.logger.Info("normalizing in parallel", slog.Int("connections", parallelNormalize))
}

queries := make(chan string)
rawTbl := c.getRawTableName(req.FlowJobName)

group, errCtx := errgroup.WithContext(ctx)
for tableNames := range destinationTableNameChunks {
for i := range parallelNormalize {
group.Go(func() error {
chConn, err := Connect(errCtx, c.config)
if err != nil {
return err
var chConn clickhouse.Conn
if i == 0 {
chConn = c.database
} else {
var err error
chConn, err = Connect(errCtx, c.config)
if err != nil {
return err
}
defer chConn.Close()
}
defer chConn.Close()

for _, tbl := range tableNames {
// SELECT projection FROM raw_table WHERE _peerdb_batch_id > normalize_batch_id AND _peerdb_batch_id <= sync_batch_id
selectQuery := strings.Builder{}
selectQuery.WriteString("SELECT ")
for query := range queries {
c.logger.Info("normalizing batch", slog.String("query", query))
if err := chConn.Exec(errCtx, query); err != nil {
return fmt.Errorf("error while inserting into normalized table: %w", err)
}
}
return nil
})
}

colSelector := strings.Builder{}
colSelector.WriteRune('(')
for _, tbl := range destinationTableNames {
// SELECT projection FROM raw_table WHERE _peerdb_batch_id > normalize_batch_id AND _peerdb_batch_id <= sync_batch_id
selectQuery := strings.Builder{}
selectQuery.WriteString("SELECT ")

schema := req.TableNameSchemaMapping[tbl]
colSelector := strings.Builder{}
colSelector.WriteRune('(')

var tableMapping *protos.TableMapping
for _, tm := range req.TableMappings {
if tm.DestinationTableIdentifier == tbl {
tableMapping = tm
break
}
}
schema := req.TableNameSchemaMapping[tbl]

projection := strings.Builder{}
projectionUpdate := strings.Builder{}

for _, column := range schema.Columns {
colName := column.Name
dstColName := colName
colType := qvalue.QValueKind(column.Type)

var clickHouseType string
var columnNullableEnabled bool
if tableMapping != nil {
for _, col := range tableMapping.Columns {
if col.SourceName == colName {
if col.DestinationName != "" {
dstColName = col.DestinationName
}
if col.DestinationType != "" {
// TODO can we restrict this to avoid injection?
clickHouseType = col.DestinationType
}
columnNullableEnabled = col.NullableEnabled
break
}
}
}
var tableMapping *protos.TableMapping
for _, tm := range req.TableMappings {
if tm.DestinationTableIdentifier == tbl {
tableMapping = tm
break
}
}

colSelector.WriteString(fmt.Sprintf("`%s`,", dstColName))
if clickHouseType == "" {
if colType == qvalue.QValueKindNumeric {
clickHouseType = getClickhouseTypeForNumericColumn(column)
} else {
var err error
clickHouseType, err = colType.ToDWHColumnType(protos.DBType_CLICKHOUSE)
if err != nil {
return fmt.Errorf("error while converting column type to clickhouse type: %w", err)
}
projection := strings.Builder{}
projectionUpdate := strings.Builder{}

for _, column := range schema.Columns {
colName := column.Name
dstColName := colName
colType := qvalue.QValueKind(column.Type)

var clickHouseType string
var columnNullableEnabled bool
if tableMapping != nil {
for _, col := range tableMapping.Columns {
if col.SourceName == colName {
if col.DestinationName != "" {
dstColName = col.DestinationName
}
if (schema.NullableEnabled || columnNullableEnabled) && column.Nullable && !colType.IsArray() {
clickHouseType = fmt.Sprintf("Nullable(%s)", clickHouseType)
if col.DestinationType != "" {
// TODO can we restrict this to avoid injection?
clickHouseType = col.DestinationType
}
columnNullableEnabled = col.NullableEnabled
break
}
}
}

switch clickHouseType {
case "Date32", "Nullable(Date32)":
projection.WriteString(fmt.Sprintf(
"toDate32(parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_data, '%s'),6)) AS `%s`,",
colName,
dstColName,
))
if enablePrimaryUpdate {
projectionUpdate.WriteString(fmt.Sprintf(
"toDate32(parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_match_data, '%s'),6)) AS `%s`,",
colName,
dstColName,
))
}
case "DateTime64(6)", "Nullable(DateTime64(6))":
projection.WriteString(fmt.Sprintf(
"parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_data, '%s'),6) AS `%s`,",
colName,
dstColName,
))
if enablePrimaryUpdate {
projectionUpdate.WriteString(fmt.Sprintf(
"parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_match_data, '%s'),6) AS `%s`,",
colName,
dstColName,
))
}
default:
projection.WriteString(fmt.Sprintf("JSONExtract(_peerdb_data, '%s', '%s') AS `%s`,", colName, clickHouseType, dstColName))
if enablePrimaryUpdate {
projectionUpdate.WriteString(fmt.Sprintf(
"JSONExtract(_peerdb_match_data, '%s', '%s') AS `%s`,",
colName,
clickHouseType,
dstColName,
))
}
colSelector.WriteString(fmt.Sprintf("`%s`,", dstColName))
if clickHouseType == "" {
if colType == qvalue.QValueKindNumeric {
clickHouseType = getClickhouseTypeForNumericColumn(column)
} else {
var err error
clickHouseType, err = colType.ToDWHColumnType(protos.DBType_CLICKHOUSE)
if err != nil {
close(queries)
return nil, fmt.Errorf("error while converting column type to clickhouse type: %w", err)
}
}
if (schema.NullableEnabled || columnNullableEnabled) && column.Nullable && !colType.IsArray() {
clickHouseType = fmt.Sprintf("Nullable(%s)", clickHouseType)
}
}

// add _peerdb_sign as _peerdb_record_type / 2
projection.WriteString(fmt.Sprintf("intDiv(_peerdb_record_type, 2) AS `%s`,", signColName))
colSelector.WriteString(fmt.Sprintf("`%s`,", signColName))

// add _peerdb_timestamp as _peerdb_version
projection.WriteString(fmt.Sprintf("_peerdb_timestamp AS `%s`", versionColName))
colSelector.WriteString(versionColName)
colSelector.WriteString(") ")

selectQuery.WriteString(projection.String())
selectQuery.WriteString(" FROM ")
selectQuery.WriteString(rawTbl)
selectQuery.WriteString(" WHERE _peerdb_batch_id > ")
selectQuery.WriteString(strconv.FormatInt(normBatchID, 10))
selectQuery.WriteString(" AND _peerdb_batch_id <= ")
selectQuery.WriteString(strconv.FormatInt(req.SyncBatchID, 10))
selectQuery.WriteString(" AND _peerdb_destination_table_name = '")
selectQuery.WriteString(tbl)
selectQuery.WriteString("'")

switch clickHouseType {
case "Date32", "Nullable(Date32)":
projection.WriteString(fmt.Sprintf(
"toDate32(parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_data, '%s'),6)) AS `%s`,",
colName,
dstColName,
))
if enablePrimaryUpdate {
// projectionUpdate generates delete on previous record, so _peerdb_record_type is filled in as 2
projectionUpdate.WriteString(fmt.Sprintf("1 AS `%s`,", signColName))
// decrement timestamp by 1 so delete is ordered before latest data,
// could be same if deletion records were only generated when ordering updated
projectionUpdate.WriteString(fmt.Sprintf("_peerdb_timestamp - 1 AS `%s`", versionColName))

selectQuery.WriteString("UNION ALL SELECT ")
selectQuery.WriteString(projectionUpdate.String())
selectQuery.WriteString(" FROM ")
selectQuery.WriteString(rawTbl)
selectQuery.WriteString(" WHERE _peerdb_batch_id > ")
selectQuery.WriteString(strconv.FormatInt(normBatchID, 10))
selectQuery.WriteString(" AND _peerdb_batch_id <= ")
selectQuery.WriteString(strconv.FormatInt(req.SyncBatchID, 10))
selectQuery.WriteString(" AND _peerdb_destination_table_name = '")
selectQuery.WriteString(tbl)
selectQuery.WriteString("' AND _peerdb_record_type = 1")
projectionUpdate.WriteString(fmt.Sprintf(
"toDate32(parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_match_data, '%s'),6)) AS `%s`,",
colName,
dstColName,
))
}

insertIntoSelectQuery := strings.Builder{}
insertIntoSelectQuery.WriteString("INSERT INTO `")
insertIntoSelectQuery.WriteString(tbl)
insertIntoSelectQuery.WriteString("` ")
insertIntoSelectQuery.WriteString(colSelector.String())
insertIntoSelectQuery.WriteString(selectQuery.String())

q := insertIntoSelectQuery.String()
c.logger.Info("normalizing batch", slog.String("query", q))
if err := chConn.Exec(errCtx, q); err != nil {
return fmt.Errorf("error while inserting into normalized table: %w", err)
case "DateTime64(6)", "Nullable(DateTime64(6))":
projection.WriteString(fmt.Sprintf(
"parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_data, '%s'),6) AS `%s`,",
colName,
dstColName,
))
if enablePrimaryUpdate {
projectionUpdate.WriteString(fmt.Sprintf(
"parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_match_data, '%s'),6) AS `%s`,",
colName,
dstColName,
))
}
default:
projection.WriteString(fmt.Sprintf("JSONExtract(_peerdb_data, '%s', '%s') AS `%s`,", colName, clickHouseType, dstColName))
if enablePrimaryUpdate {
projectionUpdate.WriteString(fmt.Sprintf(
"JSONExtract(_peerdb_match_data, '%s', '%s') AS `%s`,",
colName,
clickHouseType,
dstColName,
))
}
}
}

return nil
})
// add _peerdb_sign as _peerdb_record_type / 2
projection.WriteString(fmt.Sprintf("intDiv(_peerdb_record_type, 2) AS `%s`,", signColName))
colSelector.WriteString(fmt.Sprintf("`%s`,", signColName))

// add _peerdb_timestamp as _peerdb_version
projection.WriteString(fmt.Sprintf("_peerdb_timestamp AS `%s`", versionColName))
colSelector.WriteString(versionColName)
colSelector.WriteString(") ")

selectQuery.WriteString(projection.String())
selectQuery.WriteString(" FROM ")
selectQuery.WriteString(rawTbl)
selectQuery.WriteString(" WHERE _peerdb_batch_id > ")
selectQuery.WriteString(strconv.FormatInt(normBatchID, 10))
selectQuery.WriteString(" AND _peerdb_batch_id <= ")
selectQuery.WriteString(strconv.FormatInt(req.SyncBatchID, 10))
selectQuery.WriteString(" AND _peerdb_destination_table_name = '")
selectQuery.WriteString(tbl)
selectQuery.WriteString("'")

if enablePrimaryUpdate {
// projectionUpdate generates delete on previous record, so _peerdb_record_type is filled in as 2
projectionUpdate.WriteString(fmt.Sprintf("1 AS `%s`,", signColName))
// decrement timestamp by 1 so delete is ordered before latest data,
// could be same if deletion records were only generated when ordering updated
projectionUpdate.WriteString(fmt.Sprintf("_peerdb_timestamp - 1 AS `%s`", versionColName))

selectQuery.WriteString("UNION ALL SELECT ")
selectQuery.WriteString(projectionUpdate.String())
selectQuery.WriteString(" FROM ")
selectQuery.WriteString(rawTbl)
selectQuery.WriteString(" WHERE _peerdb_batch_id > ")
selectQuery.WriteString(strconv.FormatInt(normBatchID, 10))
selectQuery.WriteString(" AND _peerdb_batch_id <= ")
selectQuery.WriteString(strconv.FormatInt(req.SyncBatchID, 10))
selectQuery.WriteString(" AND _peerdb_destination_table_name = '")
selectQuery.WriteString(tbl)
selectQuery.WriteString("' AND _peerdb_record_type = 1")
}

insertIntoSelectQuery := strings.Builder{}
insertIntoSelectQuery.WriteString("INSERT INTO `")
insertIntoSelectQuery.WriteString(tbl)
insertIntoSelectQuery.WriteString("` ")
insertIntoSelectQuery.WriteString(colSelector.String())
insertIntoSelectQuery.WriteString(selectQuery.String())

select {
case queries <- insertIntoSelectQuery.String():
case <-errCtx.Done():
close(queries)
return nil, ctx.Err()
}
}
close(queries)
if err := group.Wait(); err != nil {
return nil, err
}
group.Wait()

if err := c.UpdateNormalizeBatchID(ctx, req.FlowJobName, req.SyncBatchID); err != nil {
c.logger.Error("[clickhouse] error while updating normalize batch id", slog.Int64("BatchID", req.SyncBatchID), slog.Any("error", err))
Expand Down

0 comments on commit e7b92bf

Please sign in to comment.