Skip to content

Commit

Permalink
[GLUTEN-7028][CH] Support write parquet files with bucket
Browse files Browse the repository at this point in the history
  • Loading branch information
lwz9103 committed Nov 26, 2024
1 parent 4dfdfd7 commit 0e2853a
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ case class FileDeltaColumnarWrite(
val guidPattern =
""".*-([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})(?:-c(\d+)\..*)?$""".r
val fileNamePattern =
guidPattern.replaceAllIn(writeFileName, m => writeFileName.replace(m.group(1), "{}"))
guidPattern.replaceAllIn(
writeFileName,
m => writeFileName.replace(m.group(1), FileNamePlaceHolder.ID))

logDebug(s"Native staging write path: $writePath and with pattern: $fileNamePattern")
val settings =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,20 +246,11 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
}
}

def validateBucketSpec(): Option[String] = {
if (bucketSpec.nonEmpty) {
Some("Unsupported native write: bucket write is not supported.")
} else {
None
}
}

validateCompressionCodec()
.orElse(validateFileFormat())
.orElse(validateFieldMetadata())
.orElse(validateDateTypes())
.orElse(validateWriteFilesOptions())
.orElse(validateBucketSpec()) match {
.orElse(validateWriteFilesOptions()) match {
case Some(reason) => ValidationResult.failed(reason)
case _ => ValidationResult.succeeded
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,11 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
* This function used to inject the staging write path before initializing the native plan.Only
* used in a pipeline model (spark 3.5) for writing parquet or orc files.
*/
override def injectWriteFilesTempPath(path: String, fileName: String): Unit = {
override def injectWriteFilesTempPath(path: String, filePattern: String): Unit = {
val settings =
Map(
RuntimeSettings.TASK_WRITE_TMP_DIR.key -> path,
RuntimeSettings.TASK_WRITE_FILENAME.key -> fileName)
RuntimeSettings.TASK_WRITE_FILENAME_PATTERN.key -> filePattern)
NativeExpressionEvaluator.updateQueryRuntimeSettings(settings)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ object CHRuleApi {
injector.injectPreTransform(_ => RewriteSubqueryBroadcast())
injector.injectPreTransform(c => FallbackBroadcastHashJoin.apply(c.session))
injector.injectPreTransform(c => MergeTwoPhasesHashBaseAggregate.apply(c.session))
injector.injectPreTransform(_ => WriteFilesWithBucketValue)

// Legacy: The legacy transform rule.
val validatorBuilder: GlutenConfig => Validator = conf =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension

import org.apache.spark.sql.catalyst.expressions.{Alias, BitwiseAnd, HiveHash, Literal, Pmod}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.{BucketingUtils, WriteFilesExec, WriterBucketSpec}

/**
* Wrap with bucket value to specify the bucket file name in native write. Native writer will remove
* this value in the final output.
*/
object WriteFilesWithBucketValue extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
plan.transformDown {
case writeFiles: WriteFilesExec if writeFiles.bucketSpec.isDefined =>
val spec = getWriterBucketSpec(writeFiles)
val wrapBucketValue = ProjectExec(
writeFiles.child.output :+ Alias(spec.bucketIdExpression, "__bucket_value__")(),
writeFiles.child)
writeFiles.copy(child = wrapBucketValue)
}
}

private def getWriterBucketSpec(writeFilesExec: WriteFilesExec): WriterBucketSpec = {
val partitionColumns = writeFilesExec.partitionColumns
val outputColumns = writeFilesExec.child.output
val dataColumns = outputColumns.filterNot(partitionColumns.contains)
val bucketSpec = writeFilesExec.bucketSpec.get
val bucketColumns = bucketSpec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
if (
writeFilesExec.options.getOrElse(
BucketingUtils.optionForHiveCompatibleBucketWrite,
"false") == "true"
) {
val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
val bucketIdExpression = Pmod(hashId, Literal(bucketSpec.numBuckets))
// The bucket file name prefix is following Hive, Presto and Trino conversion, so this
// makes sure Hive bucketed table written by Spark, can be read by other SQL engines.
//
// Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
// Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
WriterBucketSpec(bucketIdExpression, fileNamePrefix)
} else {
// Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id
// expression, so that we can guarantee the data distribution is same between shuffle and
// bucketed data source, which enables us to only shuffle one side when join a bucketed
// table and a normal one.
val bucketIdExpression =
HashPartitioning(bucketColumns, bucketSpec.numBuckets).partitionIdExpression
WriterBucketSpec(bucketIdExpression, (_: Int) => "")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ object CreateFileNameSpec {
}
}

// More details in local_engine::FileNameGenerator in NormalFileWriter.cpp
object FileNamePlaceHolder {
val ID = "{id}"
val BUCKET = "{bucket}"
}

/** [[HadoopMapReduceAdapter]] for [[HadoopMapReduceCommitProtocol]]. */
case class HadoopMapReduceAdapter(sparkCommitter: HadoopMapReduceCommitProtocol) {
private lazy val committer: OutputCommitter = {
Expand Down Expand Up @@ -132,12 +138,18 @@ case class HadoopMapReduceAdapter(sparkCommitter: HadoopMapReduceCommitProtocol)
GetFilename.invoke(sparkCommitter, taskContext, spec).asInstanceOf[String]
}

def getTaskAttemptTempPathAndFilename(
def getTaskAttemptTempPathAndFilePattern(
taskContext: TaskAttemptContext,
description: WriteJobDescription): (String, String) = {
val stageDir = newTaskAttemptTempPath(description.path)
val filename = getFilename(taskContext, CreateFileNameSpec(taskContext, description))
(stageDir, filename)
if (description.bucketSpec.isEmpty) {
val filename = getFilename(taskContext, CreateFileNameSpec(taskContext, description))
(stageDir, filename)
} else {
val filePart = getFilename(taskContext, FileNameSpec("", ""))
val fileSuffix = CreateFileNameSpec(taskContext, description).suffix
(stageDir, s"${filePart}_${FileNamePlaceHolder.BUCKET}$fileSuffix")
}
}
}

Expand Down Expand Up @@ -234,10 +246,10 @@ case class HadoopMapReduceCommitProtocolWrite(
* initializing the native plan and collect native write files metrics for each backend.
*/
override def doSetupNativeTask(): Unit = {
val (writePath, writeFileName) =
adapter.getTaskAttemptTempPathAndFilename(taskAttemptContext, description)
logDebug(s"Native staging write path: $writePath and file name: $writeFileName")
BackendsApiManager.getIteratorApiInstance.injectWriteFilesTempPath(writePath, writeFileName)
val (writePath, writeFilePattern) =
adapter.getTaskAttemptTempPathAndFilePattern(taskAttemptContext, description)
logDebug(s"Native staging write path: $writePath and file name: $writeFilePattern")
BackendsApiManager.getIteratorApiInstance.injectWriteFilesTempPath(writePath, writeFilePattern)
}

def doCollectNativeResult(stats: Seq[InternalRow]): Option[WriteTaskResult] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ class GlutenClickHouseNativeWriteTableSuite
// spark write does not support bucketed table
// https://issues.apache.org/jira/browse/SPARK-19256
val table_name = table_name_template.format(format)
writeAndCheckRead(origin_table, table_name, fields_.keys.toSeq, isSparkVersionLE("3.3")) {
writeAndCheckRead(origin_table, table_name, fields_.keys.toSeq) {
fields =>
spark
.table("origin_table")
Expand Down Expand Up @@ -658,7 +658,7 @@ class GlutenClickHouseNativeWriteTableSuite
nativeWrite {
format =>
val table_name = table_name_template.format(format)
writeAndCheckRead(origin_table, table_name, fields.keys.toSeq, isSparkVersionLE("3.3")) {
writeAndCheckRead(origin_table, table_name, fields.keys.toSeq) {
fields =>
spark
.table("origin_table")
Expand Down Expand Up @@ -762,7 +762,7 @@ class GlutenClickHouseNativeWriteTableSuite
format =>
val table_name = table_name_template.format(format)
spark.sql(s"drop table IF EXISTS $table_name")
withNativeWriteCheck(checkNative = isSparkVersionLE("3.3")) {
withNativeWriteCheck(checkNative = true) {
spark
.range(10000000)
.selectExpr("id", "cast('2020-01-01' as date) as p")
Expand Down Expand Up @@ -798,7 +798,7 @@ class GlutenClickHouseNativeWriteTableSuite
format =>
val table_name = table_name_template.format(format)
spark.sql(s"drop table IF EXISTS $table_name")
withNativeWriteCheck(checkNative = isSparkVersionLE("3.3")) {
withNativeWriteCheck(checkNative = true) {
spark
.range(30000)
.selectExpr("id", "cast(null as string) as p")
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ DB::ProcessorPtr make_sink(
const DB::Block & input_header,
const DB::Block & output_header,
const std::string & base_path,
const FileNameGenerator & generator,
FileNameGenerator & generator,
const std::string & format_hint,
const std::shared_ptr<WriteStats> & stats)
{
Expand Down
69 changes: 60 additions & 9 deletions cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,40 @@ class WriteStats : public WriteStatsBase

struct FileNameGenerator
{
// Align with org.apache.spark.sql.execution.FileNamePlaceHolder
const std::vector<std::string> placeholders = {"{id}", "{bucket}"};
const bool pattern;
const std::string filename_or_pattern;
std::unordered_map<std::string, std::string> args;

std::string generate() const
std::string generate()
{
if (pattern)
return fmt::vformat(filename_or_pattern, fmt::make_format_args(toString(DB::UUIDHelpers::generateV4())));
{
args["{id}"] = toString(DB::UUIDHelpers::generateV4());
return pattern_format(filename_or_pattern);
}
return filename_or_pattern;
}

std::string pattern_format(std::string format_str)
{
for(const std::string& placeholder: placeholders)
{
auto it = args.find(placeholder);
if (it == args.end())
continue;

std::string replacement = it->second;
size_t pos = format_str.find(placeholder);
while (pos != std::string::npos)
{
format_str.replace(pos, placeholder.length(), replacement);
pos = format_str.find(placeholder, pos + placeholder.length());
}
}
return format_str;
}
};

class SubstraitFileSink final : public DB::SinkToStorage
Expand Down Expand Up @@ -287,7 +312,18 @@ class SubstraitFileSink final : public DB::SinkToStorage
delta_stats_.update(chunk);
if (!output_format_) [[unlikely]]
output_format_ = format_file_->createOutputFormat();
output_format_->output->write(materializeBlock(getHeader().cloneWithColumns(chunk.detachColumns())));

const DB::Block & input_header = getHeader();
if (input_header.getNames().back() == "__bucket_value__")
{
chunk.erase(input_header.columns() - 1);
const DB::ColumnsWithTypeAndName & cols = input_header.getColumnsWithTypeAndName();
DB::ColumnsWithTypeAndName with_bucket_cols(cols.begin(), cols.end() - 1);
DB::Block without_bucket_header = DB::Block(with_bucket_cols);
output_format_->output->write(materializeBlock(without_bucket_header.cloneWithColumns(chunk.detachColumns())));
}
else
output_format_->output->write(materializeBlock(input_header.cloneWithColumns(chunk.detachColumns())));
}
void onFinish() override
{
Expand All @@ -309,7 +345,7 @@ class SparkPartitionedBaseSink : public DB::PartitionedSink

public:
/// visible for UTs
static DB::ASTPtr make_partition_expression(const DB::Names & partition_columns)
static DB::ASTPtr make_partition_expression(const DB::Names & partition_columns, const DB::Block & input_header)
{
/// Parse the following expression into ASTs
/// cancat('/col_name=', 'toString(col_name)')
Expand All @@ -328,6 +364,12 @@ class SparkPartitionedBaseSink : public DB::PartitionedSink
DB::ASTs if_null_args{
makeASTFunction("toString", DB::ASTs{column_ast}), std::make_shared<DB::ASTLiteral>(DEFAULT_PARTITION_NAME)};
arguments.emplace_back(makeASTFunction("ifNull", std::move(if_null_args)));

if (input_header.getNames().back() == "__bucket_value__")
{
DB::ASTs args {std::make_shared<DB::ASTLiteral>("%05d"), std::make_shared<DB::ASTIdentifier>("__bucket_value__")};
arguments.emplace_back(DB::makeASTFunction("printf", std::move(args)));
}
}
return DB::makeASTFunction("concat", std::move(arguments));
}
Expand All @@ -343,7 +385,7 @@ class SparkPartitionedBaseSink : public DB::PartitionedSink
const DB::Names & partition_by,
const DB::Block & input_header,
const std::shared_ptr<WriteStatsBase> & stats)
: PartitionedSink(make_partition_expression(partition_by), context, input_header)
: PartitionedSink(make_partition_expression(partition_by, input_header), context, input_header)
, context_(context)
, stats_(stats)
, empty_delta_stats_(DeltaStats::create(input_header, partition_by))
Expand All @@ -354,7 +396,8 @@ class SparkPartitionedBaseSink : public DB::PartitionedSink
class SubstraitPartitionedFileSink final : public SparkPartitionedBaseSink
{
const std::string base_path_;
const FileNameGenerator generator_;
FileNameGenerator generator_;
const DB::Block input_header_;
const DB::Block sample_block_;
const std::string format_hint_;

Expand All @@ -365,25 +408,33 @@ class SubstraitPartitionedFileSink final : public SparkPartitionedBaseSink
const DB::Block & input_header,
const DB::Block & sample_block,
const std::string & base_path,
const FileNameGenerator & generator,
FileNameGenerator & generator,
const std::string & format_hint,
const std::shared_ptr<WriteStatsBase> & stats)
: SparkPartitionedBaseSink(context, partition_by, input_header, stats)
, base_path_(base_path)
, generator_(generator)
, sample_block_(sample_block)
, input_header_(input_header)
, format_hint_(format_hint)
{
}

DB::SinkPtr createSinkForPartition(const String & partition_id) override
{
assert(stats_);
std::string real_partition_id = partition_id;
if (input_header_.getNames().back() == "__bucket_value__")
{
std::string bucket_val = partition_id.substr(partition_id.length() - 5, 5);
real_partition_id = partition_id.substr(0, partition_id.length() - 5);
generator_.args["{bucket}"] = bucket_val;
}
std::string filename = generator_.generate();
const auto partition_path = fmt::format("{}/{}", partition_id, filename);
const auto partition_path = fmt::format("{}/{}", real_partition_id, filename);
validatePartitionKey(partition_path, true);
return std::make_shared<SubstraitFileSink>(
context_, base_path_, partition_id, filename, format_hint_, sample_block_, stats_, empty_delta_stats_);
context_, base_path_, real_partition_id, filename, format_hint_, sample_block_, stats_, empty_delta_stats_);
}
String getName() const override { return "SubstraitPartitionedFileSink"; }
};
Expand Down
1 change: 0 additions & 1 deletion cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ Block OutputFormatFile::createHeaderWithPreferredSchema(const Block & header)
ColumnWithTypeAndName column(preferred_column.type->createColumn(), preferred_column.type, preferred_column.name);
columns.emplace_back(std::move(column));
}
assert(preferred_schema.columns() == index);
return {std::move(columns)};
}

Expand Down
6 changes: 3 additions & 3 deletions cpp-ch/local-engine/tests/gtest_write_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,14 @@ TEST(WritePipeline, SubstraitPartitionedFileSink)
TEST(WritePipeline, ComputePartitionedExpression)
{
const auto context = DB::Context::createCopy(QueryContext::globalContext());

auto partition_by = SubstraitPartitionedFileSink::make_partition_expression({"s_nationkey", "name"});

Block sample_block{{STRING(), "name"}, {UINT(), "s_nationkey"}};
auto partition_by = SubstraitPartitionedFileSink::make_partition_expression({"s_nationkey", "name"}, sample_block);
// auto partition_by = printColumn("s_nationkey");

ASTs arguments(1, partition_by);
ASTPtr partition_by_string = makeASTFunction("toString", std::move(arguments));

Block sample_block{{STRING(), "name"}, {UINT(), "s_nationkey"}};

auto syntax_result = TreeRewriter(context).analyze(partition_by_string, sample_block.getNamesAndTypesList());
auto partition_by_expr = ExpressionAnalyzer(partition_by_string, syntax_result, context).getActions(false);
Expand Down

0 comments on commit 0e2853a

Please sign in to comment.