diff --git a/flow/connectors/clickhouse/normalize.go b/flow/connectors/clickhouse/normalize.go index 87176b4cbd..c8d2f0a098 100644 --- a/flow/connectors/clickhouse/normalize.go +++ b/flow/connectors/clickhouse/normalize.go @@ -6,20 +6,20 @@ import ( "database/sql" "errors" "fmt" - "iter" "log/slog" "slices" "strconv" "strings" "time" + "github.com/ClickHouse/clickhouse-go/v2" + "golang.org/x/sync/errgroup" + "github.com/PeerDB-io/peer-flow/datatypes" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/peerdbenv" - "github.com/PeerDB-io/peer-flow/shared" - "golang.org/x/sync/errgroup" ) const ( @@ -280,200 +280,207 @@ func (c *ClickHouseConnector) NormalizeRecords( return nil, err } - parallelNormalize, err := peerdbenv.PeerDBClickHouseParallelNormalize(ctx, req.Env) + enablePrimaryUpdate, err := peerdbenv.PeerDBEnableClickHousePrimaryUpdate(ctx, req.Env) if err != nil { return nil, err } - var destinationTableNameChunks iter.Seq[[]string] - if parallelNormalize > 1 { - var chunkSize int - if parallelNormalize >= len(destinationTableNames) { - chunkSize = 1 - } else { - chunkSize = shared.DivCeil(len(destinationTableNames), parallelNormalize) - } - destinationTableNameChunks = slices.Chunk(destinationTableNames, chunkSize) - } else { - destinationTableNameChunks = func(yield func([]string) bool) { - yield(destinationTableNames) - } - } - - enablePrimaryUpdate, err := peerdbenv.PeerDBEnableClickHousePrimaryUpdate(ctx, req.Env) + parallelNormalize, err := peerdbenv.PeerDBClickHouseParallelNormalize(ctx, req.Env) if err != nil { return nil, err } + parallelNormalize = min(max(parallelNormalize, 1), len(destinationTableNames)) + if parallelNormalize > 1 { + c.logger.Info("normalizing in parallel", slog.Int("connections", parallelNormalize)) + } + queries := make(chan string) rawTbl := c.getRawTableName(req.FlowJobName) group, errCtx := errgroup.WithContext(ctx) - for tableNames := range destinationTableNameChunks { + for i := range parallelNormalize { group.Go(func() error { - chConn, err := Connect(errCtx, c.config) - if err != nil { - return err + var chConn clickhouse.Conn + if i == 0 { + chConn = c.database + } else { + var err error + chConn, err = Connect(errCtx, c.config) + if err != nil { + return err + } + defer chConn.Close() } - defer chConn.Close() - for _, tbl := range tableNames { - // SELECT projection FROM raw_table WHERE _peerdb_batch_id > normalize_batch_id AND _peerdb_batch_id <= sync_batch_id - selectQuery := strings.Builder{} - selectQuery.WriteString("SELECT ") + for query := range queries { + c.logger.Info("normalizing batch", slog.String("query", query)) + if err := chConn.Exec(errCtx, query); err != nil { + return fmt.Errorf("error while inserting into normalized table: %w", err) + } + } + return nil + }) + } - colSelector := strings.Builder{} - colSelector.WriteRune('(') + for _, tbl := range destinationTableNames { + // SELECT projection FROM raw_table WHERE _peerdb_batch_id > normalize_batch_id AND _peerdb_batch_id <= sync_batch_id + selectQuery := strings.Builder{} + selectQuery.WriteString("SELECT ") - schema := req.TableNameSchemaMapping[tbl] + colSelector := strings.Builder{} + colSelector.WriteRune('(') - var tableMapping *protos.TableMapping - for _, tm := range req.TableMappings { - if tm.DestinationTableIdentifier == tbl { - tableMapping = tm - break - } - } + schema := req.TableNameSchemaMapping[tbl] - projection := strings.Builder{} - projectionUpdate := strings.Builder{} - - for _, column := range schema.Columns { - colName := column.Name - dstColName := colName - colType := qvalue.QValueKind(column.Type) - - var clickHouseType string - var columnNullableEnabled bool - if tableMapping != nil { - for _, col := range tableMapping.Columns { - if col.SourceName == colName { - if col.DestinationName != "" { - dstColName = col.DestinationName - } - if col.DestinationType != "" { - // TODO can we restrict this to avoid injection? - clickHouseType = col.DestinationType - } - columnNullableEnabled = col.NullableEnabled - break - } - } - } + var tableMapping *protos.TableMapping + for _, tm := range req.TableMappings { + if tm.DestinationTableIdentifier == tbl { + tableMapping = tm + break + } + } - colSelector.WriteString(fmt.Sprintf("`%s`,", dstColName)) - if clickHouseType == "" { - if colType == qvalue.QValueKindNumeric { - clickHouseType = getClickhouseTypeForNumericColumn(column) - } else { - var err error - clickHouseType, err = colType.ToDWHColumnType(protos.DBType_CLICKHOUSE) - if err != nil { - return fmt.Errorf("error while converting column type to clickhouse type: %w", err) - } + projection := strings.Builder{} + projectionUpdate := strings.Builder{} + + for _, column := range schema.Columns { + colName := column.Name + dstColName := colName + colType := qvalue.QValueKind(column.Type) + + var clickHouseType string + var columnNullableEnabled bool + if tableMapping != nil { + for _, col := range tableMapping.Columns { + if col.SourceName == colName { + if col.DestinationName != "" { + dstColName = col.DestinationName } - if (schema.NullableEnabled || columnNullableEnabled) && column.Nullable && !colType.IsArray() { - clickHouseType = fmt.Sprintf("Nullable(%s)", clickHouseType) + if col.DestinationType != "" { + // TODO can we restrict this to avoid injection? + clickHouseType = col.DestinationType } + columnNullableEnabled = col.NullableEnabled + break } + } + } - switch clickHouseType { - case "Date32", "Nullable(Date32)": - projection.WriteString(fmt.Sprintf( - "toDate32(parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_data, '%s'),6)) AS `%s`,", - colName, - dstColName, - )) - if enablePrimaryUpdate { - projectionUpdate.WriteString(fmt.Sprintf( - "toDate32(parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_match_data, '%s'),6)) AS `%s`,", - colName, - dstColName, - )) - } - case "DateTime64(6)", "Nullable(DateTime64(6))": - projection.WriteString(fmt.Sprintf( - "parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_data, '%s'),6) AS `%s`,", - colName, - dstColName, - )) - if enablePrimaryUpdate { - projectionUpdate.WriteString(fmt.Sprintf( - "parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_match_data, '%s'),6) AS `%s`,", - colName, - dstColName, - )) - } - default: - projection.WriteString(fmt.Sprintf("JSONExtract(_peerdb_data, '%s', '%s') AS `%s`,", colName, clickHouseType, dstColName)) - if enablePrimaryUpdate { - projectionUpdate.WriteString(fmt.Sprintf( - "JSONExtract(_peerdb_match_data, '%s', '%s') AS `%s`,", - colName, - clickHouseType, - dstColName, - )) - } + colSelector.WriteString(fmt.Sprintf("`%s`,", dstColName)) + if clickHouseType == "" { + if colType == qvalue.QValueKindNumeric { + clickHouseType = getClickhouseTypeForNumericColumn(column) + } else { + var err error + clickHouseType, err = colType.ToDWHColumnType(protos.DBType_CLICKHOUSE) + if err != nil { + close(queries) + return nil, fmt.Errorf("error while converting column type to clickhouse type: %w", err) } } + if (schema.NullableEnabled || columnNullableEnabled) && column.Nullable && !colType.IsArray() { + clickHouseType = fmt.Sprintf("Nullable(%s)", clickHouseType) + } + } - // add _peerdb_sign as _peerdb_record_type / 2 - projection.WriteString(fmt.Sprintf("intDiv(_peerdb_record_type, 2) AS `%s`,", signColName)) - colSelector.WriteString(fmt.Sprintf("`%s`,", signColName)) - - // add _peerdb_timestamp as _peerdb_version - projection.WriteString(fmt.Sprintf("_peerdb_timestamp AS `%s`", versionColName)) - colSelector.WriteString(versionColName) - colSelector.WriteString(") ") - - selectQuery.WriteString(projection.String()) - selectQuery.WriteString(" FROM ") - selectQuery.WriteString(rawTbl) - selectQuery.WriteString(" WHERE _peerdb_batch_id > ") - selectQuery.WriteString(strconv.FormatInt(normBatchID, 10)) - selectQuery.WriteString(" AND _peerdb_batch_id <= ") - selectQuery.WriteString(strconv.FormatInt(req.SyncBatchID, 10)) - selectQuery.WriteString(" AND _peerdb_destination_table_name = '") - selectQuery.WriteString(tbl) - selectQuery.WriteString("'") - + switch clickHouseType { + case "Date32", "Nullable(Date32)": + projection.WriteString(fmt.Sprintf( + "toDate32(parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_data, '%s'),6)) AS `%s`,", + colName, + dstColName, + )) if enablePrimaryUpdate { - // projectionUpdate generates delete on previous record, so _peerdb_record_type is filled in as 2 - projectionUpdate.WriteString(fmt.Sprintf("1 AS `%s`,", signColName)) - // decrement timestamp by 1 so delete is ordered before latest data, - // could be same if deletion records were only generated when ordering updated - projectionUpdate.WriteString(fmt.Sprintf("_peerdb_timestamp - 1 AS `%s`", versionColName)) - - selectQuery.WriteString("UNION ALL SELECT ") - selectQuery.WriteString(projectionUpdate.String()) - selectQuery.WriteString(" FROM ") - selectQuery.WriteString(rawTbl) - selectQuery.WriteString(" WHERE _peerdb_batch_id > ") - selectQuery.WriteString(strconv.FormatInt(normBatchID, 10)) - selectQuery.WriteString(" AND _peerdb_batch_id <= ") - selectQuery.WriteString(strconv.FormatInt(req.SyncBatchID, 10)) - selectQuery.WriteString(" AND _peerdb_destination_table_name = '") - selectQuery.WriteString(tbl) - selectQuery.WriteString("' AND _peerdb_record_type = 1") + projectionUpdate.WriteString(fmt.Sprintf( + "toDate32(parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_match_data, '%s'),6)) AS `%s`,", + colName, + dstColName, + )) } - - insertIntoSelectQuery := strings.Builder{} - insertIntoSelectQuery.WriteString("INSERT INTO `") - insertIntoSelectQuery.WriteString(tbl) - insertIntoSelectQuery.WriteString("` ") - insertIntoSelectQuery.WriteString(colSelector.String()) - insertIntoSelectQuery.WriteString(selectQuery.String()) - - q := insertIntoSelectQuery.String() - c.logger.Info("normalizing batch", slog.String("query", q)) - if err := chConn.Exec(errCtx, q); err != nil { - return fmt.Errorf("error while inserting into normalized table: %w", err) + case "DateTime64(6)", "Nullable(DateTime64(6))": + projection.WriteString(fmt.Sprintf( + "parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_data, '%s'),6) AS `%s`,", + colName, + dstColName, + )) + if enablePrimaryUpdate { + projectionUpdate.WriteString(fmt.Sprintf( + "parseDateTime64BestEffortOrNull(JSONExtractString(_peerdb_match_data, '%s'),6) AS `%s`,", + colName, + dstColName, + )) + } + default: + projection.WriteString(fmt.Sprintf("JSONExtract(_peerdb_data, '%s', '%s') AS `%s`,", colName, clickHouseType, dstColName)) + if enablePrimaryUpdate { + projectionUpdate.WriteString(fmt.Sprintf( + "JSONExtract(_peerdb_match_data, '%s', '%s') AS `%s`,", + colName, + clickHouseType, + dstColName, + )) } } + } - return nil - }) + // add _peerdb_sign as _peerdb_record_type / 2 + projection.WriteString(fmt.Sprintf("intDiv(_peerdb_record_type, 2) AS `%s`,", signColName)) + colSelector.WriteString(fmt.Sprintf("`%s`,", signColName)) + + // add _peerdb_timestamp as _peerdb_version + projection.WriteString(fmt.Sprintf("_peerdb_timestamp AS `%s`", versionColName)) + colSelector.WriteString(versionColName) + colSelector.WriteString(") ") + + selectQuery.WriteString(projection.String()) + selectQuery.WriteString(" FROM ") + selectQuery.WriteString(rawTbl) + selectQuery.WriteString(" WHERE _peerdb_batch_id > ") + selectQuery.WriteString(strconv.FormatInt(normBatchID, 10)) + selectQuery.WriteString(" AND _peerdb_batch_id <= ") + selectQuery.WriteString(strconv.FormatInt(req.SyncBatchID, 10)) + selectQuery.WriteString(" AND _peerdb_destination_table_name = '") + selectQuery.WriteString(tbl) + selectQuery.WriteString("'") + + if enablePrimaryUpdate { + // projectionUpdate generates delete on previous record, so _peerdb_record_type is filled in as 2 + projectionUpdate.WriteString(fmt.Sprintf("1 AS `%s`,", signColName)) + // decrement timestamp by 1 so delete is ordered before latest data, + // could be same if deletion records were only generated when ordering updated + projectionUpdate.WriteString(fmt.Sprintf("_peerdb_timestamp - 1 AS `%s`", versionColName)) + + selectQuery.WriteString("UNION ALL SELECT ") + selectQuery.WriteString(projectionUpdate.String()) + selectQuery.WriteString(" FROM ") + selectQuery.WriteString(rawTbl) + selectQuery.WriteString(" WHERE _peerdb_batch_id > ") + selectQuery.WriteString(strconv.FormatInt(normBatchID, 10)) + selectQuery.WriteString(" AND _peerdb_batch_id <= ") + selectQuery.WriteString(strconv.FormatInt(req.SyncBatchID, 10)) + selectQuery.WriteString(" AND _peerdb_destination_table_name = '") + selectQuery.WriteString(tbl) + selectQuery.WriteString("' AND _peerdb_record_type = 1") + } + + insertIntoSelectQuery := strings.Builder{} + insertIntoSelectQuery.WriteString("INSERT INTO `") + insertIntoSelectQuery.WriteString(tbl) + insertIntoSelectQuery.WriteString("` ") + insertIntoSelectQuery.WriteString(colSelector.String()) + insertIntoSelectQuery.WriteString(selectQuery.String()) + + select { + case queries <- insertIntoSelectQuery.String(): + case <-errCtx.Done(): + close(queries) + return nil, ctx.Err() + } + } + close(queries) + if err := group.Wait(); err != nil { + return nil, err } - group.Wait() if err := c.UpdateNormalizeBatchID(ctx, req.FlowJobName, req.SyncBatchID); err != nil { c.logger.Error("[clickhouse] error while updating normalize batch id", slog.Int64("BatchID", req.SyncBatchID), slog.Any("error", err))