From 0952dea254df6fc4f1f01a9e0e8ac50f97285233 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 18 Jun 2024 14:24:16 +0800 Subject: [PATCH 1/6] Fallback non-UTC TimeZoneAwareExpression with zoneId [databricks] (#10996) * Fallback non-UTC TimeZoneAwareExpression with zoneId instead of timeZone config Signed-off-by: Haoyang Li * clean up Signed-off-by: Haoyang Li --------- Signed-off-by: Haoyang Li --- .../main/scala/com/nvidia/spark/rapids/RapidsMeta.scala | 8 ++++---- .../spark/sql/rapids/utils/RapidsTestSettings.scala | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index a876ea6c9e0..984892cd787 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1123,7 +1123,7 @@ abstract class BaseExprMeta[INPUT <: Expression]( if (!needTimeZoneCheck) return // Level 2 check - if (!isTimeZoneSupported) return checkUTCTimezone(this) + if (!isTimeZoneSupported) return checkUTCTimezone(this, getZoneId()) // Level 3 check val zoneId = getZoneId() @@ -1203,8 +1203,8 @@ abstract class BaseExprMeta[INPUT <: Expression]( * * @param meta to check whether it's UTC */ - def checkUTCTimezone(meta: RapidsMeta[_, _, _]): Unit = { - if (!GpuOverrides.isUTCTimezone()) { + def checkUTCTimezone(meta: RapidsMeta[_, _, _], zoneId: ZoneId): Unit = { + if (!GpuOverrides.isUTCTimezone(zoneId)) { meta.willNotWorkOnGpu( TimeZoneDB.nonUTCTimezoneNotSupportedStr(meta.wrapped.getClass.toString)) } diff --git a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala index 4cf155041d9..63649376829 100644 --- a/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala +++ b/tests/src/test/spark330/scala/org/apache/spark/sql/rapids/utils/RapidsTestSettings.scala @@ -48,11 +48,9 @@ class RapidsTestSettings extends BackendTestSettings { .exclude("from_json - input=empty array, schema=struct, output=single row with null", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10849")) .exclude("from_json - input=empty object, schema=struct, output=single row with null", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10849")) .exclude("SPARK-20549: from_json bad UTF-8", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10849")) - .exclude("from_json with timestamp", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10849")) .exclude("to_json - array", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10849")) .exclude("to_json - array with single empty row", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10849")) .exclude("to_json - empty array", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10849")) - .exclude("to_json with timestamp", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10849")) .exclude("SPARK-21513: to_json support map[string, struct] to json", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10849")) .exclude("SPARK-21513: to_json support map[struct, struct] to json", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10849")) .exclude("SPARK-21513: to_json support map[string, integer] to json", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10849")) From 7bac3a6439c10efb1961d3c4ba028128d9dca249 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Wed, 19 Jun 2024 09:44:48 +0800 Subject: [PATCH 2/6] [FEA] Introduce low shuffle merge. (#10979) * feat: Introduce low shuffle merge. Signed-off-by: liurenjie1024 * fix * Test databricks parallel * Test more databricks parallel * Fix comments * Config && scala 2.13 * Revert * Fix comments * scala 2.13 * Revert unnecessary changes * Revert "Revert unnecessary changes" This reverts commit 9fa4cf268cc3fce4d2732e04cb33eb53e4859c99. * restore change --------- Signed-off-by: liurenjie1024 --- aggregator/pom.xml | 4 + .../GpuDeltaParquetFileFormatUtils.scala | 160 +++ .../nvidia/spark/rapids/delta/deltaUDFs.scala | 83 +- .../delta/delta24x/Delta24xProvider.scala | 5 +- .../GpuDelta24xParquetFileFormat.scala | 61 +- .../delta/delta24x/MergeIntoCommandMeta.scala | 58 +- .../delta24x/GpuLowShuffleMergeCommand.scala | 1084 +++++++++++++++++ .../rapids/GpuLowShuffleMergeCommand.scala | 1083 ++++++++++++++++ .../delta/GpuDeltaParquetFileFormat.scala | 63 +- .../shims/MergeIntoCommandMetaShim.scala | 101 +- .../advanced_configs.md | 6 + .../delta_lake_low_shuffle_merge_test.py | 165 +++ .../main/python/delta_lake_merge_common.py | 155 +++ .../src/main/python/delta_lake_merge_test.py | 127 +- pom.xml | 10 + scala2.13/aggregator/pom.xml | 4 + scala2.13/pom.xml | 10 + scala2.13/sql-plugin/pom.xml | 4 + sql-plugin/pom.xml | 4 + .../com/nvidia/spark/rapids/RapidsConf.scala | 28 + 20 files changed, 3061 insertions(+), 154 deletions(-) create mode 100644 delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormatUtils.scala create mode 100644 delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/GpuLowShuffleMergeCommand.scala create mode 100644 delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuLowShuffleMergeCommand.scala create mode 100644 integration_tests/src/main/python/delta_lake_low_shuffle_merge_test.py create mode 100644 integration_tests/src/main/python/delta_lake_merge_common.py diff --git a/aggregator/pom.xml b/aggregator/pom.xml index 22bfe11105e..8cf881419c9 100644 --- a/aggregator/pom.xml +++ b/aggregator/pom.xml @@ -94,6 +94,10 @@ com.google.flatbuffers ${rapids.shade.package}.com.google.flatbuffers + + org.roaringbitmap + ${rapids.shade.package}.org.roaringbitmap + diff --git a/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormatUtils.scala b/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormatUtils.scala new file mode 100644 index 00000000000..101a82da830 --- /dev/null +++ b/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormatUtils.scala @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta + +import ai.rapids.cudf.{ColumnVector => CudfColumnVector, Scalar, Table} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.GpuColumnVector +import org.roaringbitmap.longlong.{PeekableLongIterator, Roaring64Bitmap} + +import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} + + +object GpuDeltaParquetFileFormatUtils { + /** + * Row number of the row in the file. When used with [[FILE_PATH_COL]] together, it can be used + * as unique id of a row in file. Currently to correctly calculate this, the caller needs to + * set both [[isSplitable]] to false, and [[RapidsConf.PARQUET_READER_TYPE]] to "PERFILE". + */ + val METADATA_ROW_IDX_COL: String = "__metadata_row_index" + val METADATA_ROW_IDX_FIELD: StructField = StructField(METADATA_ROW_IDX_COL, LongType, + nullable = false) + + val METADATA_ROW_DEL_COL: String = "__metadata_row_del" + val METADATA_ROW_DEL_FIELD: StructField = StructField(METADATA_ROW_DEL_COL, BooleanType, + nullable = false) + + + /** + * File path of the file that the row came from. + */ + val FILE_PATH_COL: String = "_metadata_file_path" + val FILE_PATH_FIELD: StructField = StructField(FILE_PATH_COL, StringType, nullable = false) + + /** + * Add a metadata column to the iterator. Currently only support [[METADATA_ROW_IDX_COL]]. + */ + def addMetadataColumnToIterator( + schema: StructType, + delVector: Option[Roaring64Bitmap], + input: Iterator[ColumnarBatch], + maxBatchSize: Int): Iterator[ColumnarBatch] = { + val metadataRowIndexCol = schema.fieldNames.indexOf(METADATA_ROW_IDX_COL) + val delRowIdx = schema.fieldNames.indexOf(METADATA_ROW_DEL_COL) + if (metadataRowIndexCol == -1 && delRowIdx == -1) { + return input + } + var rowIndex = 0L + input.map { batch => + withResource(batch) { _ => + val rowIdxCol = if (metadataRowIndexCol == -1) { + None + } else { + Some(metadataRowIndexCol) + } + + val delRowIdx2 = if (delRowIdx == -1) { + None + } else { + Some(delRowIdx) + } + val newBatch = addMetadataColumns(rowIdxCol, delRowIdx2, delVector,maxBatchSize, + rowIndex, batch) + rowIndex += batch.numRows() + newBatch + } + } + } + + private def addMetadataColumns( + rowIdxPos: Option[Int], + delRowIdx: Option[Int], + delVec: Option[Roaring64Bitmap], + maxBatchSize: Int, + rowIdxStart: Long, + batch: ColumnarBatch): ColumnarBatch = { + val rowIdxCol = rowIdxPos.map { _ => + withResource(Scalar.fromLong(rowIdxStart)) { start => + GpuColumnVector.from(CudfColumnVector.sequence(start, batch.numRows()), + METADATA_ROW_IDX_FIELD.dataType) + } + } + + closeOnExcept(rowIdxCol) { rowIdxCol => + + val delVecCol = delVec.map { delVec => + withResource(Scalar.fromBool(false)) { s => + withResource(CudfColumnVector.fromScalar(s, batch.numRows())) { c => + var table = new Table(c) + val posIter = new RoaringBitmapIterator( + delVec.getLongIteratorFrom(rowIdxStart), + rowIdxStart, + rowIdxStart + batch.numRows(), + ).grouped(Math.min(maxBatchSize, batch.numRows())) + + for (posChunk <- posIter) { + withResource(CudfColumnVector.fromLongs(posChunk: _*)) { poses => + withResource(Scalar.fromBool(true)) { s => + table = withResource(table) { _ => + Table.scatter(Array(s), poses, table) + } + } + } + } + + withResource(table) { _ => + GpuColumnVector.from(table.getColumn(0).incRefCount(), + METADATA_ROW_DEL_FIELD.dataType) + } + } + } + } + + closeOnExcept(delVecCol) { delVecCol => + // Replace row_idx column + val columns = new Array[ColumnVector](batch.numCols()) + for (i <- 0 until batch.numCols()) { + if (rowIdxPos.contains(i)) { + columns(i) = rowIdxCol.get + } else if (delRowIdx.contains(i)) { + columns(i) = delVecCol.get + } else { + columns(i) = batch.column(i) match { + case gpuCol: GpuColumnVector => gpuCol.incRefCount() + case col => col + } + } + } + + new ColumnarBatch(columns, batch.numRows()) + } + } + } +} + +class RoaringBitmapIterator(val inner: PeekableLongIterator, val start: Long, val end: Long) + extends Iterator[Long] { + + override def hasNext: Boolean = { + inner.hasNext && inner.peekNext() < end + } + + override def next(): Long = { + inner.next() - start + } +} diff --git a/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/deltaUDFs.scala b/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/deltaUDFs.scala index 6b2c63407d7..9893545a4ad 100644 --- a/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/deltaUDFs.scala +++ b/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/deltaUDFs.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,12 +16,19 @@ package com.nvidia.spark.rapids.delta +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + import ai.rapids.cudf.{ColumnVector, Scalar, Table} import ai.rapids.cudf.Table.DuplicateKeepOption import com.nvidia.spark.RapidsUDF import com.nvidia.spark.rapids.Arm.withResource +import org.roaringbitmap.longlong.Roaring64Bitmap +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.types.{BinaryType, DataType, SQLUserDefinedType, UserDefinedType} import org.apache.spark.util.AccumulatorV2 class GpuDeltaRecordTouchedFileNameUDF(accum: AccumulatorV2[String, java.util.Set[String]]) @@ -73,3 +80,77 @@ class GpuDeltaMetricUpdateUDF(metric: SQLMetric) } } } + +class GpuDeltaNoopUDF extends Function1[Boolean, Boolean] with RapidsUDF with Serializable { + override def apply(v1: Boolean): Boolean = v1 + + override def evaluateColumnar(numRows: Int, args: ColumnVector*): ColumnVector = { + require(args.length == 1) + args(0).incRefCount() + } +} + +@SQLUserDefinedType(udt = classOf[RoaringBitmapUDT]) +case class RoaringBitmapWrapper(inner: Roaring64Bitmap) { + def serializeToBytes(): Array[Byte] = { + withResource(new ByteArrayOutputStream()) { bout => + withResource(new DataOutputStream(bout)) { dao => + inner.serialize(dao) + } + bout.toByteArray + } + } +} + +object RoaringBitmapWrapper { + def deserializeFromBytes(bytes: Array[Byte]): RoaringBitmapWrapper = { + withResource(new ByteArrayInputStream(bytes)) { bin => + withResource(new DataInputStream(bin)) { din => + val ret = RoaringBitmapWrapper(new Roaring64Bitmap) + ret.inner.deserialize(din) + ret + } + } + } +} + +class RoaringBitmapUDT extends UserDefinedType[RoaringBitmapWrapper] { + + override def sqlType: DataType = BinaryType + + override def serialize(obj: RoaringBitmapWrapper): Any = { + obj.serializeToBytes() + } + + override def deserialize(datum: Any): RoaringBitmapWrapper = { + datum match { + case b: Array[Byte] => RoaringBitmapWrapper.deserializeFromBytes(b) + case t => throw new IllegalArgumentException(s"t: ${t.getClass}") + } + } + + override def userClass: Class[RoaringBitmapWrapper] = classOf[RoaringBitmapWrapper] + + override def typeName: String = "RoaringBitmap" +} + +object RoaringBitmapUDAF extends Aggregator[Long, RoaringBitmapWrapper, RoaringBitmapWrapper] { + override def zero: RoaringBitmapWrapper = RoaringBitmapWrapper(new Roaring64Bitmap()) + + override def reduce(b: RoaringBitmapWrapper, a: Long): RoaringBitmapWrapper = { + b.inner.addLong(a) + b + } + + override def merge(b1: RoaringBitmapWrapper, b2: RoaringBitmapWrapper): RoaringBitmapWrapper = { + val ret = b1.inner.clone() + ret.or(b2.inner) + RoaringBitmapWrapper(ret) + } + + override def finish(reduction: RoaringBitmapWrapper): RoaringBitmapWrapper = reduction + + override def bufferEncoder: Encoder[RoaringBitmapWrapper] = ExpressionEncoder() + + override def outputEncoder: Encoder[RoaringBitmapWrapper] = ExpressionEncoder() +} diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala index d3f952b856c..f90f31300e5 100644 --- a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -74,7 +74,8 @@ object Delta24xProvider extends DeltaIOProvider { override def getReadFileFormat(format: FileFormat): FileFormat = { val cpuFormat = format.asInstanceOf[DeltaParquetFileFormat] - GpuDelta24xParquetFileFormat(cpuFormat.metadata, cpuFormat.isSplittable) + GpuDelta24xParquetFileFormat(cpuFormat.metadata, cpuFormat.isSplittable, + cpuFormat.disablePushDowns, cpuFormat.broadcastDvMap) } override def convertToGpu( diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/GpuDelta24xParquetFileFormat.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/GpuDelta24xParquetFileFormat.scala index 709df7e9416..ef579d78e6f 100644 --- a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/GpuDelta24xParquetFileFormat.scala +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/GpuDelta24xParquetFileFormat.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,18 +16,32 @@ package com.nvidia.spark.rapids.delta.delta24x -import com.nvidia.spark.rapids.delta.GpuDeltaParquetFileFormat +import java.net.URI + +import com.nvidia.spark.rapids.{GpuMetric, RapidsConf} +import com.nvidia.spark.rapids.delta.{GpuDeltaParquetFileFormat, RoaringBitmapWrapper} +import com.nvidia.spark.rapids.delta.GpuDeltaParquetFileFormatUtils.addMetadataColumnToIterator +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.delta.{DeltaColumnMappingMode, IdMapping} +import org.apache.spark.sql.delta.DeltaParquetFileFormat.DeletionVectorDescriptorWithFilterType import org.apache.spark.sql.delta.actions.Metadata +import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch case class GpuDelta24xParquetFileFormat( metadata: Metadata, - isSplittable: Boolean) extends GpuDeltaParquetFileFormat { + isSplittable: Boolean, + disablePushDown: Boolean, + broadcastDvMap: Option[Broadcast[Map[URI, DeletionVectorDescriptorWithFilterType]]]) + extends GpuDeltaParquetFileFormat { override val columnMappingMode: DeltaColumnMappingMode = metadata.columnMappingMode override val referenceSchema: StructType = metadata.schema @@ -46,6 +60,47 @@ case class GpuDelta24xParquetFileFormat( options: Map[String, String], path: Path): Boolean = isSplittable + override def buildReaderWithPartitionValuesAndMetrics( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration, + metrics: Map[String, GpuMetric], + alluxioPathReplacementMap: Option[Map[String, String]]) + : PartitionedFile => Iterator[InternalRow] = { + + + val dataReader = super.buildReaderWithPartitionValuesAndMetrics( + sparkSession, + dataSchema, + partitionSchema, + requiredSchema, + if (disablePushDown) Seq.empty else filters, + options, + hadoopConf, + metrics, + alluxioPathReplacementMap) + + val delVecs = broadcastDvMap + val maxDelVecScatterBatchSize = RapidsConf + .DELTA_LOW_SHUFFLE_MERGE_SCATTER_DEL_VECTOR_BATCH_SIZE + .get(sparkSession.sessionState.conf) + + (file: PartitionedFile) => { + val input = dataReader(file) + val dv = delVecs.flatMap(_.value.get(new URI(file.filePath.toString()))) + .map(dv => RoaringBitmapWrapper.deserializeFromBytes(dv.descriptor.inlineData).inner) + addMetadataColumnToIterator(prepareSchema(requiredSchema), + dv, + input.asInstanceOf[Iterator[ColumnarBatch]], + maxDelVecScatterBatchSize) + .asInstanceOf[Iterator[InternalRow]] + } + } + /** * We sometimes need to replace FileFormat within LogicalPlans, so we have to override * `equals` to ensure file format changes are captured diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/MergeIntoCommandMeta.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/MergeIntoCommandMeta.scala index 4b4dfb624b5..8ce813ef011 100644 --- a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/MergeIntoCommandMeta.scala +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/MergeIntoCommandMeta.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,14 @@ package com.nvidia.spark.rapids.delta.delta24x -import com.nvidia.spark.rapids.{DataFromReplacementRule, RapidsConf, RapidsMeta, RunnableCommandMeta} +import com.nvidia.spark.rapids.{DataFromReplacementRule, RapidsConf, RapidsMeta, RapidsReaderType, RunnableCommandMeta} import com.nvidia.spark.rapids.delta.RapidsDeltaUtils +import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.delta.commands.MergeIntoCommand import org.apache.spark.sql.delta.rapids.GpuDeltaLog -import org.apache.spark.sql.delta.rapids.delta24x.GpuMergeIntoCommand +import org.apache.spark.sql.delta.rapids.delta24x.{GpuLowShuffleMergeCommand, GpuMergeIntoCommand} import org.apache.spark.sql.execution.command.RunnableCommand class MergeIntoCommandMeta( @@ -30,12 +31,12 @@ class MergeIntoCommandMeta( conf: RapidsConf, parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) - extends RunnableCommandMeta[MergeIntoCommand](mergeCmd, conf, parent, rule) { + extends RunnableCommandMeta[MergeIntoCommand](mergeCmd, conf, parent, rule) with Logging { override def tagSelfForGpu(): Unit = { if (!conf.isDeltaWriteEnabled) { willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + - s"${RapidsConf.ENABLE_DELTA_WRITE} to true") + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } if (mergeCmd.notMatchedBySourceClauses.nonEmpty) { // https://github.com/NVIDIA/spark-rapids/issues/8415 @@ -48,14 +49,43 @@ class MergeIntoCommandMeta( } override def convertToGpu(): RunnableCommand = { - GpuMergeIntoCommand( - mergeCmd.source, - mergeCmd.target, - new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), - mergeCmd.condition, - mergeCmd.matchedClauses, - mergeCmd.notMatchedClauses, - mergeCmd.notMatchedBySourceClauses, - mergeCmd.migratedSchema)(conf) + // TODO: Currently we only support low shuffler merge only when parquet per file read is enabled + // due to the limitation of implementing row index metadata column. + if (conf.isDeltaLowShuffleMergeEnabled) { + if (conf.isParquetPerFileReadEnabled) { + GpuLowShuffleMergeCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } else { + logWarning(s"""Low shuffle merge disabled since ${RapidsConf.PARQUET_READER_TYPE} is + not set to ${RapidsReaderType.PERFILE}. Falling back to classic merge.""") + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } + } else { + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } } + } diff --git a/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/GpuLowShuffleMergeCommand.scala b/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/GpuLowShuffleMergeCommand.scala new file mode 100644 index 00000000000..9c27d28ebd3 --- /dev/null +++ b/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/GpuLowShuffleMergeCommand.scala @@ -0,0 +1,1084 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * This file was derived from MergeIntoCommand.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta24x + +import java.net.URI +import java.util.concurrent.TimeUnit + +import scala.collection.mutable + +import com.nvidia.spark.rapids.{GpuOverrides, RapidsConf, SparkPlanMeta} +import com.nvidia.spark.rapids.RapidsConf.DELTA_LOW_SHUFFLE_MERGE_DEL_VECTOR_BROADCAST_THRESHOLD +import com.nvidia.spark.rapids.delta._ +import com.nvidia.spark.rapids.delta.GpuDeltaParquetFileFormatUtils._ +import com.nvidia.spark.rapids.shims.FileSourceScanExecMeta +import org.roaringbitmap.longlong.Roaring64Bitmap + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, CaseWhen, Expression, Literal, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.plans.logical.{DeltaMergeAction, DeltaMergeIntoClause, DeltaMergeIntoMatchedClause, DeltaMergeIntoMatchedDeleteClause, DeltaMergeIntoMatchedUpdateClause, DeltaMergeIntoNotMatchedBySourceClause, DeltaMergeIntoNotMatchedBySourceDeleteClause, DeltaMergeIntoNotMatchedBySourceUpdateClause, DeltaMergeIntoNotMatchedClause, DeltaMergeIntoNotMatchedInsertClause, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaOperations, DeltaParquetFileFormat, DeltaTableUtils, DeltaUDF, NoMapping, OptimisticTransaction, RowIndexFilterType} +import org.apache.spark.sql.delta.DeltaOperations.MergePredicate +import org.apache.spark.sql.delta.DeltaParquetFileFormat.DeletionVectorDescriptorWithFilterType +import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, DeletionVectorDescriptor, FileAction} +import org.apache.spark.sql.delta.commands.DeltaCommand +import org.apache.spark.sql.delta.rapids.{GpuDeltaLog, GpuOptimisticTransactionBase} +import org.apache.spark.sql.delta.rapids.delta24x.MergeExecutor.{toDeletionVector, totalBytesAndDistinctPartitionValues, INCR_METRICS_COL, INCR_METRICS_FIELD, ROW_DROPPED_COL, ROW_DROPPED_FIELD, SOURCE_ROW_PRESENT_COL, SOURCE_ROW_PRESENT_FIELD, TARGET_ROW_PRESENT_COL, TARGET_ROW_PRESENT_FIELD} +import org.apache.spark.sql.delta.schema.ImplicitMetadataOperation +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.util.{AnalysisHelper, DeltaFileOperations} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} + +/** + * GPU version of Delta Lake's low shuffle merge implementation. + * + * Performs a merge of a source query/table into a Delta table. + * + * Issues an error message when the ON search_condition of the MERGE statement can match + * a single row from the target table with multiple rows of the source table-reference. + * Different from the original implementation, it optimized writing touched unmodified target files. + * + * Algorithm: + * + * Phase 1: Find the input files in target that are touched by the rows that satisfy + * the condition and verify that no two source rows match with the same target row. + * This is implemented as an inner-join using the given condition. See [[findTouchedFiles]] + * for more details. + * + * Phase 2: Read the touched files again and write new files with updated and/or inserted rows + * without copying unmodified rows. + * + * Phase 3: Read the touched files again and write new files with unmodified rows in target table, + * trying to keep its original order and avoid shuffle as much as possible. + * + * Phase 4: Use the Delta protocol to atomically remove the touched files and add the new files. + * + * @param source Source data to merge from + * @param target Target table to merge into + * @param gpuDeltaLog Delta log to use + * @param condition Condition for a source row to match with a target row + * @param matchedClauses All info related to matched clauses. + * @param notMatchedClauses All info related to not matched clause. + * @param migratedSchema The final schema of the target - may be changed by schema evolution. + */ +case class GpuLowShuffleMergeCommand( + @transient source: LogicalPlan, + @transient target: LogicalPlan, + @transient gpuDeltaLog: GpuDeltaLog, + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], + notMatchedBySourceClauses: Seq[DeltaMergeIntoNotMatchedBySourceClause], + migratedSchema: Option[StructType])( + @transient val rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaCommand with PredicateHelper with AnalysisHelper with ImplicitMetadataOperation { + + import SQLMetrics._ + + override val otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE) + override val canOverwriteSchema: Boolean = false + + override val output: Seq[Attribute] = Seq( + AttributeReference("num_affected_rows", LongType)(), + AttributeReference("num_updated_rows", LongType)(), + AttributeReference("num_deleted_rows", LongType)(), + AttributeReference("num_inserted_rows", LongType)()) + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + @transient private[delta] lazy val targetDeltaLog: DeltaLog = gpuDeltaLog.deltaLog + + override lazy val metrics = Map[String, SQLMetric]( + "numSourceRows" -> createMetric(sc, "number of source rows"), + "numSourceRowsInSecondScan" -> + createMetric(sc, "number of source rows (during repeated scan)"), + "numTargetRowsCopied" -> createMetric(sc, "number of target rows rewritten unmodified"), + "numTargetRowsInserted" -> createMetric(sc, "number of inserted rows"), + "numTargetRowsUpdated" -> createMetric(sc, "number of updated rows"), + "numTargetRowsDeleted" -> createMetric(sc, "number of deleted rows"), + "numTargetRowsMatchedUpdated" -> createMetric(sc, "number of target rows updated when matched"), + "numTargetRowsMatchedDeleted" -> createMetric(sc, "number of target rows deleted when matched"), + "numTargetRowsNotMatchedBySourceUpdated" -> createMetric(sc, + "number of target rows updated when not matched by source"), + "numTargetRowsNotMatchedBySourceDeleted" -> createMetric(sc, + "number of target rows deleted when not matched by source"), + "numTargetFilesBeforeSkipping" -> createMetric(sc, "number of target files before skipping"), + "numTargetFilesAfterSkipping" -> createMetric(sc, "number of target files after skipping"), + "numTargetFilesRemoved" -> createMetric(sc, "number of files removed to target"), + "numTargetFilesAdded" -> createMetric(sc, "number of files added to target"), + "numTargetChangeFilesAdded" -> + createMetric(sc, "number of change data capture files generated"), + "numTargetChangeFileBytes" -> + createMetric(sc, "total size of change data capture files generated"), + "numTargetBytesBeforeSkipping" -> createMetric(sc, "number of target bytes before skipping"), + "numTargetBytesAfterSkipping" -> createMetric(sc, "number of target bytes after skipping"), + "numTargetBytesRemoved" -> createMetric(sc, "number of target bytes removed"), + "numTargetBytesAdded" -> createMetric(sc, "number of target bytes added"), + "numTargetPartitionsAfterSkipping" -> + createMetric(sc, "number of target partitions after skipping"), + "numTargetPartitionsRemovedFrom" -> + createMetric(sc, "number of target partitions from which files were removed"), + "numTargetPartitionsAddedTo" -> + createMetric(sc, "number of target partitions to which files were added"), + "executionTimeMs" -> + createMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> + createMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> + createMetric(sc, "time taken to rewrite the matched files")) + + /** Whether this merge statement has only a single insert (NOT MATCHED) clause. */ + protected def isSingleInsertOnly: Boolean = matchedClauses.isEmpty && + notMatchedClauses.length == 1 + + override def run(spark: SparkSession): Seq[Row] = { + recordDeltaOperation(targetDeltaLog, "delta.dml.lowshufflemerge") { + val startTime = System.nanoTime() + val result = gpuDeltaLog.withNewTransaction { deltaTxn => + if (target.schema.size != deltaTxn.metadata.schema.size) { + throw DeltaErrors.schemaChangedSinceAnalysis( + atAnalysis = target.schema, latestSchema = deltaTxn.metadata.schema) + } + + if (canMergeSchema) { + updateMetadata( + spark, deltaTxn, migratedSchema.getOrElse(target.schema), + deltaTxn.metadata.partitionColumns, deltaTxn.metadata.configuration, + isOverwriteMode = false, rearrangeOnly = false) + } + + + val (executor, fallback) = { + val context = MergeExecutorContext(this, spark, deltaTxn, rapidsConf) + if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) { + (new InsertOnlyMergeExecutor(context), false) + } else { + val executor = new LowShuffleMergeExecutor(context) + (executor, executor.shouldFallback()) + } + } + + if (fallback) { + None + } else { + Some(runLowShuffleMerge(spark, startTime, deltaTxn, executor)) + } + } + + result match { + case Some(row) => row + case None => + // We should rollback to normal gpu + new GpuMergeIntoCommand(source, target, gpuDeltaLog, condition, matchedClauses, + notMatchedClauses, notMatchedBySourceClauses, migratedSchema)(rapidsConf) + .run(spark) + } + } + } + + + private def runLowShuffleMerge( + spark: SparkSession, + startTime: Long, + deltaTxn: GpuOptimisticTransactionBase, + mergeExecutor: MergeExecutor): Seq[Row] = { + val deltaActions = mergeExecutor.execute() + // Metrics should be recorded before commit (where they are written to delta logs). + metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000) + deltaTxn.registerSQLMetrics(spark, metrics) + + // This is a best-effort sanity check. + if (metrics("numSourceRowsInSecondScan").value >= 0 && + metrics("numSourceRows").value != metrics("numSourceRowsInSecondScan").value) { + log.warn(s"Merge source has ${metrics("numSourceRows").value} rows in initial scan but " + + s"${metrics("numSourceRowsInSecondScan").value} rows in second scan") + if (conf.getConf(DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED)) { + throw DeltaErrors.sourceNotDeterministicInMergeException(spark) + } + } + + deltaTxn.commit( + deltaActions, + DeltaOperations.Merge( + Option(condition), + matchedClauses.map(DeltaOperations.MergePredicate(_)), + notMatchedClauses.map(DeltaOperations.MergePredicate(_)), + // We do not support notMatchedBySourcePredicates yet and fall back to CPU + // See https://github.com/NVIDIA/spark-rapids/issues/8415 + notMatchedBySourcePredicates = Seq.empty[MergePredicate] + )) + + // Record metrics + val stats = GpuMergeStats.fromMergeSQLMetrics( + metrics, + condition, + matchedClauses, + notMatchedClauses, + notMatchedBySourceClauses, + deltaTxn.metadata.partitionColumns.nonEmpty) + recordDeltaEvent(targetDeltaLog, "delta.dml.merge.stats", data = stats) + + + spark.sharedState.cacheManager.recacheByPlan(spark, target) + + // This is needed to make the SQL metrics visible in the Spark UI. Also this needs + // to be outside the recordMergeOperation because this method will update some metric. + val executionId = spark.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(spark.sparkContext, executionId, metrics.values.toSeq) + Seq(Row(metrics("numTargetRowsUpdated").value + metrics("numTargetRowsDeleted").value + + metrics("numTargetRowsInserted").value, metrics("numTargetRowsUpdated").value, + metrics("numTargetRowsDeleted").value, metrics("numTargetRowsInserted").value)) + } + + /** + * Execute the given `thunk` and return its result while recording the time taken to do it. + * + * @param sqlMetricName name of SQL metric to update with the time taken by the thunk + * @param thunk the code to execute + */ + private[delta] def recordMergeOperation[A](sqlMetricName: String)(thunk: => A): A = { + val startTimeNs = System.nanoTime() + val r = thunk + val timeTakenMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + if (sqlMetricName != null && timeTakenMs > 0) { + metrics(sqlMetricName) += timeTakenMs + } + r + } + + /** Expressions to increment SQL metrics */ + private[delta] def makeMetricUpdateUDF(name: String, deterministic: Boolean = false) + : Expression = { + // only capture the needed metric in a local variable + val metric = metrics(name) + var u = DeltaUDF.boolean(new GpuDeltaMetricUpdateUDF(metric)) + if (!deterministic) { + u = u.asNondeterministic() + } + u.apply().expr + } +} + +/** + * Context merge execution. + */ +case class MergeExecutorContext(cmd: GpuLowShuffleMergeCommand, + spark: SparkSession, + deltaTxn: OptimisticTransaction, + rapidsConf: RapidsConf) + +trait MergeExecutor extends AnalysisHelper with PredicateHelper with Logging { + + val context: MergeExecutorContext + + + /** + * Map to get target output attributes by name. + * The case sensitivity of the map is set accordingly to Spark configuration. + */ + @transient private lazy val targetOutputAttributesMap: Map[String, Attribute] = { + val attrMap: Map[String, Attribute] = context.cmd.target + .outputSet.view + .map(attr => attr.name -> attr).toMap + if (context.cmd.conf.caseSensitiveAnalysis) { + attrMap + } else { + CaseInsensitiveMap(attrMap) + } + } + + def execute(): Seq[FileAction] + + protected def targetOutputCols: Seq[NamedExpression] = { + context.deltaTxn.metadata.schema.map { col => + targetOutputAttributesMap + .get(col.name) + .map { a => + AttributeReference(col.name, col.dataType, col.nullable)(a.exprId) + } + .getOrElse(Alias(Literal(null), col.name)()) + } + } + + /** + * Build a DataFrame using the given `files` that has the same output columns (exprIds) + * as the `target` logical plan, so that existing update/insert expressions can be applied + * on this new plan. + */ + protected def buildTargetDFWithFiles(files: Seq[AddFile]): DataFrame = { + val targetOutputColsMap = { + val colsMap: Map[String, NamedExpression] = targetOutputCols.view + .map(col => col.name -> col).toMap + if (context.cmd.conf.caseSensitiveAnalysis) { + colsMap + } else { + CaseInsensitiveMap(colsMap) + } + } + + val plan = { + // We have to do surgery to use the attributes from `targetOutputCols` to scan the table. + // In cases of schema evolution, they may not be the same type as the original attributes. + val original = + context.deltaTxn.deltaLog.createDataFrame(context.deltaTxn.snapshot, files) + .queryExecution + .analyzed + val transformed = original.transform { + case LogicalRelation(base, _, catalogTbl, isStreaming) => + LogicalRelation( + base, + // We can ignore the new columns which aren't yet AttributeReferences. + targetOutputCols.collect { case a: AttributeReference => a }, + catalogTbl, + isStreaming) + } + + // In case of schema evolution & column mapping, we would also need to rebuild the file + // format because under column mapping, the reference schema within DeltaParquetFileFormat + // that is used to populate metadata needs to be updated + if (context.deltaTxn.metadata.columnMappingMode != NoMapping) { + val updatedFileFormat = context.deltaTxn.deltaLog.fileFormat( + context.deltaTxn.deltaLog.unsafeVolatileSnapshot.protocol, context.deltaTxn.metadata) + DeltaTableUtils.replaceFileFormat(transformed, updatedFileFormat) + } else { + transformed + } + } + + // For each plan output column, find the corresponding target output column (by name) and + // create an alias + val aliases = plan.output.map { + case newAttrib: AttributeReference => + val existingTargetAttrib = targetOutputColsMap.getOrElse(newAttrib.name, + throw new AnalysisException( + s"Could not find ${newAttrib.name} among the existing target output " + + targetOutputCols.mkString(","))).asInstanceOf[AttributeReference] + + if (existingTargetAttrib.exprId == newAttrib.exprId) { + // It's not valid to alias an expression to its own exprId (this is considered a + // non-unique exprId by the analyzer), so we just use the attribute directly. + newAttrib + } else { + Alias(newAttrib, existingTargetAttrib.name)(exprId = existingTargetAttrib.exprId) + } + } + + Dataset.ofRows(context.spark, Project(aliases, plan)) + } + + + /** + * Repartitions the output DataFrame by the partition columns if table is partitioned + * and `merge.repartitionBeforeWrite.enabled` is set to true. + */ + protected def repartitionIfNeeded(df: DataFrame): DataFrame = { + val partitionColumns = context.deltaTxn.metadata.partitionColumns + // TODO: We should remove this method and use optimized write instead, see + // https://github.com/NVIDIA/spark-rapids/issues/10417 + if (partitionColumns.nonEmpty && context.spark.conf.get(DeltaSQLConf + .MERGE_REPARTITION_BEFORE_WRITE)) { + df.repartition(partitionColumns.map(col): _*) + } else { + df + } + } + + protected def sourceDF: DataFrame = { + // UDF to increment metrics + val incrSourceRowCountExpr = context.cmd.makeMetricUpdateUDF("numSourceRows") + Dataset.ofRows(context.spark, context.cmd.source) + .filter(new Column(incrSourceRowCountExpr)) + } + + /** Whether this merge statement has no insert (NOT MATCHED) clause. */ + protected def hasNoInserts: Boolean = context.cmd.notMatchedClauses.isEmpty + + +} + +/** + * This is an optimization of the case when there is no update clause for the merge. + * We perform an left anti join on the source data to find the rows to be inserted. + * + * This will currently only optimize for the case when there is a _single_ notMatchedClause. + */ +class InsertOnlyMergeExecutor(override val context: MergeExecutorContext) extends MergeExecutor { + override def execute(): Seq[FileAction] = { + context.cmd.recordMergeOperation(sqlMetricName = "rewriteTimeMs") { + + // UDFs to update metrics + val incrSourceRowCountExpr = context.cmd.makeMetricUpdateUDF("numSourceRows") + val incrInsertedCountExpr = context.cmd.makeMetricUpdateUDF("numTargetRowsInserted") + + val outputColNames = targetOutputCols.map(_.name) + // we use head here since we know there is only a single notMatchedClause + val outputExprs = context.cmd.notMatchedClauses.head.resolvedActions.map(_.expr) + val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) => + new Column(Alias(expr, name)()) + } + + // source DataFrame + val sourceDF = Dataset.ofRows(context.spark, context.cmd.source) + .filter(new Column(incrSourceRowCountExpr)) + .filter(new Column(context.cmd.notMatchedClauses.head.condition + .getOrElse(Literal.TrueLiteral))) + + // Skip data based on the merge condition + val conjunctivePredicates = splitConjunctivePredicates(context.cmd.condition) + val targetOnlyPredicates = + conjunctivePredicates.filter(_.references.subsetOf(context.cmd.target.outputSet)) + val dataSkippedFiles = context.deltaTxn.filterFiles(targetOnlyPredicates) + + // target DataFrame + val targetDF = buildTargetDFWithFiles(dataSkippedFiles) + + val insertDf = sourceDF.join(targetDF, new Column(context.cmd.condition), "leftanti") + .select(outputCols: _*) + .filter(new Column(incrInsertedCountExpr)) + + val newFiles = context.deltaTxn + .writeFiles(repartitionIfNeeded(insertDf, + )) + + // Update metrics + context.cmd.metrics("numTargetFilesBeforeSkipping") += context.deltaTxn.snapshot.numOfFiles + context.cmd.metrics("numTargetBytesBeforeSkipping") += context.deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + context.cmd.metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + context.cmd.metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + context.cmd.metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + context.cmd.metrics("numTargetFilesRemoved") += 0 + context.cmd.metrics("numTargetBytesRemoved") += 0 + context.cmd.metrics("numTargetPartitionsRemovedFrom") += 0 + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + context.cmd.metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + context.cmd.metrics("numTargetBytesAdded") += addedBytes + context.cmd.metrics("numTargetPartitionsAddedTo") += addedPartitions + newFiles + } + } +} + + +/** + * This is an optimized algorithm for merge statement, where we avoid shuffling the unmodified + * target data. + * + * The algorithm is as follows: + * 1. Find touched target files in the target table by joining the source and target data, with + * collecting joined row identifiers as (`__metadata_file_path`, `__metadata_row_idx`) pairs. + * 2. Read the touched files again and write new files with updated and/or inserted rows + * without coping unmodified data from target table, but filtering target table with collected + * rows mentioned above. + * 3. Read the touched files again, filtering unmodified rows with collected row identifiers + * collected in first step, and saving them without shuffle. + */ +class LowShuffleMergeExecutor(override val context: MergeExecutorContext) extends MergeExecutor { + + // We over-count numTargetRowsDeleted when there are multiple matches; + // this is the amount of the overcount, so we can subtract it to get a correct final metric. + private var multipleMatchDeleteOnlyOvercount: Option[Long] = None + + // UDFs to update metrics + private val incrSourceRowCountExpr: Expression = context.cmd. + makeMetricUpdateUDF("numSourceRowsInSecondScan") + private val incrUpdatedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsUpdated") + private val incrUpdatedMatchedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsMatchedUpdated") + private val incrUpdatedNotMatchedBySourceCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsNotMatchedBySourceUpdated") + private val incrInsertedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsInserted") + private val incrDeletedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsDeleted") + private val incrDeletedMatchedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsMatchedDeleted") + private val incrDeletedNotMatchedBySourceCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsNotMatchedBySourceDeleted") + + private def updateOutput(resolvedActions: Seq[DeltaMergeAction], incrExpr: Expression) + : Seq[Expression] = { + resolvedActions.map(_.expr) :+ + Literal.FalseLiteral :+ + UnresolvedAttribute(TARGET_ROW_PRESENT_COL) :+ + UnresolvedAttribute(SOURCE_ROW_PRESENT_COL) :+ + incrExpr + } + + private def deleteOutput(incrExpr: Expression): Seq[Expression] = { + targetOutputCols :+ + TrueLiteral :+ + UnresolvedAttribute(TARGET_ROW_PRESENT_COL) :+ + UnresolvedAttribute(SOURCE_ROW_PRESENT_COL) :+ + incrExpr + } + + private def insertOutput(resolvedActions: Seq[DeltaMergeAction], incrExpr: Expression) + : Seq[Expression] = { + resolvedActions.map(_.expr) :+ + Literal.FalseLiteral :+ + UnresolvedAttribute(TARGET_ROW_PRESENT_COL) :+ + UnresolvedAttribute(SOURCE_ROW_PRESENT_COL) :+ + incrExpr + } + + private def clauseOutput(clause: DeltaMergeIntoClause): Seq[Expression] = clause match { + case u: DeltaMergeIntoMatchedUpdateClause => + updateOutput(u.resolvedActions, And(incrUpdatedCountExpr, incrUpdatedMatchedCountExpr)) + case _: DeltaMergeIntoMatchedDeleteClause => + deleteOutput(And(incrDeletedCountExpr, incrDeletedMatchedCountExpr)) + case i: DeltaMergeIntoNotMatchedInsertClause => + insertOutput(i.resolvedActions, incrInsertedCountExpr) + case u: DeltaMergeIntoNotMatchedBySourceUpdateClause => + updateOutput(u.resolvedActions, + And(incrUpdatedCountExpr, incrUpdatedNotMatchedBySourceCountExpr)) + case _: DeltaMergeIntoNotMatchedBySourceDeleteClause => + deleteOutput(And(incrDeletedCountExpr, incrDeletedNotMatchedBySourceCountExpr)) + } + + private def clauseCondition(clause: DeltaMergeIntoClause): Expression = { + // if condition is None, then expression always evaluates to true + clause.condition.getOrElse(TrueLiteral) + } + + /** + * Though low shuffle merge algorithm performs better than traditional merge algorithm in some + * cases, there are some case we should fallback to traditional merge executor: + * + * 1. Low shuffle merge algorithm requires generating metadata columns such as + * [[METADATA_ROW_IDX_COL]], [[METADATA_ROW_DEL_COL]], which only implemented on + * [[org.apache.spark.sql.rapids.GpuFileSourceScanExec]]. That means we need to fallback to + * this normal executor when [[org.apache.spark.sql.rapids.GpuFileSourceScanExec]] is disabled + * for some reason. + * 2. Low shuffle merge algorithm currently needs to broadcast deletion vector, which may + * introduce extra overhead. It maybe better to fallback to this algorithm when the changeset + * it too large. + */ + private[delta] def shouldFallback(): Boolean = { + // Trying to detect if we can execute finding touched files. + val touchFilePlanOverrideSucceed = verifyGpuPlan(planForFindingTouchedFiles()) { planMeta => + def check(meta: SparkPlanMeta[SparkPlan]): Boolean = { + meta match { + case scan if scan.isInstanceOf[FileSourceScanExecMeta] => scan + .asInstanceOf[FileSourceScanExecMeta] + .wrapped + .schema + .fieldNames + .contains(METADATA_ROW_IDX_COL) && scan.canThisBeReplaced + case m => m.childPlans.exists(check) + } + } + + check(planMeta) + } + if (!touchFilePlanOverrideSucceed) { + logWarning("Unable to override file scan for low shuffle merge for finding touched files " + + "plan, fallback to tradition merge.") + return true + } + + // Trying to detect if we can execute the merge plan. + val mergePlanOverrideSucceed = verifyGpuPlan(planForMergeExecution(touchedFiles)) { planMeta => + var overrideCount = 0 + def count(meta: SparkPlanMeta[SparkPlan]): Unit = { + meta match { + case scan if scan.isInstanceOf[FileSourceScanExecMeta] => + if (scan.asInstanceOf[FileSourceScanExecMeta] + .wrapped.schema.fieldNames.contains(METADATA_ROW_DEL_COL) && scan.canThisBeReplaced) { + overrideCount += 1 + } + case m => m.childPlans.foreach(count) + } + } + + count(planMeta) + overrideCount == 2 + } + + if (!mergePlanOverrideSucceed) { + logWarning("Unable to override file scan for low shuffle merge for merge plan, fallback to " + + "tradition merge.") + return true + } + + val deletionVectorSize = touchedFiles.values.map(_._1.serializedSizeInBytes()).sum + val maxDelVectorSize = context.rapidsConf + .get(DELTA_LOW_SHUFFLE_MERGE_DEL_VECTOR_BROADCAST_THRESHOLD) + if (deletionVectorSize > maxDelVectorSize) { + logWarning( + s"""Low shuffle merge can't be executed because broadcast deletion vector count + |$deletionVectorSize is large than max value $maxDelVectorSize """.stripMargin) + return true + } + + false + } + + private def verifyGpuPlan(input: DataFrame)(checkPlanMeta: SparkPlanMeta[SparkPlan] => Boolean) + : Boolean = { + val overridePlan = GpuOverrides.wrapAndTagPlan(input.queryExecution.sparkPlan, + context.rapidsConf) + checkPlanMeta(overridePlan) + } + + override def execute(): Seq[FileAction] = { + val newFiles = context.cmd.withStatusCode("DELTA", + s"Rewriting ${touchedFiles.size} files and saving modified data") { + val df = planForMergeExecution(touchedFiles) + context.deltaTxn.writeFiles(df) + } + + // Update metrics + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + context.cmd.metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + context.cmd.metrics("numTargetChangeFilesAdded") += newFiles.count(_.isInstanceOf[AddCDCFile]) + context.cmd.metrics("numTargetChangeFileBytes") += newFiles.collect { + case f: AddCDCFile => f.size + } + .sum + context.cmd.metrics("numTargetBytesAdded") += addedBytes + context.cmd.metrics("numTargetPartitionsAddedTo") += addedPartitions + + if (multipleMatchDeleteOnlyOvercount.isDefined) { + // Compensate for counting duplicates during the query. + val actualRowsDeleted = + context.cmd.metrics("numTargetRowsDeleted").value - multipleMatchDeleteOnlyOvercount.get + assert(actualRowsDeleted >= 0) + context.cmd.metrics("numTargetRowsDeleted").set(actualRowsDeleted) + } + + touchedFiles.values.map(_._2).map(_.remove).toSeq ++ newFiles + } + + private lazy val dataSkippedFiles: Seq[AddFile] = { + // Skip data based on the merge condition + val targetOnlyPredicates = splitConjunctivePredicates(context.cmd.condition) + .filter(_.references.subsetOf(context.cmd.target.outputSet)) + context.deltaTxn.filterFiles(targetOnlyPredicates) + } + + private lazy val dataSkippedTargetDF: DataFrame = { + addRowIndexMetaColumn(buildTargetDFWithFiles(dataSkippedFiles)) + } + + private lazy val touchedFiles: Map[String, (Roaring64Bitmap, AddFile)] = this.findTouchedFiles() + + private def planForFindingTouchedFiles(): DataFrame = { + + // Apply inner join to between source and target using the merge condition to find matches + // In addition, we attach two columns + // - METADATA_ROW_IDX column to identify target row in file + // - FILE_PATH_COL the target file name the row is from to later identify the files touched + // by matched rows + val targetDF = dataSkippedTargetDF.withColumn(FILE_PATH_COL, input_file_name()) + + sourceDF.join(targetDF, new Column(context.cmd.condition), "inner") + } + + private def planForMergeExecution(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]) + : DataFrame = { + getModifiedDF(touchedFiles).unionAll(getUnmodifiedDF(touchedFiles)) + } + + /** + * Find the target table files that contain the rows that satisfy the merge condition. This is + * implemented as an inner-join between the source query/table and the target table using + * the merge condition. + */ + private def findTouchedFiles(): Map[String, (Roaring64Bitmap, AddFile)] = + context.cmd.recordMergeOperation(sqlMetricName = "scanTimeMs") { + context.spark.udf.register("row_index_set", udaf(RoaringBitmapUDAF)) + // Process the matches from the inner join to record touched files and find multiple matches + val collectTouchedFiles = planForFindingTouchedFiles() + .select(col(FILE_PATH_COL), col(METADATA_ROW_IDX_COL)) + .groupBy(FILE_PATH_COL) + .agg( + expr(s"row_index_set($METADATA_ROW_IDX_COL) as row_idxes"), + count("*").as("count")) + .collect().map(row => { + val filename = row.getAs[String](FILE_PATH_COL) + val rowIdxSet = row.getAs[RoaringBitmapWrapper]("row_idxes").inner + val count = row.getAs[Long]("count") + (filename, (rowIdxSet, count)) + }) + .toMap + + val duplicateCount = { + val distinctMatchedRowCounts = collectTouchedFiles.values + .map(_._1.getLongCardinality).sum + val allMatchedRowCounts = collectTouchedFiles.values.map(_._2).sum + allMatchedRowCounts - distinctMatchedRowCounts + } + + val hasMultipleMatches = duplicateCount > 0 + + // Throw error if multiple matches are ambiguous or cannot be computed correctly. + val canBeComputedUnambiguously = { + // Multiple matches are not ambiguous when there is only one unconditional delete as + // all the matched row pairs in the 2nd join in `writeAllChanges` will get deleted. + val isUnconditionalDelete = context.cmd.matchedClauses.headOption match { + case Some(DeltaMergeIntoMatchedDeleteClause(None)) => true + case _ => false + } + context.cmd.matchedClauses.size == 1 && isUnconditionalDelete + } + + if (hasMultipleMatches && !canBeComputedUnambiguously) { + throw DeltaErrors.multipleSourceRowMatchingTargetRowInMergeException(context.spark) + } + + if (hasMultipleMatches) { + // This is only allowed for delete-only queries. + // This query will count the duplicates for numTargetRowsDeleted in Job 2, + // because we count matches after the join and not just the target rows. + // We have to compensate for this by subtracting the duplicates later, + // so we need to record them here. + multipleMatchDeleteOnlyOvercount = Some(duplicateCount) + } + + // Get the AddFiles using the touched file names. + val touchedFileNames = collectTouchedFiles.keys.toSeq + + val nameToAddFileMap = context.cmd.generateCandidateFileMap( + context.cmd.targetDeltaLog.dataPath, + dataSkippedFiles) + + val touchedAddFiles = touchedFileNames.map(f => + context.cmd.getTouchedFile(context.cmd.targetDeltaLog.dataPath, f, nameToAddFileMap)) + .map(f => (DeltaFileOperations + .absolutePath(context.cmd.targetDeltaLog.dataPath.toString, f.path) + .toString, f)).toMap + + // When the target table is empty, and the optimizer optimized away the join entirely + // numSourceRows will be incorrectly 0. + // We need to scan the source table once to get the correct + // metric here. + if (context.cmd.metrics("numSourceRows").value == 0 && + (dataSkippedFiles.isEmpty || dataSkippedTargetDF.take(1).isEmpty)) { + val numSourceRows = sourceDF.count() + context.cmd.metrics("numSourceRows").set(numSourceRows) + } + + // Update metrics + context.cmd.metrics("numTargetFilesBeforeSkipping") += context.deltaTxn.snapshot.numOfFiles + context.cmd.metrics("numTargetBytesBeforeSkipping") += context.deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + context.cmd.metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + context.cmd.metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + context.cmd.metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + val (removedBytes, removedPartitions) = + totalBytesAndDistinctPartitionValues(touchedAddFiles.values.toSeq) + context.cmd.metrics("numTargetFilesRemoved") += touchedAddFiles.size + context.cmd.metrics("numTargetBytesRemoved") += removedBytes + context.cmd.metrics("numTargetPartitionsRemovedFrom") += removedPartitions + + collectTouchedFiles.map(kv => (kv._1, (kv._2._1, touchedAddFiles(kv._1)))) + } + + + /** + * Modify original data frame to insert + * [[GpuDeltaParquetFileFormatUtils.METADATA_ROW_IDX_COL]]. + */ + private def addRowIndexMetaColumn(baseDF: DataFrame): DataFrame = { + val rowIdxAttr = AttributeReference( + METADATA_ROW_IDX_COL, + METADATA_ROW_IDX_FIELD.dataType, + METADATA_ROW_IDX_FIELD.nullable)() + + val newPlan = baseDF.queryExecution.analyzed.transformUp { + case r@LogicalRelation(fs: HadoopFsRelation, _, _, _) => + val newSchema = StructType(fs.dataSchema.fields).add(METADATA_ROW_IDX_FIELD) + + // This is required to ensure that row index is correctly calculated. + val newFileFormat = fs.fileFormat.asInstanceOf[DeltaParquetFileFormat] + .copy(isSplittable = false, disablePushDowns = true) + + val newFs = fs.copy(dataSchema = newSchema, fileFormat = newFileFormat)(context.spark) + + val newOutput = r.output :+ rowIdxAttr + r.copy(relation = newFs, output = newOutput) + case p@Project(projectList, _) => + val newProjectList = projectList :+ rowIdxAttr + p.copy(projectList = newProjectList) + } + + Dataset.ofRows(context.spark, newPlan) + } + + /** + * The result is scanning target table with touched files, and added an extra + * [[METADATA_ROW_DEL_COL]] to indicate whether filtered by joining with source table in first + * step. + */ + private def getTouchedTargetDF(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]) + : DataFrame = { + // Generate a new target dataframe that has same output attributes exprIds as the target plan. + // This allows us to apply the existing resolved update/insert expressions. + val baseTargetDF = buildTargetDFWithFiles(touchedFiles.values.map(_._2).toSeq) + + val newPlan = { + val rowDelAttr = AttributeReference( + METADATA_ROW_DEL_COL, + METADATA_ROW_DEL_FIELD.dataType, + METADATA_ROW_DEL_FIELD.nullable)() + + baseTargetDF.queryExecution.analyzed.transformUp { + case r@LogicalRelation(fs: HadoopFsRelation, _, _, _) => + val newSchema = StructType(fs.dataSchema.fields).add(METADATA_ROW_DEL_FIELD) + + // This is required to ensure that row index is correctly calculated. + val newFileFormat = { + val oldFormat = fs.fileFormat.asInstanceOf[DeltaParquetFileFormat] + val dvs = touchedFiles.map(kv => (new URI(kv._1), + DeletionVectorDescriptorWithFilterType(toDeletionVector(kv._2._1), + RowIndexFilterType.UNKNOWN))) + val broadcastDVs = context.spark.sparkContext.broadcast(dvs) + + oldFormat.copy(isSplittable = false, + broadcastDvMap = Some(broadcastDVs), + disablePushDowns = true) + } + + val newFs = fs.copy(dataSchema = newSchema, fileFormat = newFileFormat)(context.spark) + + val newOutput = r.output :+ rowDelAttr + r.copy(relation = newFs, output = newOutput) + case p@Project(projectList, _) => + val newProjectList = projectList :+ rowDelAttr + p.copy(projectList = newProjectList) + } + } + + val df = Dataset.ofRows(context.spark, newPlan) + .withColumn(TARGET_ROW_PRESENT_COL, lit(true)) + + df + } + + /** + * Generate a plan by calculating modified rows. It's computed by joining source and target + * tables, where target table has been filtered by (`__metadata_file_name`, + * `__metadata_row_idx`) pairs collected in first step. + * + * Schema of `modifiedDF`: + * + * targetSchema + ROW_DROPPED_COL + TARGET_ROW_PRESENT_COL + + * SOURCE_ROW_PRESENT_COL + INCR_METRICS_COL + * INCR_METRICS_COL + * + * It consists of several parts: + * + * 1. Unmatched source rows which are inserted + * 2. Unmatched source rows which are deleted + * 3. Target rows which are updated + * 4. Target rows which are deleted + */ + private def getModifiedDF(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]): DataFrame = { + val sourceDF = this.sourceDF + .withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr)) + + val targetDF = getTouchedTargetDF(touchedFiles) + + val joinedDF = { + val joinType = if (hasNoInserts && + context.spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED)) { + "inner" + } else { + "leftOuter" + } + val matchedTargetDF = targetDF.filter(METADATA_ROW_DEL_COL) + .drop(METADATA_ROW_DEL_COL) + + sourceDF.join(matchedTargetDF, new Column(context.cmd.condition), joinType) + } + + val modifiedRowsSchema = context.deltaTxn.metadata.schema + .add(ROW_DROPPED_FIELD) + .add(TARGET_ROW_PRESENT_FIELD.copy(nullable = true)) + .add(SOURCE_ROW_PRESENT_FIELD.copy(nullable = true)) + .add(INCR_METRICS_FIELD) + + // Here we generate a case when statement to handle all cases: + // CASE + // WHEN + // CASE WHEN + // + // WHEN + // + // ELSE + // + // WHEN + // CASE WHEN + // + // WHEN + // + // ELSE + // + // END + + val notMatchedConditions = context.cmd.notMatchedClauses.map(clauseCondition) + val notMatchedExpr = { + val deletedNotMatchedRow = { + targetOutputCols :+ + Literal.TrueLiteral :+ + Literal.FalseLiteral :+ + Literal(null) :+ + Literal.TrueLiteral + } + if (context.cmd.notMatchedClauses.isEmpty) { + // If there no `WHEN NOT MATCHED` clause, we should just delete not matched row + deletedNotMatchedRow + } else { + val notMatchedOutputs = context.cmd.notMatchedClauses.map(clauseOutput) + modifiedRowsSchema.zipWithIndex.map { + case (_, idx) => + CaseWhen(notMatchedConditions.zip(notMatchedOutputs.map(_(idx))), + deletedNotMatchedRow(idx)) + } + } + } + + val matchedConditions = context.cmd.matchedClauses.map(clauseCondition) + val matchedOutputs = context.cmd.matchedClauses.map(clauseOutput) + val matchedExprs = { + val notMatchedRow = { + targetOutputCols :+ + Literal.FalseLiteral :+ + Literal.TrueLiteral :+ + Literal(null) :+ + Literal.TrueLiteral + } + if (context.cmd.matchedClauses.isEmpty) { + // If there is not matched clause, this is insert only, we should delete this row. + notMatchedRow + } else { + modifiedRowsSchema.zipWithIndex.map { + case (_, idx) => + CaseWhen(matchedConditions.zip(matchedOutputs.map(_(idx))), + notMatchedRow(idx)) + } + } + } + + val sourceRowHasNoMatch = col(TARGET_ROW_PRESENT_COL).isNull.expr + + val modifiedCols = modifiedRowsSchema.zipWithIndex.map { case (col, idx) => + val caseWhen = CaseWhen( + Seq(sourceRowHasNoMatch -> notMatchedExpr(idx)), + matchedExprs(idx)) + Column(Alias(caseWhen, col.name)()) + } + + val modifiedDF = { + + // Make this a udf to avoid catalyst to be too aggressive to even remove the join! + val noopRowDroppedCol = udf(new GpuDeltaNoopUDF()).apply(!col(ROW_DROPPED_COL)) + + val modifiedDF = joinedDF.select(modifiedCols: _*) + // This will not filter anything since they always return true, but we need to avoid + // catalyst from optimizing these udf + .filter(noopRowDroppedCol && col(INCR_METRICS_COL)) + .drop(ROW_DROPPED_COL, INCR_METRICS_COL, TARGET_ROW_PRESENT_COL, SOURCE_ROW_PRESENT_COL) + + repartitionIfNeeded(modifiedDF) + } + + modifiedDF + } + + private def getUnmodifiedDF(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]): DataFrame = { + getTouchedTargetDF(touchedFiles) + .filter(!col(METADATA_ROW_DEL_COL)) + .drop(TARGET_ROW_PRESENT_COL, METADATA_ROW_DEL_COL) + } +} + + +object MergeExecutor { + + /** + * Spark UI will track all normal accumulators along with Spark tasks to show them on Web UI. + * However, the accumulator used by `MergeIntoCommand` can store a very large value since it + * tracks all files that need to be rewritten. We should ask Spark UI to not remember it, + * otherwise, the UI data may consume lots of memory. Hence, we use the prefix `internal.metrics.` + * to make this accumulator become an internal accumulator, so that it will not be tracked by + * Spark UI. + */ + val TOUCHED_FILES_ACCUM_NAME = "internal.metrics.MergeIntoDelta.touchedFiles" + + val ROW_ID_COL = "_row_id_" + val FILE_PATH_COL: String = GpuDeltaParquetFileFormatUtils.FILE_PATH_COL + val SOURCE_ROW_PRESENT_COL: String = "_source_row_present_" + val SOURCE_ROW_PRESENT_FIELD: StructField = StructField(SOURCE_ROW_PRESENT_COL, BooleanType, + nullable = false) + val TARGET_ROW_PRESENT_COL: String = "_target_row_present_" + val TARGET_ROW_PRESENT_FIELD: StructField = StructField(TARGET_ROW_PRESENT_COL, BooleanType, + nullable = false) + val ROW_DROPPED_COL: String = GpuDeltaMergeConstants.ROW_DROPPED_COL + val ROW_DROPPED_FIELD: StructField = StructField(ROW_DROPPED_COL, BooleanType, nullable = false) + val INCR_METRICS_COL: String = "_incr_metrics_" + val INCR_METRICS_FIELD: StructField = StructField(INCR_METRICS_COL, BooleanType, nullable = false) + val INCR_ROW_COUNT_COL: String = "_incr_row_count_" + + // Some Delta versions use Literal(null) which translates to a literal of NullType instead + // of the Literal(null, StringType) which is needed, so using a fixed version here + // rather than the version from Delta Lake. + val CDC_TYPE_NOT_CDC_LITERAL: Literal = Literal(null, StringType) + + private[delta] def toDeletionVector(bitmap: Roaring64Bitmap): DeletionVectorDescriptor = { + DeletionVectorDescriptor.inlineInLog(RoaringBitmapWrapper(bitmap).serializeToBytes(), + bitmap.getLongCardinality) + } + + /** Count the number of distinct partition values among the AddFiles in the given set. */ + private[delta] def totalBytesAndDistinctPartitionValues(files: Seq[FileAction]): (Long, Int) = { + val distinctValues = new mutable.HashSet[Map[String, String]]() + var bytes = 0L + val iter = files.collect { case a: AddFile => a }.iterator + while (iter.hasNext) { + val file = iter.next() + distinctValues += file.partitionValues + bytes += file.size + } + // If the only distinct value map is an empty map, then it must be an unpartitioned table. + // Return 0 in that case. + val numDistinctValues = + if (distinctValues.size == 1 && distinctValues.head.isEmpty) 0 else distinctValues.size + (bytes, numDistinctValues) + } +} \ No newline at end of file diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuLowShuffleMergeCommand.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuLowShuffleMergeCommand.scala new file mode 100644 index 00000000000..fddebda33bd --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuLowShuffleMergeCommand.scala @@ -0,0 +1,1083 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * This file was derived from MergeIntoCommand.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import java.net.URI +import java.util.concurrent.TimeUnit + +import scala.collection.mutable + +import com.databricks.sql.io.RowIndexFilterType +import com.databricks.sql.transaction.tahoe._ +import com.databricks.sql.transaction.tahoe.DeltaOperations.MergePredicate +import com.databricks.sql.transaction.tahoe.DeltaParquetFileFormat.DeletionVectorDescriptorWithFilterType +import com.databricks.sql.transaction.tahoe.actions.{AddCDCFile, AddFile, DeletionVectorDescriptor, FileAction} +import com.databricks.sql.transaction.tahoe.commands.DeltaCommand +import com.databricks.sql.transaction.tahoe.rapids.MergeExecutor.{toDeletionVector, totalBytesAndDistinctPartitionValues, FILE_PATH_COL, INCR_METRICS_COL, INCR_METRICS_FIELD, ROW_DROPPED_COL, ROW_DROPPED_FIELD, SOURCE_ROW_PRESENT_COL, SOURCE_ROW_PRESENT_FIELD, TARGET_ROW_PRESENT_COL, TARGET_ROW_PRESENT_FIELD} +import com.databricks.sql.transaction.tahoe.schema.ImplicitMetadataOperation +import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf +import com.databricks.sql.transaction.tahoe.util.{AnalysisHelper, DeltaFileOperations} +import com.nvidia.spark.rapids.{GpuOverrides, RapidsConf, SparkPlanMeta} +import com.nvidia.spark.rapids.RapidsConf.DELTA_LOW_SHUFFLE_MERGE_DEL_VECTOR_BROADCAST_THRESHOLD +import com.nvidia.spark.rapids.delta._ +import com.nvidia.spark.rapids.delta.GpuDeltaParquetFileFormatUtils.{METADATA_ROW_DEL_COL, METADATA_ROW_DEL_FIELD, METADATA_ROW_IDX_COL, METADATA_ROW_IDX_FIELD} +import com.nvidia.spark.rapids.shims.FileSourceScanExecMeta +import org.roaringbitmap.longlong.Roaring64Bitmap + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, CaseWhen, Expression, Literal, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.plans.logical.{DeltaMergeAction, DeltaMergeIntoClause, DeltaMergeIntoMatchedClause, DeltaMergeIntoMatchedDeleteClause, DeltaMergeIntoMatchedUpdateClause, DeltaMergeIntoNotMatchedBySourceClause, DeltaMergeIntoNotMatchedBySourceDeleteClause, DeltaMergeIntoNotMatchedBySourceUpdateClause, DeltaMergeIntoNotMatchedClause, DeltaMergeIntoNotMatchedInsertClause, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} + +/** + * GPU version of Delta Lake's low shuffle merge implementation. + * + * Performs a merge of a source query/table into a Delta table. + * + * Issues an error message when the ON search_condition of the MERGE statement can match + * a single row from the target table with multiple rows of the source table-reference. + * Different from the original implementation, it optimized writing touched unmodified target files. + * + * Algorithm: + * + * Phase 1: Find the input files in target that are touched by the rows that satisfy + * the condition and verify that no two source rows match with the same target row. + * This is implemented as an inner-join using the given condition. See [[findTouchedFiles]] + * for more details. + * + * Phase 2: Read the touched files again and write new files with updated and/or inserted rows + * without copying unmodified rows. + * + * Phase 3: Read the touched files again and write new files with unmodified rows in target table, + * trying to keep its original order and avoid shuffle as much as possible. + * + * Phase 4: Use the Delta protocol to atomically remove the touched files and add the new files. + * + * @param source Source data to merge from + * @param target Target table to merge into + * @param gpuDeltaLog Delta log to use + * @param condition Condition for a source row to match with a target row + * @param matchedClauses All info related to matched clauses. + * @param notMatchedClauses All info related to not matched clause. + * @param migratedSchema The final schema of the target - may be changed by schema evolution. + */ +case class GpuLowShuffleMergeCommand( + @transient source: LogicalPlan, + @transient target: LogicalPlan, + @transient gpuDeltaLog: GpuDeltaLog, + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], + notMatchedBySourceClauses: Seq[DeltaMergeIntoNotMatchedBySourceClause], + migratedSchema: Option[StructType])( + @transient val rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaCommand with PredicateHelper with AnalysisHelper with ImplicitMetadataOperation { + + import SQLMetrics._ + + override val otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE) + override val canOverwriteSchema: Boolean = false + + override val output: Seq[Attribute] = Seq( + AttributeReference("num_affected_rows", LongType)(), + AttributeReference("num_updated_rows", LongType)(), + AttributeReference("num_deleted_rows", LongType)(), + AttributeReference("num_inserted_rows", LongType)()) + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + @transient lazy val targetDeltaLog: DeltaLog = gpuDeltaLog.deltaLog + + override lazy val metrics = Map[String, SQLMetric]( + "numSourceRows" -> createMetric(sc, "number of source rows"), + "numSourceRowsInSecondScan" -> + createMetric(sc, "number of source rows (during repeated scan)"), + "numTargetRowsCopied" -> createMetric(sc, "number of target rows rewritten unmodified"), + "numTargetRowsInserted" -> createMetric(sc, "number of inserted rows"), + "numTargetRowsUpdated" -> createMetric(sc, "number of updated rows"), + "numTargetRowsDeleted" -> createMetric(sc, "number of deleted rows"), + "numTargetRowsMatchedUpdated" -> createMetric(sc, "number of target rows updated when matched"), + "numTargetRowsMatchedDeleted" -> createMetric(sc, "number of target rows deleted when matched"), + "numTargetRowsNotMatchedBySourceUpdated" -> createMetric(sc, + "number of target rows updated when not matched by source"), + "numTargetRowsNotMatchedBySourceDeleted" -> createMetric(sc, + "number of target rows deleted when not matched by source"), + "numTargetFilesBeforeSkipping" -> createMetric(sc, "number of target files before skipping"), + "numTargetFilesAfterSkipping" -> createMetric(sc, "number of target files after skipping"), + "numTargetFilesRemoved" -> createMetric(sc, "number of files removed to target"), + "numTargetFilesAdded" -> createMetric(sc, "number of files added to target"), + "numTargetChangeFilesAdded" -> + createMetric(sc, "number of change data capture files generated"), + "numTargetChangeFileBytes" -> + createMetric(sc, "total size of change data capture files generated"), + "numTargetBytesBeforeSkipping" -> createMetric(sc, "number of target bytes before skipping"), + "numTargetBytesAfterSkipping" -> createMetric(sc, "number of target bytes after skipping"), + "numTargetBytesRemoved" -> createMetric(sc, "number of target bytes removed"), + "numTargetBytesAdded" -> createMetric(sc, "number of target bytes added"), + "numTargetPartitionsAfterSkipping" -> + createMetric(sc, "number of target partitions after skipping"), + "numTargetPartitionsRemovedFrom" -> + createMetric(sc, "number of target partitions from which files were removed"), + "numTargetPartitionsAddedTo" -> + createMetric(sc, "number of target partitions to which files were added"), + "executionTimeMs" -> + createMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> + createMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> + createMetric(sc, "time taken to rewrite the matched files")) + + /** Whether this merge statement has only a single insert (NOT MATCHED) clause. */ + protected def isSingleInsertOnly: Boolean = matchedClauses.isEmpty && + notMatchedClauses.length == 1 + + override def run(spark: SparkSession): Seq[Row] = { + recordDeltaOperation(targetDeltaLog, "delta.dml.lowshufflemerge") { + val startTime = System.nanoTime() + val result = gpuDeltaLog.withNewTransaction { deltaTxn => + if (target.schema.size != deltaTxn.metadata.schema.size) { + throw DeltaErrors.schemaChangedSinceAnalysis( + atAnalysis = target.schema, latestSchema = deltaTxn.metadata.schema) + } + + if (canMergeSchema) { + updateMetadata( + spark, deltaTxn, migratedSchema.getOrElse(target.schema), + deltaTxn.metadata.partitionColumns, deltaTxn.metadata.configuration, + isOverwriteMode = false, rearrangeOnly = false) + } + + + val (executor, fallback) = { + val context = MergeExecutorContext(this, spark, deltaTxn, rapidsConf) + if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) { + (new InsertOnlyMergeExecutor(context), false) + } else { + val executor = new LowShuffleMergeExecutor(context) + (executor, executor.shouldFallback()) + } + } + + if (fallback) { + None + } else { + Some(runLowShuffleMerge(spark, startTime, deltaTxn, executor)) + } + } + + result match { + case Some(row) => row + case None => + // We should rollback to normal gpu + new GpuMergeIntoCommand(source, target, gpuDeltaLog, condition, matchedClauses, + notMatchedClauses, notMatchedBySourceClauses, migratedSchema)(rapidsConf) + .run(spark) + } + } + } + + + private def runLowShuffleMerge( + spark: SparkSession, + startTime: Long, + deltaTxn: GpuOptimisticTransactionBase, + mergeExecutor: MergeExecutor): Seq[Row] = { + val deltaActions = mergeExecutor.execute() + // Metrics should be recorded before commit (where they are written to delta logs). + metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000) + deltaTxn.registerSQLMetrics(spark, metrics) + + // This is a best-effort sanity check. + if (metrics("numSourceRowsInSecondScan").value >= 0 && + metrics("numSourceRows").value != metrics("numSourceRowsInSecondScan").value) { + log.warn(s"Merge source has ${metrics("numSourceRows").value} rows in initial scan but " + + s"${metrics("numSourceRowsInSecondScan").value} rows in second scan") + if (conf.getConf(DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED)) { + throw DeltaErrors.sourceNotDeterministicInMergeException(spark) + } + } + + deltaTxn.commit( + deltaActions, + DeltaOperations.Merge( + Option(condition), + matchedClauses.map(DeltaOperations.MergePredicate(_)), + notMatchedClauses.map(DeltaOperations.MergePredicate(_)), + // We do not support notMatchedBySourcePredicates yet and fall back to CPU + // See https://github.com/NVIDIA/spark-rapids/issues/8415 + notMatchedBySourcePredicates = Seq.empty[MergePredicate] + )) + + // Record metrics + val stats = GpuMergeStats.fromMergeSQLMetrics( + metrics, + condition, + matchedClauses, + notMatchedClauses, + deltaTxn.metadata.partitionColumns.nonEmpty) + recordDeltaEvent(targetDeltaLog, "delta.dml.merge.stats", data = stats) + + + spark.sharedState.cacheManager.recacheByPlan(spark, target) + + // This is needed to make the SQL metrics visible in the Spark UI. Also this needs + // to be outside the recordMergeOperation because this method will update some metric. + val executionId = spark.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(spark.sparkContext, executionId, metrics.values.toSeq) + Seq(Row(metrics("numTargetRowsUpdated").value + metrics("numTargetRowsDeleted").value + + metrics("numTargetRowsInserted").value, metrics("numTargetRowsUpdated").value, + metrics("numTargetRowsDeleted").value, metrics("numTargetRowsInserted").value)) + } + + /** + * Execute the given `thunk` and return its result while recording the time taken to do it. + * + * @param sqlMetricName name of SQL metric to update with the time taken by the thunk + * @param thunk the code to execute + */ + def recordMergeOperation[A](sqlMetricName: String)(thunk: => A): A = { + val startTimeNs = System.nanoTime() + val r = thunk + val timeTakenMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + if (sqlMetricName != null && timeTakenMs > 0) { + metrics(sqlMetricName) += timeTakenMs + } + r + } + + /** Expressions to increment SQL metrics */ + def makeMetricUpdateUDF(name: String, deterministic: Boolean = false) + : Expression = { + // only capture the needed metric in a local variable + val metric = metrics(name) + var u = DeltaUDF.boolean(new GpuDeltaMetricUpdateUDF(metric)) + if (!deterministic) { + u = u.asNondeterministic() + } + u.apply().expr + } +} + +/** + * Context merge execution. + */ +case class MergeExecutorContext(cmd: GpuLowShuffleMergeCommand, + spark: SparkSession, + deltaTxn: OptimisticTransaction, + rapidsConf: RapidsConf) + +trait MergeExecutor extends AnalysisHelper with PredicateHelper with Logging { + + val context: MergeExecutorContext + + + /** + * Map to get target output attributes by name. + * The case sensitivity of the map is set accordingly to Spark configuration. + */ + @transient private lazy val targetOutputAttributesMap: Map[String, Attribute] = { + val attrMap: Map[String, Attribute] = context.cmd.target + .outputSet.view + .map(attr => attr.name -> attr).toMap + if (context.cmd.conf.caseSensitiveAnalysis) { + attrMap + } else { + CaseInsensitiveMap(attrMap) + } + } + + def execute(): Seq[FileAction] + + protected def targetOutputCols: Seq[NamedExpression] = { + context.deltaTxn.metadata.schema.map { col => + targetOutputAttributesMap + .get(col.name) + .map { a => + AttributeReference(col.name, col.dataType, col.nullable)(a.exprId) + } + .getOrElse(Alias(Literal(null), col.name)()) + } + } + + /** + * Build a DataFrame using the given `files` that has the same output columns (exprIds) + * as the `target` logical plan, so that existing update/insert expressions can be applied + * on this new plan. + */ + protected def buildTargetDFWithFiles(files: Seq[AddFile]): DataFrame = { + val targetOutputColsMap = { + val colsMap: Map[String, NamedExpression] = targetOutputCols.view + .map(col => col.name -> col).toMap + if (context.cmd.conf.caseSensitiveAnalysis) { + colsMap + } else { + CaseInsensitiveMap(colsMap) + } + } + + val plan = { + // We have to do surgery to use the attributes from `targetOutputCols` to scan the table. + // In cases of schema evolution, they may not be the same type as the original attributes. + val original = + context.deltaTxn.deltaLog.createDataFrame(context.deltaTxn.snapshot, files) + .queryExecution + .analyzed + val transformed = original.transform { + case LogicalRelation(base, _, catalogTbl, isStreaming) => + LogicalRelation( + base, + // We can ignore the new columns which aren't yet AttributeReferences. + targetOutputCols.collect { case a: AttributeReference => a }, + catalogTbl, + isStreaming) + } + + // In case of schema evolution & column mapping, we would also need to rebuild the file + // format because under column mapping, the reference schema within DeltaParquetFileFormat + // that is used to populate metadata needs to be updated + if (context.deltaTxn.metadata.columnMappingMode != NoMapping) { + val updatedFileFormat = context.deltaTxn.deltaLog.fileFormat( + context.deltaTxn.deltaLog.unsafeVolatileSnapshot.protocol, context.deltaTxn.metadata) + DeltaTableUtils.replaceFileFormat(transformed, updatedFileFormat) + } else { + transformed + } + } + + // For each plan output column, find the corresponding target output column (by name) and + // create an alias + val aliases = plan.output.map { + case newAttrib: AttributeReference => + val existingTargetAttrib = targetOutputColsMap.getOrElse(newAttrib.name, + throw new AnalysisException( + s"Could not find ${newAttrib.name} among the existing target output " + + targetOutputCols.mkString(","))).asInstanceOf[AttributeReference] + + if (existingTargetAttrib.exprId == newAttrib.exprId) { + // It's not valid to alias an expression to its own exprId (this is considered a + // non-unique exprId by the analyzer), so we just use the attribute directly. + newAttrib + } else { + Alias(newAttrib, existingTargetAttrib.name)(exprId = existingTargetAttrib.exprId) + } + } + + Dataset.ofRows(context.spark, Project(aliases, plan)) + } + + + /** + * Repartitions the output DataFrame by the partition columns if table is partitioned + * and `merge.repartitionBeforeWrite.enabled` is set to true. + */ + protected def repartitionIfNeeded(df: DataFrame): DataFrame = { + val partitionColumns = context.deltaTxn.metadata.partitionColumns + // TODO: We should remove this method and use optimized write instead, see + // https://github.com/NVIDIA/spark-rapids/issues/10417 + if (partitionColumns.nonEmpty && context.spark.conf.get(DeltaSQLConf + .MERGE_REPARTITION_BEFORE_WRITE)) { + df.repartition(partitionColumns.map(col): _*) + } else { + df + } + } + + protected def sourceDF: DataFrame = { + // UDF to increment metrics + val incrSourceRowCountExpr = context.cmd.makeMetricUpdateUDF("numSourceRows") + Dataset.ofRows(context.spark, context.cmd.source) + .filter(new Column(incrSourceRowCountExpr)) + } + + /** Whether this merge statement has no insert (NOT MATCHED) clause. */ + protected def hasNoInserts: Boolean = context.cmd.notMatchedClauses.isEmpty + + +} + +/** + * This is an optimization of the case when there is no update clause for the merge. + * We perform an left anti join on the source data to find the rows to be inserted. + * + * This will currently only optimize for the case when there is a _single_ notMatchedClause. + */ +class InsertOnlyMergeExecutor(override val context: MergeExecutorContext) extends MergeExecutor { + override def execute(): Seq[FileAction] = { + context.cmd.recordMergeOperation(sqlMetricName = "rewriteTimeMs") { + + // UDFs to update metrics + val incrSourceRowCountExpr = context.cmd.makeMetricUpdateUDF("numSourceRows") + val incrInsertedCountExpr = context.cmd.makeMetricUpdateUDF("numTargetRowsInserted") + + val outputColNames = targetOutputCols.map(_.name) + // we use head here since we know there is only a single notMatchedClause + val outputExprs = context.cmd.notMatchedClauses.head.resolvedActions.map(_.expr) + val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) => + new Column(Alias(expr, name)()) + } + + // source DataFrame + val sourceDF = Dataset.ofRows(context.spark, context.cmd.source) + .filter(new Column(incrSourceRowCountExpr)) + .filter(new Column(context.cmd.notMatchedClauses.head.condition + .getOrElse(Literal.TrueLiteral))) + + // Skip data based on the merge condition + val conjunctivePredicates = splitConjunctivePredicates(context.cmd.condition) + val targetOnlyPredicates = + conjunctivePredicates.filter(_.references.subsetOf(context.cmd.target.outputSet)) + val dataSkippedFiles = context.deltaTxn.filterFiles(targetOnlyPredicates) + + // target DataFrame + val targetDF = buildTargetDFWithFiles(dataSkippedFiles) + + val insertDf = sourceDF.join(targetDF, new Column(context.cmd.condition), "leftanti") + .select(outputCols: _*) + .filter(new Column(incrInsertedCountExpr)) + + val newFiles = context.deltaTxn + .writeFiles(repartitionIfNeeded(insertDf, + )) + + // Update metrics + context.cmd.metrics("numTargetFilesBeforeSkipping") += context.deltaTxn.snapshot.numOfFiles + context.cmd.metrics("numTargetBytesBeforeSkipping") += context.deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + context.cmd.metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + context.cmd.metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + context.cmd.metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + context.cmd.metrics("numTargetFilesRemoved") += 0 + context.cmd.metrics("numTargetBytesRemoved") += 0 + context.cmd.metrics("numTargetPartitionsRemovedFrom") += 0 + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + context.cmd.metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + context.cmd.metrics("numTargetBytesAdded") += addedBytes + context.cmd.metrics("numTargetPartitionsAddedTo") += addedPartitions + newFiles + } + } +} + + +/** + * This is an optimized algorithm for merge statement, where we avoid shuffling the unmodified + * target data. + * + * The algorithm is as follows: + * 1. Find touched target files in the target table by joining the source and target data, with + * collecting joined row identifiers as (`__metadata_file_path`, `__metadata_row_idx`) pairs. + * 2. Read the touched files again and write new files with updated and/or inserted rows + * without coping unmodified data from target table, but filtering target table with collected + * rows mentioned above. + * 3. Read the touched files again, filtering unmodified rows with collected row identifiers + * collected in first step, and saving them without shuffle. + */ +class LowShuffleMergeExecutor(override val context: MergeExecutorContext) extends MergeExecutor { + + // We over-count numTargetRowsDeleted when there are multiple matches; + // this is the amount of the overcount, so we can subtract it to get a correct final metric. + private var multipleMatchDeleteOnlyOvercount: Option[Long] = None + + // UDFs to update metrics + private val incrSourceRowCountExpr: Expression = context.cmd. + makeMetricUpdateUDF("numSourceRowsInSecondScan") + private val incrUpdatedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsUpdated") + private val incrUpdatedMatchedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsMatchedUpdated") + private val incrUpdatedNotMatchedBySourceCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsNotMatchedBySourceUpdated") + private val incrInsertedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsInserted") + private val incrDeletedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsDeleted") + private val incrDeletedMatchedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsMatchedDeleted") + private val incrDeletedNotMatchedBySourceCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsNotMatchedBySourceDeleted") + + private def updateOutput(resolvedActions: Seq[DeltaMergeAction], incrExpr: Expression) + : Seq[Expression] = { + resolvedActions.map(_.expr) :+ + Literal.FalseLiteral :+ + UnresolvedAttribute(TARGET_ROW_PRESENT_COL) :+ + UnresolvedAttribute(SOURCE_ROW_PRESENT_COL) :+ + incrExpr + } + + private def deleteOutput(incrExpr: Expression): Seq[Expression] = { + targetOutputCols :+ + TrueLiteral :+ + UnresolvedAttribute(TARGET_ROW_PRESENT_COL) :+ + UnresolvedAttribute(SOURCE_ROW_PRESENT_COL) :+ + incrExpr + } + + private def insertOutput(resolvedActions: Seq[DeltaMergeAction], incrExpr: Expression) + : Seq[Expression] = { + resolvedActions.map(_.expr) :+ + Literal.FalseLiteral :+ + UnresolvedAttribute(TARGET_ROW_PRESENT_COL) :+ + UnresolvedAttribute(SOURCE_ROW_PRESENT_COL) :+ + incrExpr + } + + private def clauseOutput(clause: DeltaMergeIntoClause): Seq[Expression] = clause match { + case u: DeltaMergeIntoMatchedUpdateClause => + updateOutput(u.resolvedActions, And(incrUpdatedCountExpr, incrUpdatedMatchedCountExpr)) + case _: DeltaMergeIntoMatchedDeleteClause => + deleteOutput(And(incrDeletedCountExpr, incrDeletedMatchedCountExpr)) + case i: DeltaMergeIntoNotMatchedInsertClause => + insertOutput(i.resolvedActions, incrInsertedCountExpr) + case u: DeltaMergeIntoNotMatchedBySourceUpdateClause => + updateOutput(u.resolvedActions, + And(incrUpdatedCountExpr, incrUpdatedNotMatchedBySourceCountExpr)) + case _: DeltaMergeIntoNotMatchedBySourceDeleteClause => + deleteOutput(And(incrDeletedCountExpr, incrDeletedNotMatchedBySourceCountExpr)) + } + + private def clauseCondition(clause: DeltaMergeIntoClause): Expression = { + // if condition is None, then expression always evaluates to true + clause.condition.getOrElse(TrueLiteral) + } + + /** + * Though low shuffle merge algorithm performs better than traditional merge algorithm in some + * cases, there are some case we should fallback to traditional merge executor: + * + * 1. Low shuffle merge algorithm requires generating metadata columns such as + * [[METADATA_ROW_IDX_COL]], [[METADATA_ROW_DEL_COL]], which only implemented on + * [[org.apache.spark.sql.rapids.GpuFileSourceScanExec]]. That means we need to fallback to + * this normal executor when [[org.apache.spark.sql.rapids.GpuFileSourceScanExec]] is disabled + * for some reason. + * 2. Low shuffle merge algorithm currently needs to broadcast deletion vector, which may + * introduce extra overhead. It maybe better to fallback to this algorithm when the changeset + * it too large. + */ + def shouldFallback(): Boolean = { + // Trying to detect if we can execute finding touched files. + val touchFilePlanOverrideSucceed = verifyGpuPlan(planForFindingTouchedFiles()) { planMeta => + def check(meta: SparkPlanMeta[SparkPlan]): Boolean = { + meta match { + case scan if scan.isInstanceOf[FileSourceScanExecMeta] => scan + .asInstanceOf[FileSourceScanExecMeta] + .wrapped + .schema + .fieldNames + .contains(METADATA_ROW_IDX_COL) && scan.canThisBeReplaced + case m => m.childPlans.exists(check) + } + } + + check(planMeta) + } + if (!touchFilePlanOverrideSucceed) { + logWarning("Unable to override file scan for low shuffle merge for finding touched files " + + "plan, fallback to tradition merge.") + return true + } + + // Trying to detect if we can execute the merge plan. + val mergePlanOverrideSucceed = verifyGpuPlan(planForMergeExecution(touchedFiles)) { planMeta => + var overrideCount = 0 + def count(meta: SparkPlanMeta[SparkPlan]): Unit = { + meta match { + case scan if scan.isInstanceOf[FileSourceScanExecMeta] => + if (scan.asInstanceOf[FileSourceScanExecMeta] + .wrapped.schema.fieldNames.contains(METADATA_ROW_DEL_COL) && scan.canThisBeReplaced) { + overrideCount += 1 + } + case m => m.childPlans.foreach(count) + } + } + + count(planMeta) + overrideCount == 2 + } + + if (!mergePlanOverrideSucceed) { + logWarning("Unable to override file scan for low shuffle merge for merge plan, fallback to " + + "tradition merge.") + return true + } + + val deletionVectorSize = touchedFiles.values.map(_._1.serializedSizeInBytes()).sum + val maxDelVectorSize = context.rapidsConf + .get(DELTA_LOW_SHUFFLE_MERGE_DEL_VECTOR_BROADCAST_THRESHOLD) + if (deletionVectorSize > maxDelVectorSize) { + logWarning( + s"""Low shuffle merge can't be executed because broadcast deletion vector count + |$deletionVectorSize is large than max value $maxDelVectorSize """.stripMargin) + return true + } + + false + } + + private def verifyGpuPlan(input: DataFrame)(checkPlanMeta: SparkPlanMeta[SparkPlan] => Boolean) + : Boolean = { + val overridePlan = GpuOverrides.wrapAndTagPlan(input.queryExecution.sparkPlan, + context.rapidsConf) + checkPlanMeta(overridePlan) + } + + override def execute(): Seq[FileAction] = { + val newFiles = context.cmd.withStatusCode("DELTA", + s"Rewriting ${touchedFiles.size} files and saving modified data") { + val df = planForMergeExecution(touchedFiles) + context.deltaTxn.writeFiles(df) + } + + // Update metrics + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + context.cmd.metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + context.cmd.metrics("numTargetChangeFilesAdded") += newFiles.count(_.isInstanceOf[AddCDCFile]) + context.cmd.metrics("numTargetChangeFileBytes") += newFiles.collect { + case f: AddCDCFile => f.size + } + .sum + context.cmd.metrics("numTargetBytesAdded") += addedBytes + context.cmd.metrics("numTargetPartitionsAddedTo") += addedPartitions + + if (multipleMatchDeleteOnlyOvercount.isDefined) { + // Compensate for counting duplicates during the query. + val actualRowsDeleted = + context.cmd.metrics("numTargetRowsDeleted").value - multipleMatchDeleteOnlyOvercount.get + assert(actualRowsDeleted >= 0) + context.cmd.metrics("numTargetRowsDeleted").set(actualRowsDeleted) + } + + touchedFiles.values.map(_._2).map(_.remove).toSeq ++ newFiles + } + + private lazy val dataSkippedFiles: Seq[AddFile] = { + // Skip data based on the merge condition + val targetOnlyPredicates = splitConjunctivePredicates(context.cmd.condition) + .filter(_.references.subsetOf(context.cmd.target.outputSet)) + context.deltaTxn.filterFiles(targetOnlyPredicates) + } + + private lazy val dataSkippedTargetDF: DataFrame = { + addRowIndexMetaColumn(buildTargetDFWithFiles(dataSkippedFiles)) + } + + private lazy val touchedFiles: Map[String, (Roaring64Bitmap, AddFile)] = this.findTouchedFiles() + + private def planForFindingTouchedFiles(): DataFrame = { + + // Apply inner join to between source and target using the merge condition to find matches + // In addition, we attach two columns + // - METADATA_ROW_IDX column to identify target row in file + // - FILE_PATH_COL the target file name the row is from to later identify the files touched + // by matched rows + val targetDF = dataSkippedTargetDF.withColumn(FILE_PATH_COL, input_file_name()) + + sourceDF.join(targetDF, new Column(context.cmd.condition), "inner") + } + + private def planForMergeExecution(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]) + : DataFrame = { + getModifiedDF(touchedFiles).unionAll(getUnmodifiedDF(touchedFiles)) + } + + /** + * Find the target table files that contain the rows that satisfy the merge condition. This is + * implemented as an inner-join between the source query/table and the target table using + * the merge condition. + */ + private def findTouchedFiles(): Map[String, (Roaring64Bitmap, AddFile)] = + context.cmd.recordMergeOperation(sqlMetricName = "scanTimeMs") { + context.spark.udf.register("row_index_set", udaf(RoaringBitmapUDAF)) + // Process the matches from the inner join to record touched files and find multiple matches + val collectTouchedFiles = planForFindingTouchedFiles() + .select(col(FILE_PATH_COL), col(METADATA_ROW_IDX_COL)) + .groupBy(FILE_PATH_COL) + .agg( + expr(s"row_index_set($METADATA_ROW_IDX_COL) as row_idxes"), + count("*").as("count")) + .collect().map(row => { + val filename = row.getAs[String](FILE_PATH_COL) + val rowIdxSet = row.getAs[RoaringBitmapWrapper]("row_idxes").inner + val count = row.getAs[Long]("count") + (filename, (rowIdxSet, count)) + }) + .toMap + + val duplicateCount = { + val distinctMatchedRowCounts = collectTouchedFiles.values + .map(_._1.getLongCardinality).sum + val allMatchedRowCounts = collectTouchedFiles.values.map(_._2).sum + allMatchedRowCounts - distinctMatchedRowCounts + } + + val hasMultipleMatches = duplicateCount > 0 + + // Throw error if multiple matches are ambiguous or cannot be computed correctly. + val canBeComputedUnambiguously = { + // Multiple matches are not ambiguous when there is only one unconditional delete as + // all the matched row pairs in the 2nd join in `writeAllChanges` will get deleted. + val isUnconditionalDelete = context.cmd.matchedClauses.headOption match { + case Some(DeltaMergeIntoMatchedDeleteClause(None)) => true + case _ => false + } + context.cmd.matchedClauses.size == 1 && isUnconditionalDelete + } + + if (hasMultipleMatches && !canBeComputedUnambiguously) { + throw DeltaErrors.multipleSourceRowMatchingTargetRowInMergeException(context.spark) + } + + if (hasMultipleMatches) { + // This is only allowed for delete-only queries. + // This query will count the duplicates for numTargetRowsDeleted in Job 2, + // because we count matches after the join and not just the target rows. + // We have to compensate for this by subtracting the duplicates later, + // so we need to record them here. + multipleMatchDeleteOnlyOvercount = Some(duplicateCount) + } + + // Get the AddFiles using the touched file names. + val touchedFileNames = collectTouchedFiles.keys.toSeq + + val nameToAddFileMap = context.cmd.generateCandidateFileMap( + context.cmd.targetDeltaLog.dataPath, + dataSkippedFiles) + + val touchedAddFiles = touchedFileNames.map(f => + context.cmd.getTouchedFile(context.cmd.targetDeltaLog.dataPath, f, nameToAddFileMap)) + .map(f => (DeltaFileOperations + .absolutePath(context.cmd.targetDeltaLog.dataPath.toString, f.path) + .toString, f)).toMap + + // When the target table is empty, and the optimizer optimized away the join entirely + // numSourceRows will be incorrectly 0. + // We need to scan the source table once to get the correct + // metric here. + if (context.cmd.metrics("numSourceRows").value == 0 && + (dataSkippedFiles.isEmpty || dataSkippedTargetDF.take(1).isEmpty)) { + val numSourceRows = sourceDF.count() + context.cmd.metrics("numSourceRows").set(numSourceRows) + } + + // Update metrics + context.cmd.metrics("numTargetFilesBeforeSkipping") += context.deltaTxn.snapshot.numOfFiles + context.cmd.metrics("numTargetBytesBeforeSkipping") += context.deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + context.cmd.metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + context.cmd.metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + context.cmd.metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + val (removedBytes, removedPartitions) = + totalBytesAndDistinctPartitionValues(touchedAddFiles.values.toSeq) + context.cmd.metrics("numTargetFilesRemoved") += touchedAddFiles.size + context.cmd.metrics("numTargetBytesRemoved") += removedBytes + context.cmd.metrics("numTargetPartitionsRemovedFrom") += removedPartitions + + collectTouchedFiles.map(kv => (kv._1, (kv._2._1, touchedAddFiles(kv._1)))) + } + + + /** + * Modify original data frame to insert + * [[GpuDeltaParquetFileFormatUtils.METADATA_ROW_IDX_COL]]. + */ + private def addRowIndexMetaColumn(baseDF: DataFrame): DataFrame = { + val rowIdxAttr = AttributeReference( + METADATA_ROW_IDX_COL, + METADATA_ROW_IDX_FIELD.dataType, + METADATA_ROW_IDX_FIELD.nullable)() + + val newPlan = baseDF.queryExecution.analyzed.transformUp { + case r@LogicalRelation(fs: HadoopFsRelation, _, _, _) => + val newSchema = StructType(fs.dataSchema.fields).add(METADATA_ROW_IDX_FIELD) + + // This is required to ensure that row index is correctly calculated. + val newFileFormat = fs.fileFormat.asInstanceOf[DeltaParquetFileFormat] + .copy(isSplittable = false, disablePushDowns = true) + + val newFs = fs.copy(dataSchema = newSchema, fileFormat = newFileFormat)(context.spark) + + val newOutput = r.output :+ rowIdxAttr + r.copy(relation = newFs, output = newOutput) + case p@Project(projectList, _) => + val newProjectList = projectList :+ rowIdxAttr + p.copy(projectList = newProjectList) + } + + Dataset.ofRows(context.spark, newPlan) + } + + /** + * The result is scanning target table with touched files, and added an extra + * [[METADATA_ROW_DEL_COL]] to indicate whether filtered by joining with source table in first + * step. + */ + private def getTouchedTargetDF(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]) + : DataFrame = { + // Generate a new target dataframe that has same output attributes exprIds as the target plan. + // This allows us to apply the existing resolved update/insert expressions. + val baseTargetDF = buildTargetDFWithFiles(touchedFiles.values.map(_._2).toSeq) + + val newPlan = { + val rowDelAttr = AttributeReference( + METADATA_ROW_DEL_COL, + METADATA_ROW_DEL_FIELD.dataType, + METADATA_ROW_DEL_FIELD.nullable)() + + baseTargetDF.queryExecution.analyzed.transformUp { + case r@LogicalRelation(fs: HadoopFsRelation, _, _, _) => + val newSchema = StructType(fs.dataSchema.fields).add(METADATA_ROW_DEL_FIELD) + + // This is required to ensure that row index is correctly calculated. + val newFileFormat = { + val oldFormat = fs.fileFormat.asInstanceOf[DeltaParquetFileFormat] + val dvs = touchedFiles.map(kv => (new URI(kv._1), + DeletionVectorDescriptorWithFilterType(toDeletionVector(kv._2._1), + RowIndexFilterType.UNKNOWN))) + val broadcastDVs = context.spark.sparkContext.broadcast(dvs) + + oldFormat.copy(isSplittable = false, + broadcastDvMap = Some(broadcastDVs), + disablePushDowns = true) + } + + val newFs = fs.copy(dataSchema = newSchema, fileFormat = newFileFormat)(context.spark) + + val newOutput = r.output :+ rowDelAttr + r.copy(relation = newFs, output = newOutput) + case p@Project(projectList, _) => + val newProjectList = projectList :+ rowDelAttr + p.copy(projectList = newProjectList) + } + } + + val df = Dataset.ofRows(context.spark, newPlan) + .withColumn(TARGET_ROW_PRESENT_COL, lit(true)) + + df + } + + /** + * Generate a plan by calculating modified rows. It's computed by joining source and target + * tables, where target table has been filtered by (`__metadata_file_name`, + * `__metadata_row_idx`) pairs collected in first step. + * + * Schema of `modifiedDF`: + * + * targetSchema + ROW_DROPPED_COL + TARGET_ROW_PRESENT_COL + + * SOURCE_ROW_PRESENT_COL + INCR_METRICS_COL + * INCR_METRICS_COL + * + * It consists of several parts: + * + * 1. Unmatched source rows which are inserted + * 2. Unmatched source rows which are deleted + * 3. Target rows which are updated + * 4. Target rows which are deleted + */ + private def getModifiedDF(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]): DataFrame = { + val sourceDF = this.sourceDF + .withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr)) + + val targetDF = getTouchedTargetDF(touchedFiles) + + val joinedDF = { + val joinType = if (hasNoInserts && + context.spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED)) { + "inner" + } else { + "leftOuter" + } + val matchedTargetDF = targetDF.filter(METADATA_ROW_DEL_COL) + .drop(METADATA_ROW_DEL_COL) + + sourceDF.join(matchedTargetDF, new Column(context.cmd.condition), joinType) + } + + val modifiedRowsSchema = context.deltaTxn.metadata.schema + .add(ROW_DROPPED_FIELD) + .add(TARGET_ROW_PRESENT_FIELD.copy(nullable = true)) + .add(SOURCE_ROW_PRESENT_FIELD.copy(nullable = true)) + .add(INCR_METRICS_FIELD) + + // Here we generate a case when statement to handle all cases: + // CASE + // WHEN + // CASE WHEN + // + // WHEN + // + // ELSE + // + // WHEN + // CASE WHEN + // + // WHEN + // + // ELSE + // + // END + + val notMatchedConditions = context.cmd.notMatchedClauses.map(clauseCondition) + val notMatchedExpr = { + val deletedNotMatchedRow = { + targetOutputCols :+ + Literal.TrueLiteral :+ + Literal.FalseLiteral :+ + Literal(null) :+ + Literal.TrueLiteral + } + if (context.cmd.notMatchedClauses.isEmpty) { + // If there no `WHEN NOT MATCHED` clause, we should just delete not matched row + deletedNotMatchedRow + } else { + val notMatchedOutputs = context.cmd.notMatchedClauses.map(clauseOutput) + modifiedRowsSchema.zipWithIndex.map { + case (_, idx) => + CaseWhen(notMatchedConditions.zip(notMatchedOutputs.map(_(idx))), + deletedNotMatchedRow(idx)) + } + } + } + + val matchedConditions = context.cmd.matchedClauses.map(clauseCondition) + val matchedOutputs = context.cmd.matchedClauses.map(clauseOutput) + val matchedExprs = { + val notMatchedRow = { + targetOutputCols :+ + Literal.FalseLiteral :+ + Literal.TrueLiteral :+ + Literal(null) :+ + Literal.TrueLiteral + } + if (context.cmd.matchedClauses.isEmpty) { + // If there is not matched clause, this is insert only, we should delete this row. + notMatchedRow + } else { + modifiedRowsSchema.zipWithIndex.map { + case (_, idx) => + CaseWhen(matchedConditions.zip(matchedOutputs.map(_(idx))), + notMatchedRow(idx)) + } + } + } + + val sourceRowHasNoMatch = col(TARGET_ROW_PRESENT_COL).isNull.expr + + val modifiedCols = modifiedRowsSchema.zipWithIndex.map { case (col, idx) => + val caseWhen = CaseWhen( + Seq(sourceRowHasNoMatch -> notMatchedExpr(idx)), + matchedExprs(idx)) + new Column(Alias(caseWhen, col.name)()) + } + + val modifiedDF = { + + // Make this a udf to avoid catalyst to be too aggressive to even remove the join! + val noopRowDroppedCol = udf(new GpuDeltaNoopUDF()).apply(!col(ROW_DROPPED_COL)) + + val modifiedDF = joinedDF.select(modifiedCols: _*) + // This will not filter anything since they always return true, but we need to avoid + // catalyst from optimizing these udf + .filter(noopRowDroppedCol && col(INCR_METRICS_COL)) + .drop(ROW_DROPPED_COL, INCR_METRICS_COL, TARGET_ROW_PRESENT_COL, SOURCE_ROW_PRESENT_COL) + + repartitionIfNeeded(modifiedDF) + } + + modifiedDF + } + + private def getUnmodifiedDF(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]): DataFrame = { + getTouchedTargetDF(touchedFiles) + .filter(!col(METADATA_ROW_DEL_COL)) + .drop(TARGET_ROW_PRESENT_COL, METADATA_ROW_DEL_COL) + } +} + + +object MergeExecutor { + + /** + * Spark UI will track all normal accumulators along with Spark tasks to show them on Web UI. + * However, the accumulator used by `MergeIntoCommand` can store a very large value since it + * tracks all files that need to be rewritten. We should ask Spark UI to not remember it, + * otherwise, the UI data may consume lots of memory. Hence, we use the prefix `internal.metrics.` + * to make this accumulator become an internal accumulator, so that it will not be tracked by + * Spark UI. + */ + val TOUCHED_FILES_ACCUM_NAME = "internal.metrics.MergeIntoDelta.touchedFiles" + + val ROW_ID_COL = "_row_id_" + val FILE_PATH_COL: String = GpuDeltaParquetFileFormatUtils.FILE_PATH_COL + val SOURCE_ROW_PRESENT_COL: String = "_source_row_present_" + val SOURCE_ROW_PRESENT_FIELD: StructField = StructField(SOURCE_ROW_PRESENT_COL, BooleanType, + nullable = false) + val TARGET_ROW_PRESENT_COL: String = "_target_row_present_" + val TARGET_ROW_PRESENT_FIELD: StructField = StructField(TARGET_ROW_PRESENT_COL, BooleanType, + nullable = false) + val ROW_DROPPED_COL: String = GpuDeltaMergeConstants.ROW_DROPPED_COL + val ROW_DROPPED_FIELD: StructField = StructField(ROW_DROPPED_COL, BooleanType, nullable = false) + val INCR_METRICS_COL: String = "_incr_metrics_" + val INCR_METRICS_FIELD: StructField = StructField(INCR_METRICS_COL, BooleanType, nullable = false) + val INCR_ROW_COUNT_COL: String = "_incr_row_count_" + + // Some Delta versions use Literal(null) which translates to a literal of NullType instead + // of the Literal(null, StringType) which is needed, so using a fixed version here + // rather than the version from Delta Lake. + val CDC_TYPE_NOT_CDC_LITERAL: Literal = Literal(null, StringType) + + def toDeletionVector(bitmap: Roaring64Bitmap): DeletionVectorDescriptor = { + DeletionVectorDescriptor.inlineInLog(RoaringBitmapWrapper(bitmap).serializeToBytes(), + bitmap.getLongCardinality) + } + + /** Count the number of distinct partition values among the AddFiles in the given set. */ + def totalBytesAndDistinctPartitionValues(files: Seq[FileAction]): (Long, Int) = { + val distinctValues = new mutable.HashSet[Map[String, String]]() + var bytes = 0L + val iter = files.collect { case a: AddFile => a }.iterator + while (iter.hasNext) { + val file = iter.next() + distinctValues += file.partitionValues + bytes += file.size + } + // If the only distinct value map is an empty map, then it must be an unpartitioned table. + // Return 0 in that case. + val numDistinctValues = + if (distinctValues.size == 1 && distinctValues.head.isEmpty) 0 else distinctValues.size + (bytes, numDistinctValues) + } +} \ No newline at end of file diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala index 969d005b573..604ed826397 100644 --- a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,20 +16,32 @@ package com.nvidia.spark.rapids.delta +import java.net.URI + import com.databricks.sql.transaction.tahoe.{DeltaColumnMappingMode, DeltaParquetFileFormat, IdMapping} -import com.databricks.sql.transaction.tahoe.DeltaParquetFileFormat.IS_ROW_DELETED_COLUMN_NAME -import com.nvidia.spark.rapids.SparkPlanMeta +import com.databricks.sql.transaction.tahoe.DeltaParquetFileFormat.{DeletionVectorDescriptorWithFilterType, IS_ROW_DELETED_COLUMN_NAME} +import com.nvidia.spark.rapids.{GpuMetric, RapidsConf, SparkPlanMeta} +import com.nvidia.spark.rapids.delta.GpuDeltaParquetFileFormatUtils.addMetadataColumnToIterator +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch case class GpuDeltaParquetFileFormat( override val columnMappingMode: DeltaColumnMappingMode, override val referenceSchema: StructType, - isSplittable: Boolean) extends GpuDeltaParquetFileFormatBase { + isSplittable: Boolean, + disablePushDown: Boolean, + broadcastDvMap: Option[Broadcast[Map[URI, DeletionVectorDescriptorWithFilterType]]] +) extends GpuDeltaParquetFileFormatBase { if (columnMappingMode == IdMapping) { val requiredReadConf = SQLConf.PARQUET_FIELD_ID_READ_ENABLED @@ -44,6 +56,46 @@ case class GpuDeltaParquetFileFormat( sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = isSplittable + + override def buildReaderWithPartitionValuesAndMetrics( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration, + metrics: Map[String, GpuMetric], + alluxioPathReplacementMap: Option[Map[String, String]]) + : PartitionedFile => Iterator[InternalRow] = { + + val dataReader = super.buildReaderWithPartitionValuesAndMetrics( + sparkSession, + dataSchema, + partitionSchema, + requiredSchema, + filters, + options, + hadoopConf, + metrics, + alluxioPathReplacementMap) + + val delVecs = broadcastDvMap + val maxDelVecScatterBatchSize = RapidsConf + .DELTA_LOW_SHUFFLE_MERGE_SCATTER_DEL_VECTOR_BATCH_SIZE + .get(sparkSession.sessionState.conf) + + (file: PartitionedFile) => { + val input = dataReader(file) + val dv = delVecs.flatMap(_.value.get(new URI(file.filePath.toString()))) + .map(dv => RoaringBitmapWrapper.deserializeFromBytes(dv.descriptor.inlineData).inner) + addMetadataColumnToIterator(prepareSchema(requiredSchema), + dv, + input.asInstanceOf[Iterator[ColumnarBatch]], + maxDelVecScatterBatchSize + ).asInstanceOf[Iterator[InternalRow]] + } + } } object GpuDeltaParquetFileFormat { @@ -60,6 +112,7 @@ object GpuDeltaParquetFileFormat { } def convertToGpu(fmt: DeltaParquetFileFormat): GpuDeltaParquetFileFormat = { - GpuDeltaParquetFileFormat(fmt.columnMappingMode, fmt.referenceSchema, fmt.isSplittable) + GpuDeltaParquetFileFormat(fmt.columnMappingMode, fmt.referenceSchema, fmt.isSplittable, + fmt.disablePushDowns, fmt.broadcastDvMap) } } diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala index 8e13a9e4b5a..5a2b4e7b52e 100644 --- a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,13 +17,14 @@ package com.nvidia.spark.rapids.delta.shims import com.databricks.sql.transaction.tahoe.commands.{MergeIntoCommand, MergeIntoCommandEdge} -import com.databricks.sql.transaction.tahoe.rapids.{GpuDeltaLog, GpuMergeIntoCommand} -import com.nvidia.spark.rapids.RapidsConf +import com.databricks.sql.transaction.tahoe.rapids.{GpuDeltaLog, GpuLowShuffleMergeCommand, GpuMergeIntoCommand} +import com.nvidia.spark.rapids.{RapidsConf, RapidsReaderType} import com.nvidia.spark.rapids.delta.{MergeIntoCommandEdgeMeta, MergeIntoCommandMeta} +import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.command.RunnableCommand -object MergeIntoCommandMetaShim { +object MergeIntoCommandMetaShim extends Logging { def tagForGpu(meta: MergeIntoCommandMeta, mergeCmd: MergeIntoCommand): Unit = { // see https://github.com/NVIDIA/spark-rapids/issues/8415 for more information if (mergeCmd.notMatchedBySourceClauses.nonEmpty) { @@ -39,26 +40,82 @@ object MergeIntoCommandMetaShim { } def convertToGpu(mergeCmd: MergeIntoCommand, conf: RapidsConf): RunnableCommand = { - GpuMergeIntoCommand( - mergeCmd.source, - mergeCmd.target, - new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), - mergeCmd.condition, - mergeCmd.matchedClauses, - mergeCmd.notMatchedClauses, - mergeCmd.notMatchedBySourceClauses, - mergeCmd.migratedSchema)(conf) + // TODO: Currently we only support low shuffler merge only when parquet per file read is enabled + // due to the limitation of implementing row index metadata column. + if (conf.isDeltaLowShuffleMergeEnabled) { + if (conf.isParquetPerFileReadEnabled) { + GpuLowShuffleMergeCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } else { + logWarning(s"""Low shuffle merge disabled since ${RapidsConf.PARQUET_READER_TYPE} is + not set to ${RapidsReaderType.PERFILE}. Falling back to classic merge.""") + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } + } else { + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } } def convertToGpu(mergeCmd: MergeIntoCommandEdge, conf: RapidsConf): RunnableCommand = { - GpuMergeIntoCommand( - mergeCmd.source, - mergeCmd.target, - new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), - mergeCmd.condition, - mergeCmd.matchedClauses, - mergeCmd.notMatchedClauses, - mergeCmd.notMatchedBySourceClauses, - mergeCmd.migratedSchema)(conf) + // TODO: Currently we only support low shuffler merge only when parquet per file read is enabled + // due to the limitation of implementing row index metadata column. + if (conf.isDeltaLowShuffleMergeEnabled) { + if (conf.isParquetPerFileReadEnabled) { + GpuLowShuffleMergeCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } else { + logWarning(s"""Low shuffle merge is still disable since ${RapidsConf.PARQUET_READER_TYPE} is + not set to ${RapidsReaderType.PERFILE}. Falling back to classic merge.""") + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } + } else { + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } } } diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index 3231b7b3069..941ab4046e6 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -73,6 +73,12 @@ Name | Description | Default Value | Applicable at spark.rapids.sql.csv.read.double.enabled|CSV reading is not 100% compatible when reading doubles.|true|Runtime spark.rapids.sql.csv.read.float.enabled|CSV reading is not 100% compatible when reading floats.|true|Runtime spark.rapids.sql.decimalOverflowGuarantees|FOR TESTING ONLY. DO NOT USE IN PRODUCTION. Please see the decimal section of the compatibility documents for more information on this config.|true|Runtime +spark.rapids.sql.delta.lowShuffleMerge.deletionVector.broadcast.threshold|Currently we need to broadcast deletion vector to all executors to perform low shuffle merge. When we detect the deletion vector broadcast size is larger than this value, we will fallback to normal shuffle merge.|20971520|Runtime +spark.rapids.sql.delta.lowShuffleMerge.enabled|Option to turn on the low shuffle merge for Delta Lake. Currently there are some limitations for this feature: +1. We only support Databricks Runtime 13.3 and Deltalake 2.4. +2. The file scan mode must be set to PERFILE +3. The deletion vector size must be smaller than spark.rapids.sql.delta.lowShuffleMerge.deletionVector.broadcast.threshold +|false|Runtime spark.rapids.sql.detectDeltaCheckpointQueries|Queries against Delta Lake _delta_log checkpoint Parquet files are not efficient on the GPU. When this option is enabled, the plugin will attempt to detect these queries and fall back to the CPU.|true|Runtime spark.rapids.sql.detectDeltaLogQueries|Queries against Delta Lake _delta_log JSON files are not efficient on the GPU. When this option is enabled, the plugin will attempt to detect these queries and fall back to the CPU.|true|Runtime spark.rapids.sql.fast.sample|Option to turn on fast sample. If enable it is inconsistent with CPU sample because of GPU sample algorithm is inconsistent with CPU.|false|Runtime diff --git a/integration_tests/src/main/python/delta_lake_low_shuffle_merge_test.py b/integration_tests/src/main/python/delta_lake_low_shuffle_merge_test.py new file mode 100644 index 00000000000..6935ee13751 --- /dev/null +++ b/integration_tests/src/main/python/delta_lake_low_shuffle_merge_test.py @@ -0,0 +1,165 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyspark.sql.functions as f +import pytest + +from delta_lake_merge_common import * +from marks import * +from pyspark.sql.types import * +from spark_session import is_databricks133_or_later, spark_version + +delta_merge_enabled_conf = copy_and_update(delta_writes_enabled_conf, + {"spark.rapids.sql.command.MergeIntoCommand": "true", + "spark.rapids.sql.command.MergeIntoCommandEdge": "true", + "spark.rapids.sql.delta.lowShuffleMerge.enabled": "true", + "spark.rapids.sql.format.parquet.reader.type": "PERFILE"}) + +@allow_non_gpu("ColumnarToRowExec", *delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) +def test_delta_low_shuffle_merge_when_gpu_file_scan_override_failed(spark_tmp_path, + spark_tmp_table_factory, + use_cdf, num_slices): + # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous + src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) + dest_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, seed=1, num_slices=num_slices) + merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *" + + conf = copy_and_update(delta_merge_enabled_conf, + { + "spark.rapids.sql.exec.FileSourceScanExec": "false", + # Disable auto broadcast join due to this issue: + # https://github.com/NVIDIA/spark-rapids/issues/10973 + "spark.sql.autoBroadcastJoinThreshold": "-1" + }) + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, False, conf=conf) + + + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("table_ranges", [(range(20), range(10)), # partial insert of source + (range(5), range(5)), # no-op insert + (range(10), range(20, 30)) # full insert of source + ], ids=idfn) +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +@pytest.mark.parametrize("partition_columns", [None, ["a"], ["b"], ["a", "b"]], ids=idfn) +@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) +def test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, + use_cdf, partition_columns, num_slices): + do_test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_factory, + table_ranges, use_cdf, partition_columns, + num_slices, False, delta_merge_enabled_conf) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("table_ranges", [(range(10), range(20)), # partial delete of target + (range(5), range(5)), # full delete of target + (range(10), range(20, 30)) # no-op delete + ], ids=idfn) +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +@pytest.mark.parametrize("partition_columns", [None, ["a"], ["b"], ["a", "b"]], ids=idfn) +@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) +def test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, + use_cdf, partition_columns, num_slices): + do_test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, + use_cdf, partition_columns, num_slices, False, + delta_merge_enabled_conf) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) +def test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, use_cdf, num_slices): + do_test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, use_cdf, + num_slices, False, delta_merge_enabled_conf) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +@pytest.mark.parametrize("merge_sql", [ + "MERGE INTO {dest_table} d USING {src_table} s ON d.a == s.a" \ + " WHEN MATCHED AND s.b > 'q' THEN UPDATE SET d.a = s.a / 2, d.b = s.b" \ + " WHEN NOT MATCHED THEN INSERT *", + "MERGE INTO {dest_table} d USING {src_table} s ON d.a == s.a" \ + " WHEN NOT MATCHED AND s.b > 'q' THEN INSERT *", + "MERGE INTO {dest_table} d USING {src_table} s ON d.a == s.a" \ + " WHEN MATCHED AND s.b > 'a' AND s.b < 'g' THEN UPDATE SET d.a = s.a / 2, d.b = s.b" \ + " WHEN MATCHED AND s.b > 'g' AND s.b < 'z' THEN UPDATE SET d.a = s.a / 4, d.b = concat('extra', s.b)" \ + " WHEN NOT MATCHED AND s.b > 'b' AND s.b < 'f' THEN INSERT *" \ + " WHEN NOT MATCHED AND s.b > 'f' AND s.b < 'z' THEN INSERT (b) VALUES ('not here')" ], ids=idfn) +@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) +def test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, merge_sql, num_slices): + do_test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, + merge_sql, num_slices, False, + delta_merge_enabled_conf) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) +def test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, num_slices): + do_test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, + spark_tmp_table_factory, + use_cdf, + num_slices, + False, + delta_merge_enabled_conf) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +def test_delta_merge_update_with_aggregation(spark_tmp_path, spark_tmp_table_factory, use_cdf): + do_test_delta_merge_update_with_aggregation(spark_tmp_path, spark_tmp_table_factory, use_cdf, + delta_merge_enabled_conf) + diff --git a/integration_tests/src/main/python/delta_lake_merge_common.py b/integration_tests/src/main/python/delta_lake_merge_common.py new file mode 100644 index 00000000000..e6e9676625d --- /dev/null +++ b/integration_tests/src/main/python/delta_lake_merge_common.py @@ -0,0 +1,155 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyspark.sql.functions as f +import string +from pyspark.sql.types import * + +from asserts import * +from data_gen import * +from delta_lake_utils import * +from spark_session import is_databricks_runtime + +# Databricks changes the number of files being written, so we cannot compare logs +num_slices_to_test = [10] if is_databricks_runtime() else [1, 10] + + +def make_df(spark, gen, num_slices): + return three_col_df(spark, gen, SetValuesGen(StringType(), string.ascii_lowercase), + SetValuesGen(StringType(), string.ascii_uppercase), num_slices=num_slices) + + +def delta_sql_merge_test(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, check_func, + partition_columns=None): + data_path = spark_tmp_path + "/DELTA_DATA" + src_table = spark_tmp_table_factory.get() + + def setup_tables(spark): + setup_delta_dest_tables(spark, data_path, dest_table_func, use_cdf, partition_columns) + src_table_func(spark).createOrReplaceTempView(src_table) + + def do_merge(spark, path): + dest_table = spark_tmp_table_factory.get() + read_delta_path(spark, path).createOrReplaceTempView(dest_table) + return spark.sql(merge_sql.format(src_table=src_table, dest_table=dest_table)).collect() + with_cpu_session(setup_tables) + check_func(data_path, do_merge) + + +def assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, + compare_logs, partition_columns=None, conf=None): + assert conf is not None, "conf must be set" + + def read_data(spark, path): + read_func = read_delta_path_with_cdf if use_cdf else read_delta_path + df = read_func(spark, path) + return df.sort(df.columns) + + def checker(data_path, do_merge): + cpu_path = data_path + "/CPU" + gpu_path = data_path + "/GPU" + # compare resulting dataframe from the merge operation (some older Spark versions return empty here) + cpu_result = with_cpu_session(lambda spark: do_merge(spark, cpu_path), conf=conf) + gpu_result = with_gpu_session(lambda spark: do_merge(spark, gpu_path), conf=conf) + assert_equal(cpu_result, gpu_result) + # compare merged table data results, read both via CPU to make sure GPU write can be read by CPU + cpu_result = with_cpu_session(lambda spark: read_data(spark, cpu_path).collect(), conf=conf) + gpu_result = with_cpu_session(lambda spark: read_data(spark, gpu_path).collect(), conf=conf) + assert_equal(cpu_result, gpu_result) + # Using partition columns involves sorting, and there's no guarantees on the task + # partitioning due to random sampling. + if compare_logs and not partition_columns: + with_cpu_session(lambda spark: assert_gpu_and_cpu_delta_logs_equivalent(spark, data_path)) + delta_sql_merge_test(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, checker, partition_columns) + + +def do_test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, + use_cdf, partition_columns, num_slices, compare_logs, + conf): + src_range, dest_range = table_ranges + src_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), src_range), num_slices) + dest_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), dest_range), num_slices) + merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ + " WHEN NOT MATCHED THEN INSERT *" + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, compare_logs, + partition_columns, conf=conf) + + +def do_test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, + use_cdf, partition_columns, num_slices, compare_logs, + conf): + src_range, dest_range = table_ranges + src_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), src_range), num_slices) + dest_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), dest_range), num_slices) + merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ + " WHEN MATCHED THEN DELETE" + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, compare_logs, + partition_columns, conf=conf) + + +def do_test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, use_cdf, + num_slices, compare_logs, conf): + # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous + src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) + dest_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, seed=1, num_slices=num_slices) + merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *" + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, compare_logs, + conf=conf) + + +def do_test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, + merge_sql, num_slices, compare_logs, conf): + # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous + src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) + dest_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, seed=1, num_slices=num_slices) + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, compare_logs, + conf=conf) + + +def do_test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, + spark_tmp_table_factory, use_cdf, + num_slices, compare_logs, conf): + # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous + src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) + dest_table_func = lambda spark: two_col_df(spark, SetValuesGen(IntegerType(), range(100)), string_gen, seed=1, num_slices=num_slices) + merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ + " WHEN MATCHED AND {dest_table}.a > 100 THEN UPDATE SET *" + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, compare_logs, + conf=conf) + + +def do_test_delta_merge_update_with_aggregation(spark_tmp_path, spark_tmp_table_factory, use_cdf, + conf): + # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous + src_table_func = lambda spark: spark.range(10).withColumn("x", f.col("id") + 1) \ + .select(f.col("id"), (f.col("x") + 1).alias("x")) \ + .drop_duplicates(["id"]) \ + .limit(10) + dest_table_func = lambda spark: spark.range(5).withColumn("x", f.col("id") + 1) + merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.id == {src_table}.id" \ + " WHEN MATCHED THEN UPDATE SET {dest_table}.x = {src_table}.x + 2" \ + " WHEN NOT MATCHED AND {src_table}.x < 7 THEN INSERT *" + + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, + compare_logs=False, conf=conf) diff --git a/integration_tests/src/main/python/delta_lake_merge_test.py b/integration_tests/src/main/python/delta_lake_merge_test.py index 0880db16434..5c3bb915ddb 100644 --- a/integration_tests/src/main/python/delta_lake_merge_test.py +++ b/integration_tests/src/main/python/delta_lake_merge_test.py @@ -14,66 +14,17 @@ import pyspark.sql.functions as f import pytest -import string -from asserts import * -from data_gen import * -from delta_lake_utils import * +from delta_lake_merge_common import * from marks import * from pyspark.sql.types import * from spark_session import is_before_spark_320, is_databricks_runtime, spark_version -# Databricks changes the number of files being written, so we cannot compare logs -num_slices_to_test = [10] if is_databricks_runtime() else [1, 10] delta_merge_enabled_conf = copy_and_update(delta_writes_enabled_conf, {"spark.rapids.sql.command.MergeIntoCommand": "true", "spark.rapids.sql.command.MergeIntoCommandEdge": "true"}) -def make_df(spark, gen, num_slices): - return three_col_df(spark, gen, SetValuesGen(StringType(), string.ascii_lowercase), - SetValuesGen(StringType(), string.ascii_uppercase), num_slices=num_slices) - -def delta_sql_merge_test(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, check_func, - partition_columns=None): - data_path = spark_tmp_path + "/DELTA_DATA" - src_table = spark_tmp_table_factory.get() - def setup_tables(spark): - setup_delta_dest_tables(spark, data_path, dest_table_func, use_cdf, partition_columns) - src_table_func(spark).createOrReplaceTempView(src_table) - def do_merge(spark, path): - dest_table = spark_tmp_table_factory.get() - read_delta_path(spark, path).createOrReplaceTempView(dest_table) - return spark.sql(merge_sql.format(src_table=src_table, dest_table=dest_table)).collect() - with_cpu_session(setup_tables) - check_func(data_path, do_merge) - -def assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, - compare_logs, partition_columns=None, - conf=delta_merge_enabled_conf): - def read_data(spark, path): - read_func = read_delta_path_with_cdf if use_cdf else read_delta_path - df = read_func(spark, path) - return df.sort(df.columns) - def checker(data_path, do_merge): - cpu_path = data_path + "/CPU" - gpu_path = data_path + "/GPU" - # compare resulting dataframe from the merge operation (some older Spark versions return empty here) - cpu_result = with_cpu_session(lambda spark: do_merge(spark, cpu_path), conf=conf) - gpu_result = with_gpu_session(lambda spark: do_merge(spark, gpu_path), conf=conf) - assert_equal(cpu_result, gpu_result) - # compare merged table data results, read both via CPU to make sure GPU write can be read by CPU - cpu_result = with_cpu_session(lambda spark: read_data(spark, cpu_path).collect(), conf=conf) - gpu_result = with_cpu_session(lambda spark: read_data(spark, gpu_path).collect(), conf=conf) - assert_equal(cpu_result, gpu_result) - # Using partition columns involves sorting, and there's no guarantees on the task - # partitioning due to random sampling. - if compare_logs and not partition_columns: - with_cpu_session(lambda spark: assert_gpu_and_cpu_delta_logs_equivalent(spark, data_path)) - delta_sql_merge_test(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, checker, partition_columns) @allow_non_gpu(delta_write_fallback_allow, *delta_meta_allow) @delta_lake @@ -162,16 +113,9 @@ def test_delta_merge_partial_fallback_via_conf(spark_tmp_path, spark_tmp_table_f @pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) def test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, use_cdf, partition_columns, num_slices): - src_range, dest_range = table_ranges - src_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), src_range), num_slices) - dest_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), dest_range), num_slices) - merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ - " WHEN NOT MATCHED THEN INSERT *" - # Non-deterministic input for each task means we can only reliably compare record counts when using only one task - compare_logs = num_slices == 1 - assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, compare_logs, - partition_columns) + do_test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_factory, + table_ranges, use_cdf, partition_columns, + num_slices, num_slices == 1, delta_merge_enabled_conf) @allow_non_gpu(*delta_meta_allow) @delta_lake @@ -186,16 +130,9 @@ def test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_facto @pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) def test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, use_cdf, partition_columns, num_slices): - src_range, dest_range = table_ranges - src_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), src_range), num_slices) - dest_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), dest_range), num_slices) - merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ - " WHEN MATCHED THEN DELETE" - # Non-deterministic input for each task means we can only reliably compare record counts when using only one task - compare_logs = num_slices == 1 - assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, compare_logs, - partition_columns) + do_test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, + use_cdf, partition_columns, num_slices, num_slices == 1, + delta_merge_enabled_conf) @allow_non_gpu(*delta_meta_allow) @delta_lake @@ -204,15 +141,9 @@ def test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, @pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) @pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) def test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, use_cdf, num_slices): - # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous - src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) - dest_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, seed=1, num_slices=num_slices) - merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ - " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *" - # Non-deterministic input for each task means we can only reliably compare record counts when using only one task - compare_logs = num_slices == 1 - assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, compare_logs) + do_test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, use_cdf, + num_slices, num_slices == 1, delta_merge_enabled_conf) + @allow_non_gpu(*delta_meta_allow) @delta_lake @@ -232,13 +163,10 @@ def test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, us " WHEN NOT MATCHED AND s.b > 'f' AND s.b < 'z' THEN INSERT (b) VALUES ('not here')" ], ids=idfn) @pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) def test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, merge_sql, num_slices): - # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous - src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) - dest_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, seed=1, num_slices=num_slices) - # Non-deterministic input for each task means we can only reliably compare record counts when using only one task - compare_logs = num_slices == 1 - assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, compare_logs) + do_test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, + merge_sql, num_slices, num_slices == 1, + delta_merge_enabled_conf) + @allow_non_gpu(*delta_meta_allow) @delta_lake @@ -247,15 +175,10 @@ def test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_facto @pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) @pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) def test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, num_slices): - # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous - src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) - dest_table_func = lambda spark: two_col_df(spark, SetValuesGen(IntegerType(), range(100)), string_gen, seed=1, num_slices=num_slices) - merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ - " WHEN MATCHED AND {dest_table}.a > 100 THEN UPDATE SET *" - # Non-deterministic input for each task means we can only reliably compare record counts when using only one task - compare_logs = num_slices == 1 - assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, compare_logs) + do_test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, + spark_tmp_table_factory, use_cdf, + num_slices, num_slices == 1, + delta_merge_enabled_conf) @allow_non_gpu(*delta_meta_allow) @delta_lake @@ -263,18 +186,8 @@ def test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, spa @pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") @pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) def test_delta_merge_update_with_aggregation(spark_tmp_path, spark_tmp_table_factory, use_cdf): - # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous - src_table_func = lambda spark: spark.range(10).withColumn("x", f.col("id") + 1)\ - .select(f.col("id"), (f.col("x") + 1).alias("x"))\ - .drop_duplicates(["id"])\ - .limit(10) - dest_table_func = lambda spark: spark.range(5).withColumn("x", f.col("id") + 1) - merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.id == {src_table}.id" \ - " WHEN MATCHED THEN UPDATE SET {dest_table}.x = {src_table}.x + 2" \ - " WHEN NOT MATCHED AND {src_table}.x < 7 THEN INSERT *" - - assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, compare_logs=False) + do_test_delta_merge_update_with_aggregation(spark_tmp_path, spark_tmp_table_factory, use_cdf, + delta_merge_enabled_conf) @allow_non_gpu(*delta_meta_allow) @delta_lake diff --git a/pom.xml b/pom.xml index 06947857521..3ff87c3cb97 100644 --- a/pom.xml +++ b/pom.xml @@ -733,6 +733,7 @@ --> -Xlint:all,-serial,-path,-try,-processing|-Werror 1.16.0 + 1.0.6 ${ucx.baseVersion} true @@ -1016,6 +1017,15 @@ ${alluxio.client.version} provided + + + org.roaringbitmap + RoaringBitmap + ${roaringbitmap.version} + compile + org.scalatest scalatest_${scala.binary.version} diff --git a/scala2.13/aggregator/pom.xml b/scala2.13/aggregator/pom.xml index 1d70c76f037..198b62d5fa6 100644 --- a/scala2.13/aggregator/pom.xml +++ b/scala2.13/aggregator/pom.xml @@ -94,6 +94,10 @@ com.google.flatbuffers ${rapids.shade.package}.com.google.flatbuffers + + org.roaringbitmap + ${rapids.shade.package}.org.roaringbitmap + diff --git a/scala2.13/pom.xml b/scala2.13/pom.xml index cbc4aecbd26..e32a64f0529 100644 --- a/scala2.13/pom.xml +++ b/scala2.13/pom.xml @@ -733,6 +733,7 @@ --> -Xlint:all,-serial,-path,-try,-processing|-Werror 1.16.0 + 1.0.6 ${ucx.baseVersion} true @@ -1016,6 +1017,15 @@ ${alluxio.client.version} provided + + + org.roaringbitmap + RoaringBitmap + ${roaringbitmap.version} + compile + org.scalatest scalatest_${scala.binary.version} diff --git a/scala2.13/sql-plugin/pom.xml b/scala2.13/sql-plugin/pom.xml index df3532a3592..eb6f240a3f6 100644 --- a/scala2.13/sql-plugin/pom.xml +++ b/scala2.13/sql-plugin/pom.xml @@ -97,6 +97,10 @@ org.alluxio alluxio-shaded-client + + org.roaringbitmap + RoaringBitmap + org.mockito mockito-core diff --git a/sql-plugin/pom.xml b/sql-plugin/pom.xml index 961e6f08372..08657a9d40b 100644 --- a/sql-plugin/pom.xml +++ b/sql-plugin/pom.xml @@ -97,6 +97,10 @@ org.alluxio alluxio-shaded-client + + org.roaringbitmap + RoaringBitmap + org.mockito mockito-core diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 8ea1641fb4a..5203e926efa 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -2274,6 +2274,32 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .integerConf .createWithDefault(1024) + val DELTA_LOW_SHUFFLE_MERGE_SCATTER_DEL_VECTOR_BATCH_SIZE = + conf("spark.rapids.sql.delta.lowShuffleMerge.deletion.scatter.max.size") + .doc("Option to set max batch size when scattering deletion vector") + .internal() + .integerConf + .createWithDefault(32 * 1024) + + val DELTA_LOW_SHUFFLE_MERGE_DEL_VECTOR_BROADCAST_THRESHOLD = + conf("spark.rapids.sql.delta.lowShuffleMerge.deletionVector.broadcast.threshold") + .doc("Currently we need to broadcast deletion vector to all executors to perform low " + + "shuffle merge. When we detect the deletion vector broadcast size is larger than this " + + "value, we will fallback to normal shuffle merge.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(20 * 1024 * 1024) + + val ENABLE_DELTA_LOW_SHUFFLE_MERGE = + conf("spark.rapids.sql.delta.lowShuffleMerge.enabled") + .doc("Option to turn on the low shuffle merge for Delta Lake. Currently there are some " + + "limitations for this feature: \n" + + "1. We only support Databricks Runtime 13.3 and Deltalake 2.4. \n" + + s"2. The file scan mode must be set to ${RapidsReaderType.PERFILE} \n" + + "3. The deletion vector size must be smaller than " + + s"${DELTA_LOW_SHUFFLE_MERGE_DEL_VECTOR_BROADCAST_THRESHOLD.key} \n") + .booleanConf + .createWithDefault(false) + private def printSectionHeader(category: String): Unit = println(s"\n### $category") @@ -3083,6 +3109,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val testGetJsonObjectSaveRows: Int = get(TEST_GET_JSON_OBJECT_SAVE_ROWS) + lazy val isDeltaLowShuffleMergeEnabled: Boolean = get(ENABLE_DELTA_LOW_SHUFFLE_MERGE) + private val optimizerDefaults = Map( // this is not accurate because CPU projections do have a cost due to appending values // to each row that is produced, but this needs to be a really small number because From 4b449034f2a0105c687646176590b349f9901ea7 Mon Sep 17 00:00:00 2001 From: Liangcai Li Date: Mon, 24 Jun 2024 09:32:03 +0800 Subject: [PATCH 3/6] Support bucketing write for GPU (#10957) This PR adds the GPU support for the bucketing write. - React the code of the dynamic partition single writer and concurrent writer to try to reuse the code as much as possible, and then add in the bucketing write logic for both of them. - Update the bucket check during the plan overriding for the write commands, including InsertIntoHadoopFsRelationCommand, CreateDataSourceTableAsSelectCommand, InsertIntoHiveTable, CreateHiveTableAsSelectCommand. - From 330, Spark also supports HiveHash to generate the bucket IDs, in addition to Murmur3Hash. So the shim object GpuBucketingUtils is introduced to handle the shim things. - This change also adds two functions (tagForHiveBucketingWrite and tagForBucketing) to do the overriding check for the two hashing functions separately. And the Hive write nodes will fall back to CPU when HiveHash is chosen, because HiveHash is not supported on GPU. --------- Signed-off-by: Firestarman --- integration_tests/src/main/python/asserts.py | 6 +- .../src/main/python/orc_write_test.py | 48 +- .../src/main/python/parquet_write_test.py | 79 +- .../rapids/GpuHashPartitioningBase.scala | 8 +- .../nvidia/spark/rapids/GpuOverrides.scala | 7 +- .../sql/hive/rapids/GpuHiveFileFormat.scala | 6 +- .../sql/rapids/GpuFileFormatDataWriter.scala | 1112 +++++++---------- ...aSourceTableAsSelectCommandMetaShims.scala | 8 +- ...dCreateHiveTableAsSelectCommandShims.scala | 5 +- .../shims/spark311/GpuBucketingUtils.scala | 77 ++ .../GpuCreateHiveTableAsSelectCommand.scala | 9 +- .../rapids/shims/GpuInsertIntoHiveTable.scala | 5 +- .../sql/rapids/GpuFileFormatWriter.scala | 15 +- .../shims/spark330/GpuBucketingUtils.scala | 88 ++ ...aSourceTableAsSelectCommandMetaShims.scala | 12 +- .../rapids/shims/GpuInsertIntoHiveTable.scala | 7 +- ...dCreateHiveTableAsSelectCommandShims.scala | 6 +- .../sql/rapids/GpuFileFormatWriter.scala | 15 +- .../rapids/GpuFileFormatDataWriterSuite.scala | 132 +- 19 files changed, 896 insertions(+), 749 deletions(-) create mode 100644 sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/spark311/GpuBucketingUtils.scala create mode 100644 sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/spark330/GpuBucketingUtils.scala diff --git a/integration_tests/src/main/python/asserts.py b/integration_tests/src/main/python/asserts.py index 32416612d26..b861e89b726 100644 --- a/integration_tests/src/main/python/asserts.py +++ b/integration_tests/src/main/python/asserts.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -101,6 +101,10 @@ def _assert_equal(cpu, gpu, float_check, path): else: assert False, "Found unexpected type {} at {}".format(t, path) +def assert_equal_with_local_sort(cpu, gpu): + _sort_locally(cpu, gpu) + assert_equal(cpu, gpu) + def assert_equal(cpu, gpu): """Verify that the result from the CPU and the GPU are equal""" try: diff --git a/integration_tests/src/main/python/orc_write_test.py b/integration_tests/src/main/python/orc_write_test.py index 8d3013cbe8b..5b5c7b786b6 100644 --- a/integration_tests/src/main/python/orc_write_test.py +++ b/integration_tests/src/main/python/orc_write_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -209,7 +209,7 @@ def test_write_sql_save_table(spark_tmp_path, orc_gens, ts_type, orc_impl, spark @pytest.mark.parametrize('codec', ['zlib', 'lzo']) def test_orc_write_compression_fallback(spark_tmp_path, codec, spark_tmp_table_factory): gen = TimestampGen() - data_path = spark_tmp_path + '/PARQUET_DATA' + data_path = spark_tmp_path + '/ORC_DATA' all_confs={'spark.sql.orc.compression.codec': codec, 'spark.rapids.sql.format.orc.write.enabled': True} assert_gpu_fallback_write( lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("orc").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), @@ -218,17 +218,45 @@ def test_orc_write_compression_fallback(spark_tmp_path, codec, spark_tmp_table_f 'DataWritingCommandExec', conf=all_confs) -@ignore_order -@allow_non_gpu('DataWritingCommandExec,ExecutedCommandExec,WriteFilesExec') -def test_buckets_write_fallback(spark_tmp_path, spark_tmp_table_factory): +@ignore_order(local=True) +def test_buckets_write_round_trip(spark_tmp_path, spark_tmp_table_factory): data_path = spark_tmp_path + '/ORC_DATA' + gen_list = [["id", int_gen], ["data", long_gen]] + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list).selectExpr("id % 100 as b_id", "data").write + .bucketBy(4, "b_id").format('orc').mode('overwrite').option("path", path) + .saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.orc(path), + data_path, + conf={'spark.rapids.sql.format.orc.write.enabled': True}) + +@ignore_order(local=True) +@allow_non_gpu('DataWritingCommandExec,ExecutedCommandExec,WriteFilesExec, SortExec') +def test_buckets_write_fallback_unsupported_types(spark_tmp_path, spark_tmp_table_factory): + data_path = spark_tmp_path + '/ORC_DATA' + gen_list = [["id", binary_gen], ["data", long_gen]] assert_gpu_fallback_write( - lambda spark, path: spark.range(10e4).write.bucketBy(4, "id").sortBy("id").format('orc').mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), - lambda spark, path: spark.read.orc(path), - data_path, - 'DataWritingCommandExec', - conf = {'spark.rapids.sql.format.orc.write.enabled': True}) + lambda spark, path: gen_df(spark, gen_list).selectExpr("id as b_id", "data").write + .bucketBy(4, "b_id").format('orc').mode('overwrite').option("path", path) + .saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.orc(path), + data_path, + 'DataWritingCommandExec', + conf={'spark.rapids.sql.format.orc.write.enabled': True}) +@ignore_order(local=True) +def test_partitions_and_buckets_write_round_trip(spark_tmp_path, spark_tmp_table_factory): + data_path = spark_tmp_path + '/ORC_DATA' + gen_list = [["id", int_gen], ["data", long_gen]] + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list) + .selectExpr("id % 5 as b_id", "id % 10 as p_id", "data").write + .partitionBy("p_id") + .bucketBy(4, "b_id").format('orc').mode('overwrite').option("path", path) + .saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.orc(path), + data_path, + conf={'spark.rapids.sql.format.orc.write.enabled': True}) @ignore_order @allow_non_gpu('DataWritingCommandExec,ExecutedCommandExec,WriteFilesExec') diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 38dab9e84a4..805a0b8137c 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -409,16 +409,81 @@ def test_parquet_writeLegacyFormat_fallback(spark_tmp_path, spark_tmp_table_fact 'DataWritingCommandExec', conf=all_confs) -@ignore_order -@allow_non_gpu('DataWritingCommandExec,ExecutedCommandExec,WriteFilesExec') -def test_buckets_write_fallback(spark_tmp_path, spark_tmp_table_factory): +@ignore_order(local=True) +def test_buckets_write_round_trip(spark_tmp_path, spark_tmp_table_factory): + data_path = spark_tmp_path + '/PARQUET_DATA' + gen_list = [["id", int_gen], ["data", long_gen]] + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list).selectExpr("id % 100 as b_id", "data").write + .bucketBy(4, "b_id").format('parquet').mode('overwrite').option("path", path) + .saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.parquet(path), + data_path, + conf=writer_confs) + + +def test_buckets_write_correctness(spark_tmp_path, spark_tmp_table_factory): + cpu_path = spark_tmp_path + '/PARQUET_DATA/CPU' + gpu_path = spark_tmp_path + '/PARQUET_DATA/GPU' + gen_list = [["id", int_gen], ["data", long_gen]] + num_buckets = 4 + + def do_bucketing_write(spark, path): + df = gen_df(spark, gen_list).selectExpr("id % 100 as b_id", "data") + df.write.bucketBy(num_buckets, "b_id").format('parquet').mode('overwrite') \ + .option("path", path).saveAsTable(spark_tmp_table_factory.get()) + + def read_single_bucket(path, bucket_id): + # Bucket Id string format: f"_$id%05d" + ".c$fileCounter%03d". + # fileCounter is always 0 in this test. For example '_00002.c000' is for + # bucket id being 2. + # We leverage this bucket segment in the file path to filter rows belong + # to a bucket. + bucket_segment = '_' + "{}".format(bucket_id).rjust(5, '0') + '.c000' + return with_cpu_session( + lambda spark: spark.read.parquet(path) + .withColumn('file_name', f.input_file_name()) + .filter(f.col('file_name').contains(bucket_segment)) + .selectExpr('b_id', 'data') # need to drop the file_name column for comparison. + .collect()) + + with_cpu_session(lambda spark: do_bucketing_write(spark, cpu_path), writer_confs) + with_gpu_session(lambda spark: do_bucketing_write(spark, gpu_path), writer_confs) + cur_bucket_id = 0 + while cur_bucket_id < num_buckets: + # Verify the result bucket by bucket + ret_cpu = read_single_bucket(cpu_path, cur_bucket_id) + ret_gpu = read_single_bucket(gpu_path, cur_bucket_id) + assert_equal_with_local_sort(ret_cpu, ret_gpu) + cur_bucket_id += 1 + +@ignore_order(local=True) +@allow_non_gpu('DataWritingCommandExec,ExecutedCommandExec,WriteFilesExec, SortExec') +def test_buckets_write_fallback_unsupported_types(spark_tmp_path, spark_tmp_table_factory): data_path = spark_tmp_path + '/PARQUET_DATA' + gen_list = [["id", binary_gen], ["data", long_gen]] assert_gpu_fallback_write( - lambda spark, path: spark.range(10e4).write.bucketBy(4, "id").sortBy("id").format('parquet').mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), - lambda spark, path: spark.read.parquet(path), - data_path, - 'DataWritingCommandExec') + lambda spark, path: gen_df(spark, gen_list).selectExpr("id as b_id", "data").write + .bucketBy(4, "b_id").format('parquet').mode('overwrite').option("path", path) + .saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.parquet(path), + data_path, + 'DataWritingCommandExec', + conf=writer_confs) +@ignore_order(local=True) +def test_partitions_and_buckets_write_round_trip(spark_tmp_path, spark_tmp_table_factory): + data_path = spark_tmp_path + '/PARQUET_DATA' + gen_list = [["id", int_gen], ["data", long_gen]] + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list) + .selectExpr("id % 5 as b_id", "id % 10 as p_id", "data").write + .partitionBy("p_id") + .bucketBy(4, "b_id").format('parquet').mode('overwrite').option("path", path) + .saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.parquet(path), + data_path, + conf=writer_confs) @ignore_order @allow_non_gpu('DataWritingCommandExec,ExecutedCommandExec,WriteFilesExec') diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioningBase.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioningBase.scala index b17b2782e90..baa009d0669 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioningBase.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioningBase.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.rapids.GpuMurmur3Hash +import org.apache.spark.sql.rapids.{GpuMurmur3Hash, GpuPmod} import org.apache.spark.sql.types.{DataType, IntegerType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -59,6 +59,10 @@ abstract class GpuHashPartitioningBase(expressions: Seq[Expression], numPartitio sliceInternalGpuOrCpuAndClose(numRows, partitionIndexes, partitionColumns) } } + + def partitionIdExpression: GpuExpression = GpuPmod( + GpuMurmur3Hash(expressions, GpuHashPartitioningBase.DEFAULT_HASH_SEED), + GpuLiteral(numPartitions)) } object GpuHashPartitioningBase { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 295480d24cc..9e26cf751f4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -322,10 +322,11 @@ final class InsertIntoHadoopFsRelationCommandMeta( private var fileFormat: Option[ColumnarFileFormat] = None override def tagSelfForGpuInternal(): Unit = { - if (cmd.bucketSpec.isDefined) { - willNotWorkOnGpu("bucketing is not supported") + if (GpuBucketingUtils.isHiveHashBucketing(cmd.options)) { + GpuBucketingUtils.tagForHiveBucketingWrite(this, cmd.bucketSpec, cmd.outputColumns, false) + } else { + BucketIdMetaUtils.tagForBucketingWrite(this, cmd.bucketSpec, cmd.outputColumns) } - val spark = SparkSession.active val formatCls = cmd.fileFormat.getClass fileFormat = if (formatCls == classOf[CSVFileFormat]) { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveFileFormat.scala index 21437a64481..69189b2600c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveFileFormat.scala @@ -24,6 +24,7 @@ import com.google.common.base.Charsets import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.jni.CastStrings +import com.nvidia.spark.rapids.shims.GpuBucketingUtils import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging @@ -43,9 +44,8 @@ object GpuHiveFileFormat extends Logging { def tagGpuSupport(meta: GpuInsertIntoHiveTableMeta): Option[ColumnarFileFormat] = { val insertCmd = meta.wrapped // Bucketing write - if (insertCmd.table.bucketSpec.isDefined) { - meta.willNotWorkOnGpu("bucketed tables are not supported yet") - } + GpuBucketingUtils.tagForHiveBucketingWrite(meta, insertCmd.table.bucketSpec, + insertCmd.outputColumns, false) // Infer the file format from the serde string, similar as what Spark does in // RelationConversions for Hive. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala index 4ceac365314..939a421e0b9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,30 +17,30 @@ package org.apache.spark.sql.rapids import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.ListBuffer +import scala.util.hashing.{MurmurHash3 => ScalaMurmur3Hash} -import ai.rapids.cudf.{ColumnVector, OrderByArg, Table} +import ai.rapids.cudf.{OrderByArg, Table} import com.nvidia.spark.TimingUtils import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit -import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.shims.GpuFileFormatDataWriterShim import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Cast, Concat, Expression, Literal, NullsFirst, ScalaUDF, SortOrder, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Cast, Concat, Expression, Literal, Murmur3Hash, NullsFirst, ScalaUDF, UnsafeProjection} import org.apache.spark.sql.connector.write.DataWriter import org.apache.spark.sql.execution.datasources.{BucketingUtils, PartitioningUtils, WriteTaskResult} -import org.apache.spark.sql.rapids.GpuFileFormatDataWriter.{shouldSplitToFitMaxRecordsPerFile, splitToFitMaxRecordsAndClose} +import org.apache.spark.sql.rapids.GpuFileFormatDataWriter._ import org.apache.spark.sql.rapids.GpuFileFormatWriter.GpuConcurrentOutputWriterSpec -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration @@ -50,7 +50,7 @@ object GpuFileFormatDataWriter { } def shouldSplitToFitMaxRecordsPerFile( - maxRecordsPerFile: Long, recordsInFile: Long, numRowsInBatch: Long) = { + maxRecordsPerFile: Long, recordsInFile: Long, numRowsInBatch: Long): Boolean = { maxRecordsPerFile > 0 && (recordsInFile + numRowsInBatch) > maxRecordsPerFile } @@ -88,13 +88,8 @@ object GpuFileFormatDataWriter { maxRecordsPerFile: Long, recordsInFile: Long): Array[SpillableColumnarBatch] = { val (types, splitIndexes) = closeOnExcept(batch) { _ => - val types = GpuColumnVector.extractTypes(batch) - val splitIndexes = - getSplitIndexes( - maxRecordsPerFile, - recordsInFile, - batch.numRows()) - (types, splitIndexes) + val splitIndexes = getSplitIndexes(maxRecordsPerFile, recordsInFile, batch.numRows()) + (GpuColumnVector.extractTypes(batch), splitIndexes) } if (splitIndexes.isEmpty) { // this should never happen, as `splitToFitMaxRecordsAndClose` is called when @@ -124,6 +119,31 @@ abstract class GpuFileFormatDataWriter( description: GpuWriteJobDescription, taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol) extends DataWriter[ColumnarBatch] { + + protected class WriterAndStatus { + var writer: ColumnarOutputWriter = _ + + /** Number of records in current file. */ + var recordsInFile: Long = 0 + + /** + * File counter for writing current partition or bucket. For same partition or bucket, + * we may have more than one file, due to number of records limit per file. + */ + var fileCounter: Int = 0 + + final def release(): Unit = { + if (writer != null) { + try { + writer.close() + statsTrackers.foreach(_.closeFile(writer.path())) + } finally { + writer = null + } + } + } + } + /** * Max number of files a single task writes out due to file size. In most cases the number of * files written should be very small. This is just a safe guard to protect some really bad @@ -131,28 +151,26 @@ abstract class GpuFileFormatDataWriter( */ protected val MAX_FILE_COUNTER: Int = 1000 * 1000 protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]() - protected var currentWriter: ColumnarOutputWriter = _ + protected var currentWriterStatus: WriterAndStatus = new WriterAndStatus() /** Trackers for computing various statistics on the data as it's being written out. */ protected val statsTrackers: Seq[ColumnarWriteTaskStatsTracker] = description.statsTrackers.map(_.newTaskInstance()) - /** Release resources of `currentWriter`. */ - protected def releaseCurrentWriter(): Unit = { - if (currentWriter != null) { - try { - currentWriter.close() - statsTrackers.foreach(_.closeFile(currentWriter.path())) - } finally { - currentWriter = null - } - } + /** Release resources of a WriterStatus. */ + protected final def releaseOutWriter(status: WriterAndStatus): Unit = { + status.release() + } + + protected final def writeUpdateMetricsAndClose(scb: SpillableColumnarBatch, + writerStatus: WriterAndStatus): Unit = { + writerStatus.recordsInFile += writerStatus.writer.writeSpillableAndClose(scb, statsTrackers) } /** Release all resources. Public for testing */ def releaseResources(): Unit = { - // Call `releaseCurrentWriter()` by default, as this is the only resource to be released. - releaseCurrentWriter() + // Release current writer by default, as this is the only resource to be released. + releaseOutWriter(currentWriterStatus) } /** Write an iterator of column batch. */ @@ -211,8 +229,6 @@ class GpuSingleDirectoryDataWriter( taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol) extends GpuFileFormatDataWriter(description, taskAttemptContext, committer) { - private var fileCounter: Int = _ - private var recordsInFile: Long = _ // Initialize currentWriter and statsTrackers newOutputWriter() @@ -220,7 +236,8 @@ class GpuSingleDirectoryDataWriter( "msg=method newTaskTempFile in class FileCommitProtocol is deprecated" ) private def newOutputWriter(): Unit = { - recordsInFile = 0 + currentWriterStatus.recordsInFile = 0 + val fileCounter = currentWriterStatus.fileCounter releaseResources() val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) @@ -229,7 +246,7 @@ class GpuSingleDirectoryDataWriter( None, f"-c$fileCounter%03d" + ext) - currentWriter = description.outputWriterFactory.newInstance( + currentWriterStatus.writer = description.outputWriterFactory.newInstance( path = currentPath, dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) @@ -237,32 +254,30 @@ class GpuSingleDirectoryDataWriter( statsTrackers.foreach(_.newFile(currentPath)) } - private def writeUpdateMetricsAndClose(scb: SpillableColumnarBatch): Unit = { - recordsInFile += currentWriter.writeSpillableAndClose(scb, statsTrackers) - } - override def write(batch: ColumnarBatch): Unit = { val maxRecordsPerFile = description.maxRecordsPerFile + val recordsInFile = currentWriterStatus.recordsInFile if (!shouldSplitToFitMaxRecordsPerFile( maxRecordsPerFile, recordsInFile, batch.numRows())) { writeUpdateMetricsAndClose( - SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) + SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY), + currentWriterStatus) } else { val partBatches = splitToFitMaxRecordsAndClose( batch, maxRecordsPerFile, recordsInFile) - var needNewWriter = recordsInFile >= maxRecordsPerFile + val needNewWriterForFirstPart = recordsInFile >= maxRecordsPerFile closeOnExcept(partBatches) { _ => partBatches.zipWithIndex.foreach { case (partBatch, partIx) => - if (needNewWriter) { - fileCounter += 1 + if (partIx > 0 || needNewWriterForFirstPart) { + currentWriterStatus.fileCounter += 1 + val fileCounter = currentWriterStatus.fileCounter assert(fileCounter <= MAX_FILE_COUNTER, s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") newOutputWriter() } // null out the entry so that we don't double close partBatches(partIx) = null - writeUpdateMetricsAndClose(partBatch) - needNewWriter = true + writeUpdateMetricsAndClose(partBatch, currentWriterStatus) } } } @@ -280,35 +295,44 @@ class GpuDynamicPartitionDataSingleWriter( taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol) extends GpuFileFormatDataWriter(description, taskAttemptContext, committer) { + /** Wrapper class to index a unique concurrent output writer. */ + protected class WriterIndex( + var partitionPath: Option[String], + var bucketId: Option[Int]) extends Product2[Option[String], Option[Int]] { - /** Wrapper class for status of a unique single output writer. */ - protected class WriterStatus( - // output writer - var outputWriter: ColumnarOutputWriter, + override def hashCode(): Int = ScalaMurmur3Hash.productHash(this) - /** Number of records in current file. */ - var recordsInFile: Long = 0, + override def equals(obj: Any): Boolean = { + if (obj.isInstanceOf[WriterIndex]) { + val otherWI = obj.asInstanceOf[WriterIndex] + partitionPath == otherWI.partitionPath && bucketId == otherWI.bucketId + } else { + false + } + } - /** - * File counter for writing current partition or bucket. For same partition or bucket, - * we may have more than one file, due to number of records limit per file. - */ - var fileCounter: Int = 0 - ) + override def _1: Option[String] = partitionPath + override def _2: Option[Int] = bucketId + override def canEqual(that: Any): Boolean = that.isInstanceOf[WriterIndex] + } - /** Wrapper class for status and caches of a unique concurrent output writer. - * Used by `GpuDynamicPartitionDataConcurrentWriter` + /** + * A case class to hold the batch, the optional partition path and the optional bucket + * ID for a split group. All the rows in the batch belong to the group defined by the + * partition path and the bucket ID. */ - class WriterStatusWithCaches( - // writer status - var writerStatus: WriterStatus, - - // caches for this partition or writer - val tableCaches: ListBuffer[SpillableColumnarBatch] = ListBuffer(), - - // current device bytes for the above caches - var deviceBytes: Long = 0 - ) + private case class SplitPack(split: SpillableColumnarBatch, path: Option[String], + bucketId: Option[Int]) extends AutoCloseable { + override def close(): Unit = { + split.safeClose() + } + } + /** + * The index for current writer. Intentionally make the index mutable and reusable. + * Avoid JVM GC issue when many short-living `WriterIndex` objects are created + * if switching between concurrent writers frequently. + */ + private val currentWriterId: WriterIndex = new WriterIndex(None, None) /** Flag saying whether or not the data to be written out is partitioned. */ protected val isPartitioned: Boolean = description.partitionColumns.nonEmpty @@ -316,25 +340,17 @@ class GpuDynamicPartitionDataSingleWriter( /** Flag saying whether or not the data to be written out is bucketed. */ protected val isBucketed: Boolean = description.bucketSpec.isDefined - private var currentPartPath: String = "" - - private var currentWriterStatus: WriterStatus = _ - - // All data is sorted ascending with default null ordering - private val nullsSmallest = Ascending.defaultNullOrdering == NullsFirst - - if (isBucketed) { - throw new UnsupportedOperationException("Bucketing is not supported on the GPU yet.") - } - assert(isPartitioned || isBucketed, s"""GpuDynamicPartitionWriteTask should be used for writing out data that's either |partitioned or bucketed. In this case neither is true. |GpuWriteJobDescription: $description """.stripMargin) + // All data is sorted ascending with default null ordering + private val nullsSmallest = Ascending.defaultNullOrdering == NullsFirst + /** Extracts the partition values out of an input batch. */ - protected lazy val getPartitionColumnsAsBatch: ColumnarBatch => ColumnarBatch = { + private lazy val getPartitionColumnsAsBatch: ColumnarBatch => ColumnarBatch = { val expressions = GpuBindReferences.bindGpuReferences( description.partitionColumns, description.allColumns) @@ -343,20 +359,9 @@ class GpuDynamicPartitionDataSingleWriter( } } - /** Extracts the output values of an input batch. */ - private lazy val getOutputColumnsAsBatch: ColumnarBatch => ColumnarBatch= { + private lazy val getBucketIdColumnAsBatch: ColumnarBatch => ColumnarBatch = { val expressions = GpuBindReferences.bindGpuReferences( - description.dataColumns, - description.allColumns) - cb => { - GpuProjectExec.project(cb, expressions) - } - } - - /** Extracts the output values of an input batch. */ - protected lazy val getOutputCb: ColumnarBatch => ColumnarBatch = { - val expressions = GpuBindReferences.bindGpuReferences( - description.dataColumns, + Seq(description.bucketSpec.get.bucketIdExpression), description.allColumns) cb => { GpuProjectExec.project(cb, expressions) @@ -379,62 +384,58 @@ class GpuDynamicPartitionDataSingleWriter( /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns * the partition string. */ - protected lazy val getPartitionPath: InternalRow => String = { + private lazy val getPartitionPath: InternalRow => String = { val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns) row => proj(row).getString(0) } - /** Release resources of writer. */ - private def releaseWriter(writer: ColumnarOutputWriter): Unit = { - if (writer != null) { - val path = writer.path() - writer.close() - statsTrackers.foreach(_.closeFile(path)) + /** Extracts the output values of an input batch. */ + protected lazy val getDataColumnsAsBatch: ColumnarBatch => ColumnarBatch = { + val expressions = GpuBindReferences.bindGpuReferences( + description.dataColumns, + description.allColumns) + cb => { + GpuProjectExec.project(cb, expressions) } } - /** - * Opens a new OutputWriter given a partition key and/or a bucket id. - * If bucket id is specified, we will append it to the end of the file name, but before the - * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet - * - * @param partDir the partition directory - * @param bucketId the bucket which all tuples being written by this OutputWriter belong to, - * currently does not support `bucketId`, it's always None - * @param fileCounter integer indicating the number of files to be written to `partDir` - */ - @scala.annotation.nowarn( - "msg=method newTaskTempFile.* in class FileCommitProtocol is deprecated" - ) - def newWriter( - partDir: String, - bucketId: Option[Int], // Currently it's always None - fileCounter: Int - ): ColumnarOutputWriter = { - updatedPartitions.add(partDir) - // Currently will be empty - val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - - // This must be in a form that matches our bucketing format. See BucketingUtils. - val ext = f"$bucketIdStr.c$fileCounter%03d" + - description.outputWriterFactory.getFileExtension(taskAttemptContext) - - val customPath = description.customPartitionLocations - .get(PartitioningUtils.parsePathFragment(partDir)) + protected def getKeysBatch(cb: ColumnarBatch): ColumnarBatch = { + val keysBatch = withResource(getPartitionColumnsAsBatch(cb)) { partCb => + if (isBucketed) { + withResource(getBucketIdColumnAsBatch(cb)) { bucketIdCb => + GpuColumnVector.combineColumns(partCb, bucketIdCb) + } + } else { + GpuColumnVector.incRefCounts(partCb) + } + } + require(keysBatch.numCols() > 0, "No sort key is specified") + keysBatch + } - val currentPath = if (customPath.isDefined) { - committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + protected def genGetBucketIdFunc(keyHostCb: ColumnarBatch): Int => Option[Int] = { + if (isBucketed) { + // The last column is the bucket id column + val bucketIdCol = keyHostCb.column(keyHostCb.numCols() - 1) + i => Some(bucketIdCol.getInt(i)) } else { - committer.newTaskTempFile(taskAttemptContext, Option(partDir), ext) + _ => None } + } - val newWriter = description.outputWriterFactory.newInstance( - path = currentPath, - dataSchema = description.dataColumns.toStructType, - context = taskAttemptContext) - - statsTrackers.foreach(_.newFile(currentPath)) - newWriter + protected def genGetPartitionPathFunc(keyHostCb: ColumnarBatch): Int => Option[String] = { + if (isPartitioned) { + // Use the existing code to convert each row into a path. It would be nice to do this + // on the GPU, but the data should be small and there are things we cannot easily + // support on the GPU right now + import scala.collection.JavaConverters._ + val partCols = description.partitionColumns.indices.map(keyHostCb.column) + val iter = new ColumnarBatch(partCols.toArray, keyHostCb.numRows()).rowIterator() + .asScala.map(getPartitionPath) + _ => Some(iter.next) + } else { + _ => None + } } // distinct value sorted the same way the input data is sorted. @@ -461,282 +462,195 @@ class GpuDynamicPartitionDataSingleWriter( } } - override def write(batch: ColumnarBatch): Unit = { - // this single writer always passes `cachesMap` as None - write(batch, cachesMap = None) - } - - private case class SplitAndPath(var split: SpillableColumnarBatch, path: String) - extends AutoCloseable { - override def close(): Unit = { - split.safeClose() - split = null - } - } - /** - * Split a batch according to the sorted keys (partitions). Returns a tuple with an - * array of the splits as `ContiguousTable`'s, and an array of paths to use to - * write each partition. + * Split a batch according to the sorted keys (partitions + bucket ids). + * Returns a tuple with an array of the splits as `ContiguousTable`'s, an array of + * paths and bucket ids to use to write each partition and(or) bucket file. */ - private def splitBatchByKeyAndClose( - batch: ColumnarBatch, - partDataTypes: Array[DataType]): Array[SplitAndPath] = { - val (outputColumnsBatch, partitionColumnsBatch) = withResource(batch) { _ => - closeOnExcept(getOutputColumnsAsBatch(batch)) { outputColumnsBatch => - closeOnExcept(getPartitionColumnsAsBatch(batch)) { partitionColumnsBatch => - (outputColumnsBatch, partitionColumnsBatch) - } + private def splitBatchByKeyAndClose(batch: ColumnarBatch): Array[SplitPack] = { + val (keysCb, dataCb) = withResource(batch) { _ => + closeOnExcept(getDataColumnsAsBatch(batch)) { data => + (getKeysBatch(batch), data) } } - val (cbKeys, partitionIndexes) = closeOnExcept(outputColumnsBatch) { _ => - val partitionColumnsTbl = withResource(partitionColumnsBatch) { _ => - GpuColumnVector.from(partitionColumnsBatch) - } - withResource(partitionColumnsTbl) { _ => - withResource(distinctAndSort(partitionColumnsTbl)) { distinctKeysTbl => - val partitionIndexes = splitIndexes(partitionColumnsTbl, distinctKeysTbl) - val cbKeys = copyToHostAsBatch(distinctKeysTbl, partDataTypes) - (cbKeys, partitionIndexes) + val (keyHostCb, splitIds) = closeOnExcept(dataCb) { _ => + val (splitIds, distinctKeysTbl, keysCbTypes) = withResource(keysCb) { _ => + val keysCbTypes = GpuColumnVector.extractTypes(keysCb) + withResource(GpuColumnVector.from(keysCb)) { keysTable => + closeOnExcept(distinctAndSort(keysTable)) { distinctKeysTbl => + (splitIndexes(keysTable, distinctKeysTbl), distinctKeysTbl, keysCbTypes) + } } } + withResource(distinctKeysTbl) { _ => + (copyToHostAsBatch(distinctKeysTbl, keysCbTypes), splitIds) + } } - - val splits = closeOnExcept(cbKeys) { _ => - val spillableOutputColumnsBatch = - SpillableColumnarBatch(outputColumnsBatch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY) - withRetryNoSplit(spillableOutputColumnsBatch) { spillable => - withResource(spillable.getColumnarBatch()) { outCb => + val splits = closeOnExcept(keyHostCb) { _ => + val scbOutput = closeOnExcept(dataCb)( _ => + SpillableColumnarBatch(dataCb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) + withRetryNoSplit(scbOutput) { scb => + withResource(scb.getColumnarBatch()) { outCb => withResource(GpuColumnVector.from(outCb)) { outputColumnsTbl => withResource(outputColumnsTbl) { _ => - outputColumnsTbl.contiguousSplit(partitionIndexes: _*) + outputColumnsTbl.contiguousSplit(splitIds: _*) } } } } } - - val paths = closeOnExcept(splits) { _ => - withResource(cbKeys) { _ => - // Use the existing code to convert each row into a path. It would be nice to do this - // on the GPU, but the data should be small and there are things we cannot easily - // support on the GPU right now - import scala.collection.JavaConverters._ - // paths - cbKeys.rowIterator().asScala.map(getPartitionPath).toArray - } - } + // Build the split result withResource(splits) { _ => - // NOTE: the `zip` here has the effect that will remove an extra `ContiguousTable` - // added at the end of `splits` because we use `upperBound` to find the split points, - // and the last split point is the number of rows. - val outDataTypes = description.dataColumns.map(_.dataType).toArray - splits.zip(paths).zipWithIndex.map { case ((split, path), ix) => - splits(ix) = null - withResource(split) { _ => - SplitAndPath( - SpillableColumnarBatch( - split, outDataTypes, SpillPriorities.ACTIVE_BATCHING_PRIORITY), - path) - } + withResource(keyHostCb) { _ => + val getBucketId = genGetBucketIdFunc(keyHostCb) + val getNextPartPath = genGetPartitionPathFunc(keyHostCb) + val outDataTypes = description.dataColumns.map(_.dataType).toArray + (0 until keyHostCb.numRows()).safeMap { idx => + val split = splits(idx) + splits(idx) = null + closeOnExcept(split) { _ => + SplitPack( + SpillableColumnarBatch(split, outDataTypes, + SpillPriorities.ACTIVE_BATCHING_PRIORITY), + getNextPartPath(idx), getBucketId(idx)) + } + }.toArray } } } - private def getBatchToWrite( - partBatch: SpillableColumnarBatch, - savedStatus: Option[WriterStatusWithCaches]): SpillableColumnarBatch = { - val outDataTypes = description.dataColumns.map(_.dataType).toArray - if (savedStatus.isDefined && savedStatus.get.tableCaches.nonEmpty) { - // In the case where the concurrent partition writers fall back, we need to - // incorporate into the current part any pieces that are already cached - // in the `savedStatus`. Adding `partBatch` to what was saved could make a - // concatenated batch with number of rows larger than `maxRecordsPerFile`, - // so this concatenated result could be split later, which is not efficient. However, - // the concurrent writers are default off in Spark, so it is not clear if this - // code path is worth optimizing. - val concat: Table = - withResource(savedStatus.get.tableCaches) { subSpillableBatches => - val toConcat = subSpillableBatches :+ partBatch - - // clear the caches - savedStatus.get.tableCaches.clear() - - withRetryNoSplit(toConcat.toSeq) { spillables => - withResource(spillables.safeMap(_.getColumnarBatch())) { batches => - withResource(batches.map(GpuColumnVector.from)) { subTables => - Table.concatenate(subTables: _*) - } - } - } - } - withResource(concat) { _ => - SpillableColumnarBatch( - GpuColumnVector.from(concat, outDataTypes), - SpillPriorities.ACTIVE_ON_DECK_PRIORITY) - } - } else { - partBatch + /** + * Create a new writer according to the given writer id, and update the given + * writer status. It also closes the old writer in the writer status by default. + */ + protected final def renewOutWriter(newWriterId: WriterIndex, curWriterStatus: WriterAndStatus, + closeOldWriter: Boolean = true): Unit = { + if (closeOldWriter) { + releaseOutWriter(curWriterStatus) } + curWriterStatus.recordsInFile = 0 + curWriterStatus.writer = newWriter(newWriterId.partitionPath, newWriterId.bucketId, + curWriterStatus.fileCounter) + } + + /** + * Set up a writer to the given writer status for the given writer id. + * It will create a new one if needed. This is used when seeing a new partition + * and(or) a new bucket id. + */ + protected def setupCurrentWriter(newWriterId: WriterIndex, curWriterStatus: WriterAndStatus, + closeOldWriter: Boolean = true): Unit = { + renewOutWriter(newWriterId, curWriterStatus, closeOldWriter) } /** - * Write columnar batch. - * If the `cachesMap` is not empty, this single writer should restore the writers and caches in - * the `cachesMap`, this single writer should first combine the caches and current split data - * for a specific partition before write. + * Opens a new OutputWriter given a partition key and/or a bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet * - * @param cb the column batch - * @param cachesMap used by `GpuDynamicPartitionDataConcurrentWriter` when fall back to single - * writer, single writer should handle the stored writers and the pending caches + * @param partDir the partition directory + * @param bucketId the bucket which all tuples being written by this OutputWriter belong to, + * currently does not support `bucketId`, it's always None + * @param fileCounter integer indicating the number of files to be written to `partDir` */ - protected def write( - batch: ColumnarBatch, - cachesMap: Option[mutable.HashMap[String, WriterStatusWithCaches]]): Unit = { - assert(isPartitioned) - assert(!isBucketed) + @scala.annotation.nowarn( + "msg=method newTaskTempFile.* in class FileCommitProtocol is deprecated" + ) + def newWriter(partDir: Option[String], bucketId: Option[Int], + fileCounter: Int): ColumnarOutputWriter = { + partDir.foreach(updatedPartitions.add) + // Currently will be empty + val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - val maxRecordsPerFile = description.maxRecordsPerFile - val partDataTypes = description.partitionColumns.map(_.dataType).toArray - - // We have an entire batch that is sorted, so we need to split it up by key - // to get a batch per path - withResource(splitBatchByKeyAndClose(batch, partDataTypes)) { splitsAndPaths => - splitsAndPaths.zipWithIndex.foreach { case (SplitAndPath(partBatch, partPath), ix) => - // If we fall back from `GpuDynamicPartitionDataConcurrentWriter`, we should get the - // saved status - val savedStatus = updateCurrentWriterIfNeeded(partPath, cachesMap) - - // combine `partBatch` with any remnants for this partition for the concurrent - // writer fallback case in `savedStatus` - splitsAndPaths(ix) = null - val batchToWrite = getBatchToWrite(partBatch, savedStatus) - - // if the batch fits, write it as is, else split and write it. - if (!shouldSplitToFitMaxRecordsPerFile(maxRecordsPerFile, - currentWriterStatus.recordsInFile, batchToWrite.numRows())) { - writeUpdateMetricsAndClose(currentWriterStatus, batchToWrite) - } else { - // materialize an actual batch since we are going to split it - // on the GPU - val batchToSplit = withRetryNoSplit(batchToWrite) { _ => - batchToWrite.getColumnarBatch() - } - val maxRecordsPerFileSplits = splitToFitMaxRecordsAndClose( - batchToSplit, - maxRecordsPerFile, - currentWriterStatus.recordsInFile) - writeSplitBatchesAndClose(maxRecordsPerFileSplits, maxRecordsPerFile, partPath) - } - } + // This must be in a form that matches our bucketing format. See BucketingUtils. + val ext = f"$bucketIdStr.c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) + + val customPath = partDir.flatMap { dir => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) } + + val currentPath = if (customPath.isDefined) { + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + } else { + committer.newTaskTempFile(taskAttemptContext, partDir, ext) + } + + val outWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + + statsTrackers.foreach(_.newFile(currentPath)) + outWriter } - private def updateCurrentWriterIfNeeded( - partPath: String, - cachesMap: Option[mutable.HashMap[String, WriterStatusWithCaches]]): - Option[WriterStatusWithCaches] = { - var savedStatus: Option[WriterStatusWithCaches] = None - if (currentPartPath != partPath) { - val previousPartPath = currentPartPath - currentPartPath = partPath - - // see a new partition, close the old writer - val previousWriterStatus = currentWriterStatus - if (previousWriterStatus != null) { - releaseWriter(previousWriterStatus.outputWriter) - } + protected final def writeBatchPerMaxRecordsAndClose(scb: SpillableColumnarBatch, + writerId: WriterIndex, writerStatus: WriterAndStatus): Unit = { + val maxRecordsPerFile = description.maxRecordsPerFile + val recordsInFile = writerStatus.recordsInFile - if (cachesMap.isDefined) { - savedStatus = cachesMap.get.get(currentPartPath) - if (savedStatus.isDefined) { - // first try to restore the saved writer status, - // `GpuDynamicPartitionDataConcurrentWriter` may already opened the writer, and may - // have pending caches - currentWriterStatus = savedStatus.get.writerStatus - // entire batch that is sorted, see a new partition, the old write status is useless - cachesMap.get.remove(previousPartPath) - } else { - // create a new one - val writer = newWriter(partPath, None, 0) - currentWriterStatus = new WriterStatus(writer) - statsTrackers.foreach(_.newPartition()) + if (!shouldSplitToFitMaxRecordsPerFile(maxRecordsPerFile, recordsInFile, scb.numRows())) { + writeUpdateMetricsAndClose(scb, writerStatus) + } else { + val batch = withRetryNoSplit(scb) { scb => + scb.getColumnarBatch() + } + val splits = splitToFitMaxRecordsAndClose(batch, maxRecordsPerFile, recordsInFile) + withResource(splits) { _ => + val needNewWriterForFirstPart = recordsInFile >= maxRecordsPerFile + splits.zipWithIndex.foreach { case (part, partIx) => + if (partIx > 0 || needNewWriterForFirstPart) { + writerStatus.fileCounter += 1 + assert(writerStatus.fileCounter <= MAX_FILE_COUNTER, + s"File counter ${writerStatus.fileCounter} is beyond max value $MAX_FILE_COUNTER") + // will create a new file, so close the old writer + renewOutWriter(writerId, writerStatus) + } + splits(partIx) = null + writeUpdateMetricsAndClose(part, writerStatus) } - } else { - // create a new one - val writer = newWriter(partPath, None, 0) - currentWriterStatus = new WriterStatus(writer) - statsTrackers.foreach(_.newPartition()) } } - savedStatus } /** - * Write an array of spillable batches. + * Called just before updating the current writer status when seeing a new partition + * or a bucket. * - * Note: `spillableBatches` will be closed in this function. - * - * @param batches the SpillableColumnarBatch splits to be written - * @param maxRecordsPerFile the max number of rows per file - * @param partPath the partition directory + * @param curWriterId the current writer index */ - private def writeSplitBatchesAndClose( - spillableBatches: Array[SpillableColumnarBatch], - maxRecordsPerFile: Long, - partPath: String): Unit = { - var needNewWriter = currentWriterStatus.recordsInFile >= maxRecordsPerFile - withResource(spillableBatches) { _ => - spillableBatches.zipWithIndex.foreach { case (part, partIx) => - if (needNewWriter) { - currentWriterStatus.fileCounter += 1 - assert(currentWriterStatus.fileCounter <= MAX_FILE_COUNTER, - s"File counter ${currentWriterStatus.fileCounter} " + - s"is beyond max value $MAX_FILE_COUNTER") - - // will create a new file, close the old writer - if (currentWriterStatus != null) { - releaseWriter(currentWriterStatus.outputWriter) - } + protected def preUpdateCurrentWriterStatus(curWriterId: WriterIndex): Unit ={} - // create a new writer and update the writer in the status - currentWriterStatus.outputWriter = - newWriter(partPath, None, currentWriterStatus.fileCounter) - currentWriterStatus.recordsInFile = 0 + override def write(batch: ColumnarBatch): Unit = { + // The input batch that is entirely sorted, so split it up by partitions and (or) + // bucket ids, and write the split batches one by one. + withResource(splitBatchByKeyAndClose(batch)) { splitPacks => + splitPacks.zipWithIndex.foreach { case (SplitPack(sp, partPath, bucketId), i) => + val hasDiffPart = partPath != currentWriterId.partitionPath + val hasDiffBucket = bucketId != currentWriterId.bucketId + if (hasDiffPart || hasDiffBucket) { + preUpdateCurrentWriterStatus(currentWriterId) + if (hasDiffPart) { + currentWriterId.partitionPath = partPath + statsTrackers.foreach(_.newPartition()) + } + if (hasDiffBucket) { + currentWriterId.bucketId = bucketId + } + currentWriterStatus.fileCounter = 0 + setupCurrentWriter(currentWriterId, currentWriterStatus) } - spillableBatches(partIx) = null - writeUpdateMetricsAndClose(currentWriterStatus, part) - needNewWriter = true - } - } - } - - protected def writeUpdateMetricsAndClose( - writerStatus: WriterStatus, - spillableBatch: SpillableColumnarBatch): Unit = { - writerStatus.recordsInFile += - writerStatus.outputWriter.writeSpillableAndClose(spillableBatch, statsTrackers) - } - - /** Release all resources. */ - override def releaseResources(): Unit = { - // does not use `currentWriter`, single writer use `currentWriterStatus` - assert(currentWriter == null) - - if (currentWriterStatus != null) { - try { - currentWriterStatus.outputWriter.close() - statsTrackers.foreach(_.closeFile(currentWriterStatus.outputWriter.path())) - } finally { - currentWriterStatus = null + splitPacks(i) = null + writeBatchPerMaxRecordsAndClose(sp, currentWriterId, currentWriterStatus) } } } } /** - * Dynamic partition writer with concurrent writers, meaning multiple concurrent writers are opened - * for writing. + * Dynamic partition writer with concurrent writers, meaning multiple concurrent + * writers are opened for writing. * * The process has the following steps: * - Step 1: Maintain a map of output writers per each partition columns. Keep all @@ -754,18 +668,29 @@ class GpuDynamicPartitionDataConcurrentWriter( description: GpuWriteJobDescription, taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol, - spec: GpuConcurrentOutputWriterSpec, - taskContext: TaskContext) - extends GpuDynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) { + spec: GpuConcurrentOutputWriterSpec) + extends GpuDynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) + with Logging { - // Keep all the unclosed writers, key is partition directory string. - // Note: if fall back to sort-based mode, also use the opened writers in the map. - private val concurrentWriters = mutable.HashMap[String, WriterStatusWithCaches]() + /** Wrapper class for status and caches of a unique concurrent output writer. */ + private class WriterStatusWithBatches extends WriterAndStatus with AutoCloseable { + // caches for this partition or writer + val tableCaches: ListBuffer[SpillableColumnarBatch] = ListBuffer() - // guarantee to close the caches and writers when task is finished - onTaskCompletion(taskContext)(closeCachesAndWriters()) + // current device bytes for the above caches + var deviceBytes: Long = 0 - private val outDataTypes = description.dataColumns.map(_.dataType).toArray + override def close(): Unit = try { + releaseOutWriter(this) + } finally { + tableCaches.safeClose() + tableCaches.clear() + } + } + + // Keep all the unclosed writers, key is a partition path and(or) bucket id. + // Note: if fall back to sort-based mode, also use the opened writers in the map. + private val concurrentWriters = mutable.HashMap[WriterIndex, WriterStatusWithBatches]() private val partitionFlushSize = if (description.concurrentWriterPartitionFlushSize <= 0) { @@ -777,324 +702,196 @@ class GpuDynamicPartitionDataConcurrentWriter( description.concurrentWriterPartitionFlushSize } - // refer to current batch if should fall back to `single writer` - private var currentFallbackColumnarBatch: ColumnarBatch = _ + // Pending split batches that are not cached for the concurrent write because + // there are too many open writers, and it is going to fall back to the sorted + // sequential write. + private val pendingBatches: mutable.Queue[SpillableColumnarBatch] = mutable.Queue.empty - override def abort(): Unit = { - try { - closeCachesAndWriters() - } finally { - committer.abortTask(taskAttemptContext) + override def writeWithIterator(iterator: Iterator[ColumnarBatch]): Unit = { + // 1: try concurrent writer + while (iterator.hasNext && pendingBatches.isEmpty) { + // concurrent write and update the `concurrentWriters` map. + this.write(iterator.next()) } - } - /** - * State to indicate if we are falling back to sort-based writer. - * Because we first try to use concurrent writers, its initial value is false. - */ - private var fallBackToSortBased: Boolean = false + // 2: fall back to single write if the input is not all consumed. + if (pendingBatches.nonEmpty || iterator.hasNext) { + // sort the all the pending batches and ones in `iterator` + val pendingCbsIter = new Iterator[ColumnarBatch] { + override def hasNext: Boolean = pendingBatches.nonEmpty - private def writeWithSingleWriter(cb: ColumnarBatch): Unit = { - // invoke `GpuDynamicPartitionDataSingleWriter`.write, - // single writer will take care of the unclosed writers and the pending caches - // in `concurrentWriters` - super.write(cb, Some(concurrentWriters)) + override def next(): ColumnarBatch = { + if (!hasNext) { + throw new NoSuchElementException() + } + withResource(pendingBatches.dequeue())(_.getColumnarBatch()) + } + } + val sortIter = GpuOutOfCoreSortIterator(pendingCbsIter ++ iterator, + new GpuSorter(spec.sortOrder, spec.output), GpuSortExec.targetSize(spec.batchSize), + NoopMetric, NoopMetric, NoopMetric, NoopMetric) + while (sortIter.hasNext) { + // write with sort-based sequential writer + super.write(sortIter.next()) + } + } } - private def writeWithConcurrentWriter(cb: ColumnarBatch): Unit = { - this.write(cb) + /** This is for the fallback case, used to clean the writers map. */ + override def preUpdateCurrentWriterStatus(curWriterId: WriterIndex): Unit = { + concurrentWriters.remove(curWriterId) } - /** - * Write an iterator of column batch. - * - * @param iterator the iterator of column batch - */ - override def writeWithIterator(iterator: Iterator[ColumnarBatch]): Unit = { - // 1: try concurrent writer - while (iterator.hasNext && !fallBackToSortBased) { - // concurrently write and update the `concurrentWriters` map - // the `` will be updated - writeWithConcurrentWriter(iterator.next()) + /** This is for the fallback case, try to find the writer from cache first. */ + override def setupCurrentWriter(newWriterId: WriterIndex, writerStatus: WriterAndStatus, + closeOldWriter: Boolean): Unit = { + if (closeOldWriter) { + releaseOutWriter(writerStatus) } - - // 2: fall back to single writer - // Note single writer should restore writer status and handle the pending caches - if (fallBackToSortBased) { - // concat the put back batch and un-coming batches - val newIterator = Iterator.single(currentFallbackColumnarBatch) ++ iterator - // sort the all the batches in `iterator` - - val sortIterator: GpuOutOfCoreSortIterator = getSorted(newIterator) - while (sortIterator.hasNext) { - // write with sort-based single writer - writeWithSingleWriter(sortIterator.next()) - } + val oOpenStatus = concurrentWriters.get(newWriterId) + if (oOpenStatus.isDefined) { + val openStatus = oOpenStatus.get + writerStatus.writer = openStatus.writer + writerStatus.recordsInFile = openStatus.recordsInFile + writerStatus.fileCounter = openStatus.fileCounter + } else { + super.setupCurrentWriter(newWriterId, writerStatus, closeOldWriter = false) } } /** - * Sort the input iterator by out of core sort - * - * @param iterator the input iterator - * @return sorted iterator - */ - private def getSorted(iterator: Iterator[ColumnarBatch]): GpuOutOfCoreSortIterator = { - val gpuSortOrder: Seq[SortOrder] = spec.sortOrder - val output: Seq[Attribute] = spec.output - val sorter = new GpuSorter(gpuSortOrder, output) - - // use noop metrics below - val sortTime = NoopMetric - val opTime = NoopMetric - val outputBatch = NoopMetric - val outputRows = NoopMetric - - val targetSize = GpuSortExec.targetSize(spec.batchSize) - // out of core sort the entire iterator - GpuOutOfCoreSortIterator(iterator, sorter, targetSize, - opTime, sortTime, outputBatch, outputRows) - } - - /** - * concurrent write the columnar batch - * Note: if new partitions number in `cb` plus existing partitions number is greater than - * `maxWriters` limit, will put back the whole `cb` to 'single writer` + * The write path of concurrent writers * - * @param cb the columnar batch + * @param cb the columnar batch to be written */ override def write(cb: ColumnarBatch): Unit = { - assert(isPartitioned) - assert(!isBucketed) - if (cb.numRows() == 0) { // TODO https://github.com/NVIDIA/spark-rapids/issues/6453 // To solve above issue, I assume that an empty batch will be wrote for saving metadata. // If the assumption it's true, this concurrent writer should write the metadata here, // and should not run into below splitting and caching logic + cb.close() return } - // 1. combine partition columns and `cb` columns into a column array - val columnsWithPartition = ArrayBuffer[ColumnVector]() - - // this withResource is here to decrement the refcount of the partition columns - // that are projected out of `cb` - withResource(getPartitionColumnsAsBatch(cb)) { partitionColumnsBatch => - columnsWithPartition.appendAll(GpuColumnVector.extractBases(partitionColumnsBatch)) - } - - val cols = GpuColumnVector.extractBases(cb) - columnsWithPartition ++= cols - - // 2. group by the partition columns - // get sub-groups for each partition and get unique keys for each partition - val groupsAndKeys = withResource( - new Table(columnsWithPartition.toSeq: _*)) { colsWithPartitionTbl => - // [0, partition columns number - 1] - val partitionIndices = description.partitionColumns.indices - - // group by partition columns - val op = colsWithPartitionTbl.groupBy(partitionIndices: _*) - // return groups and uniq keys table - // Each row in uniq keys table is corresponding to a group - op.contiguousSplitGroupsAndGenUniqKeys() - } - - withResource(groupsAndKeys) { _ => - // groups number should equal to uniq keys number - assert(groupsAndKeys.getGroups.length == groupsAndKeys.getUniqKeyTable.getRowCount) - - val (groups, keys) = (groupsAndKeys.getGroups, groupsAndKeys.getUniqKeyTable) - - // 3. generate partition strings for all sub-groups in advance - val partDataTypes = description.partitionColumns.map(_.dataType).toArray - val dataTypes = GpuColumnVector.extractTypes(cb) - // generate partition string list for all groups - val partitionStrList = getPartitionStrList(keys, partDataTypes) - // key table is useless now - groupsAndKeys.closeUniqKeyTable() - - // 4. cache each group according to each partitionStr - withResource(groups) { _ => - - // first update fallBackToSortBased - withResource(cb) { _ => - var newPartitionNum = 0 - var groupIndex = 0 - while (!fallBackToSortBased && groupIndex < groups.length) { - // get the partition string - val partitionStr = partitionStrList(groupIndex) - groupIndex += 1 - if (!concurrentWriters.contains(partitionStr)) { - newPartitionNum += 1 - if (newPartitionNum + concurrentWriters.size >= spec.maxWriters) { - fallBackToSortBased = true - currentFallbackColumnarBatch = cb - // `cb` should be put back to single writer - GpuColumnVector.incRefCounts(cb) - } - } - } - } - - if (!fallBackToSortBased) { - // not fall, collect all caches - var groupIndex = 0 - while (groupIndex < groups.length) { - // get the partition string and group pair - val (partitionStr, group) = (partitionStrList(groupIndex), groups(groupIndex)) - val groupTable = group.getTable - groupIndex += 1 - - // create writer if encounter a new partition and put into `concurrentWriters` map - if (!concurrentWriters.contains(partitionStr)) { - val w = newWriter(partitionStr, None, 0) - val ws = new WriterStatus(w) - concurrentWriters.put(partitionStr, new WriterStatusWithCaches(ws)) - statsTrackers.foreach(_.newPartition()) - } - - // get data columns, tail part is data columns - val dataColumns = ArrayBuffer[ColumnVector]() - for (i <- description.partitionColumns.length until groupTable.getNumberOfColumns) { - dataColumns += groupTable.getColumn(i) - } - withResource(new Table(dataColumns.toSeq: _*)) { dataTable => - withResource(GpuColumnVector.from(dataTable, dataTypes)) { cb => - val outputCb = getOutputCb(cb) - // convert to spillable cache and add to the pending cache - val currWriterStatus = concurrentWriters(partitionStr) - // create SpillableColumnarBatch to take the owner of `outputCb` - currWriterStatus.tableCaches += SpillableColumnarBatch( - outputCb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY) - currWriterStatus.deviceBytes += GpuColumnVector.getTotalDeviceMemoryUsed(outputCb) - } - } + // Split the batch and cache the result, along with opening the writers. + splitBatchToCacheAndClose(cb) + // Write the cached batches + val writeFunc: (WriterIndex, WriterStatusWithBatches) => Unit = + if (pendingBatches.nonEmpty) { + // Flush all the caches before going into sorted sequential write + writeOneCacheAndClose + } else { + // Still the concurrent write, so write out only partitions that size > threshold. + (wi, ws) => + if (ws.deviceBytes > partitionFlushSize) { + writeOneCacheAndClose(wi, ws) } - } } - } - - // 5. find all big enough partitions and write - if(!fallBackToSortBased) { - for ((partitionDir, ws) <- findBigPartitions(partitionFlushSize)) { - writeAndCloseCache(partitionDir, ws) - } - } - } - - private def getPartitionStrList( - uniqKeysTable: Table, partDataTypes: Array[DataType]): Array[String] = { - withResource(copyToHostAsBatch(uniqKeysTable, partDataTypes)) { oneRowCb => - import scala.collection.JavaConverters._ - oneRowCb.rowIterator().asScala.map(getPartitionPath).toArray + concurrentWriters.foreach { case (writerIdx, writerStatus) => + writeFunc(writerIdx, writerStatus) } } - private def writeAndCloseCache(partitionDir: String, status: WriterStatusWithCaches): Unit = { + private def writeOneCacheAndClose(writerId: WriterIndex, + status: WriterStatusWithBatches): Unit = { assert(status.tableCaches.nonEmpty) + // Concat tables if needed + val scbToWrite = GpuBatchUtils.concatSpillBatchesAndClose(status.tableCaches.toSeq).get + status.tableCaches.clear() + status.deviceBytes = 0 + writeBatchPerMaxRecordsAndClose(scbToWrite, writerId, status) + } - // get concat table or the single table - val spillableToWrite = if (status.tableCaches.length >= 2) { - // concat the sub batches to write in once. - val concatted = withRetryNoSplit(status.tableCaches.toSeq) { spillableSubBatches => - withResource(spillableSubBatches.safeMap(_.getColumnarBatch())) { subBatches => - withResource(subBatches.map(GpuColumnVector.from)) { subTables => - Table.concatenate(subTables: _*) - } - } + private def splitBatchToCacheAndClose(batch: ColumnarBatch): Unit = { + // Split batch to groups by sort columns, [partition and(or) bucket id column]. + val (keysAndGroups, keyTypes) = withResource(batch) { _ => + val (opBatch, keyTypes) = withResource(getKeysBatch(batch)) { keysBatch => + val combinedCb = GpuColumnVector.combineColumns(keysBatch, batch) + (combinedCb, GpuColumnVector.extractTypes(keysBatch)) } - withResource(concatted) { _ => - SpillableColumnarBatch( - GpuColumnVector.from(concatted, outDataTypes), - SpillPriorities.ACTIVE_ON_DECK_PRIORITY) + withResource(opBatch) { _ => + withResource(GpuColumnVector.from(opBatch)) { opTable => + (opTable.groupBy(keyTypes.indices: _*).contiguousSplitGroupsAndGenUniqKeys(), + keyTypes) + } } - } else { - // only one single table - status.tableCaches.head } - - status.tableCaches.clear() - - val maxRecordsPerFile = description.maxRecordsPerFile - if (!shouldSplitToFitMaxRecordsPerFile( - maxRecordsPerFile, status.writerStatus.recordsInFile, spillableToWrite.numRows())) { - writeUpdateMetricsAndClose(status.writerStatus, spillableToWrite) - } else { - val batchToSplit = withRetryNoSplit(spillableToWrite) { _ => - spillableToWrite.getColumnarBatch() - } - val splits = splitToFitMaxRecordsAndClose( - batchToSplit, - maxRecordsPerFile, - status.writerStatus.recordsInFile) - var needNewWriter = status.writerStatus.recordsInFile >= maxRecordsPerFile - withResource(splits) { _ => - splits.zipWithIndex.foreach { case (split, partIndex) => - if (needNewWriter) { - status.writerStatus.fileCounter += 1 - assert(status.writerStatus.fileCounter <= MAX_FILE_COUNTER, - s"File counter ${status.writerStatus.fileCounter} " + - s"is beyond max value $MAX_FILE_COUNTER") - status.writerStatus.outputWriter.close() - // start a new writer - val w = newWriter(partitionDir, None, status.writerStatus.fileCounter) - status.writerStatus.outputWriter = w - status.writerStatus.recordsInFile = 0L + // Copy keys table to host and make group batches spillable + val (keyHostCb, groups) = withResource(keysAndGroups) { _ => + // groups number should equal to uniq keys number + assert(keysAndGroups.getGroups.length == keysAndGroups.getUniqKeyTable.getRowCount) + closeOnExcept(copyToHostAsBatch(keysAndGroups.getUniqKeyTable, keyTypes)) { keyHostCb => + keysAndGroups.closeUniqKeyTable() + val allTypes = description.allColumns.map(_.dataType).toArray + val allColsIds = allTypes.indices.map(_ + keyTypes.length) + val gps = keysAndGroups.getGroups.safeMap { gp => + withResource(gp.getTable) { gpTable => + withResource(new Table(allColsIds.map(gpTable.getColumn): _*)) { allTable => + SpillableColumnarBatch(GpuColumnVector.from(allTable, allTypes), + SpillPriorities.ACTIVE_BATCHING_PRIORITY) + } } - splits(partIndex) = null - writeUpdateMetricsAndClose(status.writerStatus, split) - needNewWriter = true } + (keyHostCb, gps) } } - status.tableCaches.clear() - status.deviceBytes = 0 - } - - def closeCachesAndWriters(): Unit = { - // collect all caches and writers - val allResources = ArrayBuffer[AutoCloseable]() - allResources ++= concurrentWriters.values.flatMap(ws => ws.tableCaches) - allResources ++= concurrentWriters.values.map { ws => - new AutoCloseable() { - override def close(): Unit = { - ws.writerStatus.outputWriter.close() - statsTrackers.foreach(_.closeFile(ws.writerStatus.outputWriter.path())) + // Cache the result to either the map or the pending queue. + withResource(groups) { _ => + withResource(keyHostCb) { _ => + val getBucketId = genGetBucketIdFunc(keyHostCb) + val getNextPartPath = genGetPartitionPathFunc(keyHostCb) + var idx = 0 + while (idx < groups.length && concurrentWriters.size < spec.maxWriters) { + val writerId = new WriterIndex(getNextPartPath(idx), getBucketId(idx)) + val writerStatus = + concurrentWriters.getOrElseUpdate(writerId, new WriterStatusWithBatches) + if (writerStatus.writer == null) { + // a new partition or bucket, so create a writer + renewOutWriter(writerId, writerStatus, closeOldWriter = false) + } + withResource(groups(idx)) { gp => + groups(idx) = null + withResource(gp.getColumnarBatch()) { cb => + val dataScb = SpillableColumnarBatch(getDataColumnsAsBatch(cb), + SpillPriorities.ACTIVE_BATCHING_PRIORITY) + writerStatus.tableCaches.append(dataScb) + writerStatus.deviceBytes += dataScb.sizeInBytes + } + } + idx += 1 + } + if (idx < groups.length) { + // The open writers number reaches the limit, and still some partitions are + // not cached. Append to the queue for the coming fallback to the sorted + // sequential write. + groups.drop(idx).foreach(g => pendingBatches.enqueue(g)) + // Set to null to avoid double close + (idx until groups.length).foreach(groups(_) = null) + logInfo(s"Number of concurrent writers ${concurrentWriters.size} reaches " + + "the threshold. Fall back from concurrent writers to sort-based sequential" + + " writer.") } } } - - // safe close all the caches and writers - allResources.safeClose() - - // clear `concurrentWriters` map - concurrentWriters.values.foreach(ws => ws.tableCaches.clear()) - concurrentWriters.clear() } /** Release all resources. */ override def releaseResources(): Unit = { - // does not use `currentWriter`, only use the writers in the concurrent writer map - assert(currentWriter == null) - - if (fallBackToSortBased) { - // Note: we should close the last partition writer in the single writer. - super.releaseResources() - } + pendingBatches.safeClose() + pendingBatches.clear() // write all caches - concurrentWriters.filter(pair => pair._2.tableCaches.nonEmpty) - .foreach(pair => writeAndCloseCache(pair._1, pair._2)) + concurrentWriters.foreach { case (wi, ws) => + if (ws.tableCaches.nonEmpty) { + writeOneCacheAndClose(wi, ws) + } + } // close all resources - closeCachesAndWriters() - } - - private def findBigPartitions( - sizeThreshold: Long): mutable.Map[String, WriterStatusWithCaches] = { - concurrentWriters.filter(pair => pair._2.deviceBytes >= sizeThreshold) + concurrentWriters.values.toSeq.safeClose() + concurrentWriters.clear() + super.releaseResources() } } @@ -1105,7 +902,7 @@ class GpuDynamicPartitionDataConcurrentWriter( * @param bucketFileNamePrefix Prefix of output file name based on bucket id. */ case class GpuWriterBucketSpec( - bucketIdExpression: Expression, + bucketIdExpression: GpuExpression, bucketFileNamePrefix: Int => String) /** @@ -1134,4 +931,23 @@ class GpuWriteJobDescription( |Partition columns: ${partitionColumns.mkString(", ")} |Data columns: ${dataColumns.mkString(", ")} """.stripMargin) -} \ No newline at end of file +} + +object BucketIdMetaUtils { + // Tag for the bucketing write using Spark Murmur3Hash + def tagForBucketingWrite(meta: RapidsMeta[_, _, _], bucketSpec: Option[BucketSpec], + outputColumns: Seq[Attribute]): Unit = { + bucketSpec.foreach { bSpec => + // Create a Murmur3Hash expression to leverage the overriding types check. + val expr = Murmur3Hash( + bSpec.bucketColumnNames.map(n => outputColumns.find(_.name == n).get), + GpuHashPartitioningBase.DEFAULT_HASH_SEED) + val hashMeta = GpuOverrides.wrapExpr(expr, meta.conf, None) + hashMeta.tagForGpu() + if(!hashMeta.canThisBeReplaced) { + meta.willNotWorkOnGpu(s"Hashing for generating bucket IDs can not run" + + s" on GPU. Details: ${hashMeta.explain(all=false)}") + } + } + } +} diff --git a/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/CreateDataSourceTableAsSelectCommandMetaShims.scala b/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/CreateDataSourceTableAsSelectCommandMetaShims.scala index de066a5486d..d1a26dc80fc 100644 --- a/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/CreateDataSourceTableAsSelectCommandMetaShims.scala +++ b/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/CreateDataSourceTableAsSelectCommandMetaShims.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.rapids.{GpuDataSourceBase, GpuOrcFileFormat} +import org.apache.spark.sql.rapids.{BucketIdMetaUtils, GpuDataSourceBase, GpuOrcFileFormat} import org.apache.spark.sql.rapids.shims.GpuCreateDataSourceTableAsSelectCommand @@ -56,9 +56,7 @@ final class CreateDataSourceTableAsSelectCommandMeta( private var gpuProvider: Option[ColumnarFileFormat] = None override def tagSelfForGpuInternal(): Unit = { - if (cmd.table.bucketSpec.isDefined) { - willNotWorkOnGpu("bucketing is not supported") - } + BucketIdMetaUtils.tagForBucketingWrite(this, cmd.table.bucketSpec, cmd.outputColumns) if (cmd.table.provider.isEmpty) { willNotWorkOnGpu("provider must be defined") } @@ -94,4 +92,4 @@ final class CreateDataSourceTableAsSelectCommandMeta( conf.stableSort, conf.concurrentWriterPartitionFlushSize) } -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/GpuOptimizedCreateHiveTableAsSelectCommandShims.scala b/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/GpuOptimizedCreateHiveTableAsSelectCommandShims.scala index 5e2601a0467..55d9bc53704 100644 --- a/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/GpuOptimizedCreateHiveTableAsSelectCommandShims.scala +++ b/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/GpuOptimizedCreateHiveTableAsSelectCommandShims.scala @@ -184,9 +184,8 @@ final class OptimizedCreateHiveTableAsSelectCommandMeta( willNotWorkOnGpu("partitioned writes are not supported") } - if (tableDesc.bucketSpec.isDefined) { - willNotWorkOnGpu("bucketing is not supported") - } + GpuBucketingUtils.tagForHiveBucketingWrite(this, tableDesc.bucketSpec, + cmd.outputColumns, false) val serde = tableDesc.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) if (serde.contains("parquet")) { diff --git a/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/spark311/GpuBucketingUtils.scala b/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/spark311/GpuBucketingUtils.scala new file mode 100644 index 00000000000..a604267d1d9 --- /dev/null +++ b/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/spark311/GpuBucketingUtils.scala @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "311"} +{"spark": "312"} +{"spark": "313"} +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import com.nvidia.spark.rapids.RapidsMeta + +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.rapids.{BucketIdMetaUtils, GpuWriterBucketSpec} + +object GpuBucketingUtils { + + def getWriterBucketSpec( + bucketSpec: Option[BucketSpec], + dataColumns: Seq[Attribute], + options: Map[String, String], + forceHiveHash: Boolean): Option[GpuWriterBucketSpec] = { + bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + if (forceHiveHash) { + // Forcely use HiveHash for Hive write commands for some customized Spark binaries. + // TODO: Cannot support this until we support Hive hash partitioning on the GPU + throw new UnsupportedOperationException("Hive hash partitioning is not supported" + + " on GPU") + } else { + // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id + // expression, so that we can guarantee the data distribution is same between shuffle and + // bucketed data source, which enables us to only shuffle one side when join a bucketed + // table and a normal one. + val bucketIdExpression = GpuHashPartitioning(bucketColumns, spec.numBuckets) + .partitionIdExpression + GpuWriterBucketSpec(bucketIdExpression, (_: Int) => "") + } + } + } + + def isHiveHashBucketing(options: Map[String, String]): Boolean = false + + def getOptionsWithHiveBucketWrite(bucketSpec: Option[BucketSpec]): Map[String, String] = { + Map.empty + } + + def tagForHiveBucketingWrite(meta: RapidsMeta[_, _, _], bucketSpec: Option[BucketSpec], + outColumns: Seq[Attribute], forceHiveHash: Boolean): Unit = { + if (forceHiveHash) { + bucketSpec.foreach(_ => + meta.willNotWorkOnGpu("Hive Hashing for generating bucket IDs is not supported yet") + ) + } else { + BucketIdMetaUtils.tagForBucketingWrite(meta, bucketSpec, outColumns) + } + } +} diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/hive/rapids/shims/GpuCreateHiveTableAsSelectCommand.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/hive/rapids/shims/GpuCreateHiveTableAsSelectCommand.scala index 034567d60e5..acdd53b74ab 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/hive/rapids/shims/GpuCreateHiveTableAsSelectCommand.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/hive/rapids/shims/GpuCreateHiveTableAsSelectCommand.scala @@ -36,7 +36,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.hive.rapids.shims import com.nvidia.spark.rapids.{DataFromReplacementRule, DataWritingCommandMeta, GpuDataWritingCommand, GpuOverrides, RapidsConf, RapidsMeta} -import com.nvidia.spark.rapids.shims.GpuCreateHiveTableAsSelectBase +import com.nvidia.spark.rapids.shims.{GpuBucketingUtils, GpuCreateHiveTableAsSelectBase} import org.apache.spark.sql.{SaveMode, SparkSession} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog} @@ -61,9 +61,8 @@ final class GpuCreateHiveTableAsSelectCommandMeta(cmd: CreateHiveTableAsSelectCo willNotWorkOnGpu("partitioned writes are not supported") } - if (tableDesc.bucketSpec.isDefined) { - willNotWorkOnGpu("bucketing is not supported") - } + GpuBucketingUtils.tagForHiveBucketingWrite(this, tableDesc.bucketSpec, + cmd.outputColumns, false) val catalog = spark.sessionState.catalog val tableExists = catalog.tableExists(tableDesc.identifier) @@ -137,4 +136,4 @@ case class GpuCreateHiveTableAsSelectCommand( // Do not support partitioned or bucketed writes override def requireSingleBatch: Boolean = false -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/hive/rapids/shims/GpuInsertIntoHiveTable.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/hive/rapids/shims/GpuInsertIntoHiveTable.scala index 2ea0301fa2c..3f59d6565a5 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/hive/rapids/shims/GpuInsertIntoHiveTable.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/hive/rapids/shims/GpuInsertIntoHiveTable.scala @@ -38,6 +38,7 @@ package org.apache.spark.sql.hive.rapids.shims import java.util.Locale import com.nvidia.spark.rapids.{ColumnarFileFormat, DataFromReplacementRule, DataWritingCommandMeta, GpuDataWritingCommand, RapidsConf, RapidsMeta} +import com.nvidia.spark.rapids.shims.GpuBucketingUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf @@ -216,7 +217,9 @@ case class GpuInsertIntoHiveTable( hadoopConf = hadoopConf, fileFormat = fileFormat, outputLocation = tmpLocation.toString, - partitionAttributes = partitionAttributes) + partitionAttributes = partitionAttributes, + bucketSpec = table.bucketSpec, + options = GpuBucketingUtils.getOptionsWithHiveBucketWrite(table.bucketSpec)) if (partition.nonEmpty) { if (numDynamicPartitions > 0) { diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala index f788971a85f..4adbd7b2ef5 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala @@ -39,7 +39,7 @@ import java.util.{Date, UUID} import com.nvidia.spark.TimingUtils import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.RapidsFileSourceMetaUtils +import com.nvidia.spark.rapids.shims.{GpuBucketingUtils, RapidsFileSourceMetaUtils} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ @@ -136,13 +136,8 @@ object GpuFileFormatWriter extends Logging { if (projectList.nonEmpty) GpuProjectExec(projectList, plan)() else plan } - val writerBucketSpec: Option[GpuWriterBucketSpec] = bucketSpec.map { spec => - // TODO: Cannot support this until we: - // support Hive hash partitioning on the GPU - throw new UnsupportedOperationException("GPU hash partitioning for bucketed data is not " - + "compatible with the CPU version") - } - + val writerBucketSpec = GpuBucketingUtils.getWriterBucketSpec(bucketSpec, dataColumns, + options, false) val sortColumns = bucketSpec.toSeq.flatMap { spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) } @@ -328,8 +323,8 @@ object GpuFileFormatWriter extends Logging { } else { concurrentOutputWriterSpec match { case Some(spec) => - new GpuDynamicPartitionDataConcurrentWriter( - description, taskAttemptContext, committer, spec, TaskContext.get()) + new GpuDynamicPartitionDataConcurrentWriter(description, taskAttemptContext, + committer, spec) case _ => new GpuDynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) } diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/spark330/GpuBucketingUtils.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/spark330/GpuBucketingUtils.scala new file mode 100644 index 00000000000..feb562fa9b8 --- /dev/null +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/spark330/GpuBucketingUtils.scala @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "330db"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332cdh"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "334"} +{"spark": "340"} +{"spark": "341"} +{"spark": "341db"} +{"spark": "342"} +{"spark": "343"} +{"spark": "350"} +{"spark": "351"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import com.nvidia.spark.rapids.RapidsMeta + +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.datasources.BucketingUtils +import org.apache.spark.sql.rapids.GpuWriterBucketSpec + +object GpuBucketingUtils { + + def getWriterBucketSpec( + bucketSpec: Option[BucketSpec], + dataColumns: Seq[Attribute], + options: Map[String, String], + forceHiveHash: Boolean): Option[GpuWriterBucketSpec] = { + bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + val shouldHiveCompatibleWrite = options.getOrElse( + BucketingUtils.optionForHiveCompatibleBucketWrite, "false").toBoolean + if (shouldHiveCompatibleWrite) { + // TODO: Cannot support this until we support Hive hash partitioning on the GPU + throw new UnsupportedOperationException("Hive hash partitioning is not supported" + + " on GPU") + } else { + // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id + // expression, so that we can guarantee the data distribution is same between shuffle and + // bucketed data source, which enables us to only shuffle one side when join a bucketed + // table and a normal one. + val bucketIdExpression = GpuHashPartitioning(bucketColumns, spec.numBuckets) + .partitionIdExpression + GpuWriterBucketSpec(bucketIdExpression, (_: Int) => "") + } + } + } + + def isHiveHashBucketing(options: Map[String, String]): Boolean = { + options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false").toBoolean + } + + def getOptionsWithHiveBucketWrite(bucketSpec: Option[BucketSpec]): Map[String, String] = { + bucketSpec + .map(_ => Map(BucketingUtils.optionForHiveCompatibleBucketWrite -> "true")) + .getOrElse(Map.empty) + } + + def tagForHiveBucketingWrite(meta: RapidsMeta[_, _, _], bucketSpec: Option[BucketSpec], + outColumns: Seq[Attribute], forceHiveHash: Boolean): Unit = { + bucketSpec.foreach(_ => + // From Spark330, Hive write always uses HiveHash to generate bucket IDs. + meta.willNotWorkOnGpu("Hive Hashing for generating bucket IDs is not supported yet") + ) + } +} diff --git a/sql-plugin/src/main/spark332db/scala/com/nvidia/spark/rapids/shims/CreateDataSourceTableAsSelectCommandMetaShims.scala b/sql-plugin/src/main/spark332db/scala/com/nvidia/spark/rapids/shims/CreateDataSourceTableAsSelectCommandMetaShims.scala index faa550c0cb6..f51bd984bdc 100644 --- a/sql-plugin/src/main/spark332db/scala/com/nvidia/spark/rapids/shims/CreateDataSourceTableAsSelectCommandMetaShims.scala +++ b/sql-plugin/src/main/spark332db/scala/com/nvidia/spark/rapids/shims/CreateDataSourceTableAsSelectCommandMetaShims.scala @@ -30,10 +30,10 @@ package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.rapids.{GpuDataSourceBase, GpuOrcFileFormat} +import org.apache.spark.sql.rapids.{BucketIdMetaUtils, GpuDataSourceBase, GpuOrcFileFormat} import org.apache.spark.sql.rapids.shims.GpuCreateDataSourceTableAsSelectCommand final class CreateDataSourceTableAsSelectCommandMeta( @@ -46,9 +46,9 @@ final class CreateDataSourceTableAsSelectCommandMeta( private var origProvider: Class[_] = _ override def tagSelfForGpu(): Unit = { - if (cmd.table.bucketSpec.isDefined) { - willNotWorkOnGpu("bucketing is not supported") - } + val outputColumns = + DataWritingCommand.logicalPlanOutputWithNames(cmd.query, cmd.outputColumnNames) + BucketIdMetaUtils.tagForBucketingWrite(this, cmd.table.bucketSpec, outputColumns) if (cmd.table.provider.isEmpty) { willNotWorkOnGpu("provider must be defined") } @@ -76,4 +76,4 @@ final class CreateDataSourceTableAsSelectCommandMeta( cmd.outputColumnNames, origProvider) } -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/spark332db/scala/com/nvidia/spark/rapids/shims/GpuInsertIntoHiveTable.scala b/sql-plugin/src/main/spark332db/scala/com/nvidia/spark/rapids/shims/GpuInsertIntoHiveTable.scala index 42fd5941025..b3103c3c76e 100644 --- a/sql-plugin/src/main/spark332db/scala/com/nvidia/spark/rapids/shims/GpuInsertIntoHiveTable.scala +++ b/sql-plugin/src/main/spark332db/scala/com/nvidia/spark/rapids/shims/GpuInsertIntoHiveTable.scala @@ -30,6 +30,7 @@ package org.apache.spark.sql.hive.rapids.shims import java.util.Locale import com.nvidia.spark.rapids.{ColumnarFileFormat, DataFromReplacementRule, DataWritingCommandMeta, GpuDataWritingCommand, RapidsConf, RapidsMeta} +import com.nvidia.spark.rapids.shims.GpuBucketingUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf @@ -205,7 +206,9 @@ case class GpuInsertIntoHiveTable( hadoopConf = hadoopConf, fileFormat = fileFormat, outputLocation = tmpLocation.toString, - partitionAttributes = partitionAttributes) + partitionAttributes = partitionAttributes, + bucketSpec = table.bucketSpec, + options = GpuBucketingUtils.getOptionsWithHiveBucketWrite(table.bucketSpec)) if (partition.nonEmpty) { if (numDynamicPartitions > 0) { @@ -349,4 +352,4 @@ case class GpuInsertIntoHiveTable( } override def requireSingleBatch: Boolean = false // TODO: Re-evaluate. If partitioned or bucketed? -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/spark332db/scala/com/nvidia/spark/rapids/shims/GpuOptimizedCreateHiveTableAsSelectCommandShims.scala b/sql-plugin/src/main/spark332db/scala/com/nvidia/spark/rapids/shims/GpuOptimizedCreateHiveTableAsSelectCommandShims.scala index 53c17d2f946..e74bf979af9 100644 --- a/sql-plugin/src/main/spark332db/scala/com/nvidia/spark/rapids/shims/GpuOptimizedCreateHiveTableAsSelectCommandShims.scala +++ b/sql-plugin/src/main/spark332db/scala/com/nvidia/spark/rapids/shims/GpuOptimizedCreateHiveTableAsSelectCommandShims.scala @@ -197,9 +197,9 @@ final class OptimizedCreateHiveTableAsSelectCommandMeta( willNotWorkOnGpu("partitioned writes are not supported") } - if (tableDesc.bucketSpec.isDefined) { - willNotWorkOnGpu("bucketing is not supported") - } + val outputColumns = + DataWritingCommand.logicalPlanOutputWithNames(cmd.query, cmd.outputColumnNames) + GpuBucketingUtils.tagForHiveBucketingWrite(this, tableDesc.bucketSpec, outputColumns, false) val serde = tableDesc.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) if (serde.contains("parquet")) { diff --git a/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala b/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala index e7b3561f5fd..874d89353aa 100644 --- a/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala +++ b/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala @@ -31,7 +31,7 @@ import java.util.{Date, UUID} import com.nvidia.spark.TimingUtils import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.RapidsFileSourceMetaUtils +import com.nvidia.spark.rapids.shims.{GpuBucketingUtils, RapidsFileSourceMetaUtils} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ @@ -119,13 +119,8 @@ object GpuFileFormatWriter extends Logging { .map(RapidsFileSourceMetaUtils.cleanupFileSourceMetadataInformation)) val dataColumns = finalOutputSpec.outputColumns.filterNot(partitionSet.contains) - val writerBucketSpec: Option[GpuWriterBucketSpec] = bucketSpec.map { spec => - // TODO: Cannot support this until we: - // support Hive hash partitioning on the GPU - throw new UnsupportedOperationException("GPU hash partitioning for bucketed data is not " - + "compatible with the CPU version") - } - + val writerBucketSpec = GpuBucketingUtils.getWriterBucketSpec(bucketSpec, dataColumns, + options, false) val sortColumns = bucketSpec.toSeq.flatMap { spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) } @@ -419,8 +414,8 @@ object GpuFileFormatWriter extends Logging { } else { concurrentOutputWriterSpec match { case Some(spec) => - new GpuDynamicPartitionDataConcurrentWriter( - description, taskAttemptContext, committer, spec, TaskContext.get()) + new GpuDynamicPartitionDataConcurrentWriter(description, taskAttemptContext, + committer, spec) case _ => new GpuDynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) } diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala index 5aaeae2c7b9..d52c8b47ae7 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.TableWriter -import com.nvidia.spark.rapids.{ColumnarOutputWriter, ColumnarOutputWriterFactory, GpuBoundReference, GpuColumnVector, RapidsBufferCatalog, RapidsDeviceMemoryStore, ScalableTaskCompletion} +import com.nvidia.spark.rapids.{ColumnarOutputWriter, ColumnarOutputWriterFactory, GpuColumnVector, GpuLiteral, RapidsBufferCatalog, RapidsDeviceMemoryStore, ScalableTaskCompletion} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.jni.{GpuRetryOOM, GpuSplitAndRetryOOM} import org.apache.hadoop.conf.Configuration @@ -28,7 +28,6 @@ import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar.mock -import org.apache.spark.TaskContext import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, ExprId, SortOrder} @@ -39,7 +38,6 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { private var mockJobDescription: GpuWriteJobDescription = _ - private var mockTaskContext: TaskContext = _ private var mockTaskAttemptContext: TaskAttemptContext = _ private var mockCommitter: FileCommitProtocol = _ private var mockOutputWriterFactory: ColumnarOutputWriterFactory = _ @@ -48,6 +46,7 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { private var allCols: Seq[AttributeReference] = _ private var partSpec: Seq[AttributeReference] = _ private var dataSpec: Seq[AttributeReference] = _ + private var bucketSpec: Option[GpuWriterBucketSpec] = None private var includeRetry: Boolean = false class NoTransformColumnarOutputWriter( @@ -102,9 +101,9 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { allCols = null partSpec = null dataSpec = null + bucketSpec = None mockJobDescription = mock[GpuWriteJobDescription] when(mockJobDescription.statsTrackers).thenReturn(Seq.empty) - mockTaskContext = mock[TaskContext] mockTaskAttemptContext = mock[TaskAttemptContext] mockCommitter = mock[FileCommitProtocol] mockOutputWriterFactory = mock[ColumnarOutputWriterFactory] @@ -130,8 +129,12 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { * It is used to setup certain mocks before `body` is executed. After execution, the * columns in the batches are checked for `refCount==0` (e.g. that they were closed). * @note it is assumed that the schema of each batch is identical. + * numBuckets > 0: Bucketing only + * numBuckets == 0: Partition only + * numBuckets < 0: Both partition and bucketing */ - def withColumnarBatchesVerifyClosed[V](cbs: Seq[ColumnarBatch])(body: => V): Unit = { + def withColumnarBatchesVerifyClosed[V]( + cbs: Seq[ColumnarBatch], numBuckets: Int = 0)(body: => V): Unit = { val allTypes = cbs.map(GpuColumnVector.extractTypes) allCols = Seq.empty dataSpec = Seq.empty @@ -140,8 +143,17 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { allCols = allTypes.head.zipWithIndex.map { case (dataType, colIx) => AttributeReference(s"col_$colIx", dataType, nullable = false)(ExprId(colIx)) } - partSpec = Seq(allCols.head) - dataSpec = allCols.tail + if (numBuckets <= 0) { + partSpec = Seq(allCols.head) + dataSpec = allCols.tail + } else { + dataSpec = allCols + } + if (numBuckets != 0) { + bucketSpec = Some(GpuWriterBucketSpec( + GpuPmod(GpuMurmur3Hash(Seq(allCols.last), 42), GpuLiteral(Math.abs(numBuckets))), + _ => "")) + } } val fields = new Array[StructField](allCols.size) allCols.zipWithIndex.foreach { case (col, ix) => @@ -153,6 +165,7 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { } when(mockJobDescription.dataColumns).thenReturn(dataSpec) when(mockJobDescription.partitionColumns).thenReturn(partSpec) + when(mockJobDescription.bucketSpec).thenReturn(bucketSpec) when(mockJobDescription.allColumns).thenReturn(allCols) try { body @@ -187,6 +200,20 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { new ColumnarBatch(cols, rowCount) } + def buildBatchWithPartitionedAndBucketCols( + partInts: Seq[Int], bucketInts: Seq[Int]): ColumnarBatch = { + assert(partInts.length == bucketInts.length) + val rowCount = partInts.size + val cols: Array[ColumnVector] = new Array[ColumnVector](3) + val partCol = ai.rapids.cudf.ColumnVector.fromInts(partInts: _*) + val dataCol = ai.rapids.cudf.ColumnVector.fromStrings(partInts.map(_.toString): _*) + val bucketCol = ai.rapids.cudf.ColumnVector.fromInts(bucketInts: _*) + cols(0) = GpuColumnVector.from(partCol, IntegerType) + cols(1) = GpuColumnVector.from(dataCol, StringType) + cols(2) = GpuColumnVector.from(bucketCol, IntegerType) + new ColumnarBatch(cols, rowCount) + } + def verifyClosed(cbs: Seq[ColumnarBatch]): Unit = { cbs.foreach { cb => val cols = GpuColumnVector.extractBases(cb) @@ -198,7 +225,6 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { def prepareDynamicPartitionSingleWriter(): GpuDynamicPartitionDataSingleWriter = { - when(mockJobDescription.bucketSpec).thenReturn(None) when(mockJobDescription.customPartitionLocations) .thenReturn(Map.empty[TablePartitionSpec, String]) @@ -212,13 +238,10 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { GpuDynamicPartitionDataConcurrentWriter = { val mockConfig = new Configuration() when(mockTaskAttemptContext.getConfiguration).thenReturn(mockConfig) - when(mockJobDescription.bucketSpec).thenReturn(None) when(mockJobDescription.customPartitionLocations) .thenReturn(Map.empty[TablePartitionSpec, String]) - // assume the first column is the partition-by column - val sortExpr = - GpuBoundReference(0, partSpec.head.dataType, nullable = false)(ExprId(0), "") - val sortSpec = Seq(SortOrder(sortExpr, Ascending)) + val sortSpec = (partSpec ++ bucketSpec.map(_.bucketIdExpression)) + .map(SortOrder(_, Ascending)) val concurrentSpec = GpuConcurrentOutputWriterSpec( maxWriters, allCols, batchSize, sortSpec) @@ -226,8 +249,7 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { mockJobDescription, mockTaskAttemptContext, mockCommitter, - concurrentSpec, - mockTaskContext)) + concurrentSpec)) } test("empty directory data writer") { @@ -317,18 +339,6 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { } } - test("dynamic partition data writer doesn't support bucketing") { - resetMocksWithAndWithoutRetry { - withColumnarBatchesVerifyClosed(Seq.empty) { - when(mockJobDescription.bucketSpec).thenReturn(Some(GpuWriterBucketSpec(null, null))) - assertThrows[UnsupportedOperationException] { - new GpuDynamicPartitionDataSingleWriter( - mockJobDescription, mockTaskAttemptContext, mockCommitter) - } - } - } - } - test("dynamic partition data writer without splits") { resetMocksWithAndWithoutRetry { // 4 partitions @@ -353,6 +363,35 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { } } + test("dynamic partition data writer bucketing write without splits") { + Seq(5, -5).foreach { numBuckets => + val (numWrites, numNewWriters) = if (numBuckets > 0) { // Bucket only + (6, 6) // 3 buckets + 3 buckets + } else { // partition and bucket + (10, 10) // 5 pairs + 5 pairs + } + resetMocksWithAndWithoutRetry { + val cb = buildBatchWithPartitionedAndBucketCols( + IndexedSeq(1, 1, 2, 2, 3, 3, 4, 4), + IndexedSeq(1, 1, 1, 1, 2, 2, 2, 3)) + val cb2 = buildBatchWithPartitionedAndBucketCols( + IndexedSeq(1, 2, 3, 4, 5), + IndexedSeq(1, 1, 2, 2, 3)) + val cbs = Seq(spy(cb), spy(cb2)) + withColumnarBatchesVerifyClosed(cbs, numBuckets) { + // setting to 9 then the writer won't split as no group has more than 9 rows + when(mockJobDescription.maxRecordsPerFile).thenReturn(9) + val dynamicSingleWriter = prepareDynamicPartitionSingleWriter() + dynamicSingleWriter.writeWithIterator(cbs.iterator) + dynamicSingleWriter.commit() + verify(mockOutputWriter, times(numWrites)).writeSpillableAndClose(any(), any()) + verify(dynamicSingleWriter, times(numNewWriters)).newWriter(any(), any(), any()) + verify(mockOutputWriter, times(numNewWriters)).close() + } + } + } + } + test("dynamic partition data writer with splits") { resetMocksWithAndWithoutRetry { val cb = buildBatchWithPartitionedCol(1, 1, 2, 2, 3, 3, 4, 4) @@ -399,6 +438,38 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { } } + test("dynamic partition concurrent data writer bucketing write without splits") { + Seq(5, -5).foreach { numBuckets => + val (numWrites, numNewWriters) = if (numBuckets > 0) { // Bucket only + (3, 3) // 3 distinct buckets in total + } else { // partition and bucket + (6, 6) // 6 distinct pairs in total + } + resetMocksWithAndWithoutRetry { + val cb = buildBatchWithPartitionedAndBucketCols( + IndexedSeq(1, 1, 2, 2, 3, 3, 4, 4), + IndexedSeq(1, 1, 1, 1, 2, 2, 2, 3)) + val cb2 = buildBatchWithPartitionedAndBucketCols( + IndexedSeq(1, 2, 3, 4, 5), + IndexedSeq(1, 1, 2, 2, 3)) + val cbs = Seq(spy(cb), spy(cb2)) + withColumnarBatchesVerifyClosed(cbs, numBuckets) { + // setting to 9 then the writer won't split as no group has more than 9 rows + when(mockJobDescription.maxRecordsPerFile).thenReturn(9) + // I would like to not flush on the first iteration of the `write` method + when(mockJobDescription.concurrentWriterPartitionFlushSize).thenReturn(1000) + val dynamicConcurrentWriter = + prepareDynamicPartitionConcurrentWriter(maxWriters = 20, batchSize = 100) + dynamicConcurrentWriter.writeWithIterator(cbs.iterator) + dynamicConcurrentWriter.commit() + verify(mockOutputWriter, times(numWrites)).writeSpillableAndClose(any(), any()) + verify(dynamicConcurrentWriter, times(numNewWriters)).newWriter(any(), any(), any()) + verify(mockOutputWriter, times(numNewWriters)).close() + } + } + } + } + test("dynamic partition concurrent data writer with splits and flush") { resetMocksWithAndWithoutRetry { val cb = buildBatchWithPartitionedCol(1, 1, 2, 2, 3, 3, 4, 4) @@ -438,8 +509,9 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { prepareDynamicPartitionConcurrentWriter(maxWriters = 1, batchSize = 1) dynamicConcurrentWriter.writeWithIterator(cbs.iterator) dynamicConcurrentWriter.commit() - // 5 batches written, one per partition (no splitting) - verify(mockOutputWriter, times(5)) + // 6 batches written, one per partition (no splitting) plus one written by + // the concurrent writer. + verify(mockOutputWriter, times(6)) .writeSpillableAndClose(any(), any()) verify(dynamicConcurrentWriter, times(5)).newWriter(any(), any(), any()) // 5 files written because this is the single writer mode From 18ec4b2530f68ad2703e661cfb5a06aaaa2b2dea Mon Sep 17 00:00:00 2001 From: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com> Date: Tue, 25 Jun 2024 09:04:37 +0800 Subject: [PATCH 4/6] upgrade actions version (#11086) Signed-off-by: YanxuanLiu --- .github/workflows/blossom-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 4b8071303c1..447f3d5049b 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -90,7 +90,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: ${{ fromJson(needs.Authorization.outputs.args).repo }} ref: ${{ fromJson(needs.Authorization.outputs.args).ref }} @@ -98,7 +98,7 @@ jobs: # repo specific steps - name: Setup java - uses: actions/setup-java@v3 + uses: actions/setup-java@v4 with: distribution: adopt java-version: 8 From 86a905aac1544fef0554bad188c150b8e9720f91 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Mon, 24 Jun 2024 22:54:55 -0700 Subject: [PATCH 5/6] Fixed Failing tests in arithmetic_ops_tests for Spark 4.0.0 [databricks] (#11044) * Fixed arithmetic_ops_tests * Signing off Signed-off-by: Raza Jafri * Added a mechanism to add ansi mode per test * Reverted unnecessary change to spark_init_internal.py * Corrected the year in the licence * Only set ansi conf to false when ansi_mode_disabled is set * Addressed review comments * Fixed the method name * Update integration_tests/src/main/python/conftest.py This handles cases like `cache_test.py` which should run with the default conf for `spark.sql.ansi.enabled`. --------- Signed-off-by: Raza Jafri Co-authored-by: MithunR --- .../src/main/python/arithmetic_ops_test.py | 77 ++++++++++++++----- integration_tests/src/main/python/conftest.py | 10 +++ integration_tests/src/main/python/marks.py | 3 +- .../src/main/python/spark_session.py | 9 ++- 4 files changed, 76 insertions(+), 23 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index b75872ed8b2..d7fd941b97b 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql from data_gen import * -from marks import ignore_order, incompat, approximate_float, allow_non_gpu, datagen_overrides +from marks import ignore_order, incompat, approximate_float, allow_non_gpu, datagen_overrides, disable_ansi_mode from pyspark.sql.types import * from pyspark.sql.types import IntegralType from spark_session import * @@ -25,6 +25,10 @@ import pyspark.sql.utils from datetime import timedelta +_arithmetic_exception_string = 'java.lang.ArithmeticException' if is_before_spark_330() else \ + 'org.apache.spark.SparkArithmeticException' if is_before_spark_400() else \ + 'pyspark.errors.exceptions.captured.ArithmeticException' + # No overflow gens here because we just focus on verifying the fallback to CPU when # enabling ANSI mode. But overflows will fail the tests because CPU runs raise # exceptions. @@ -95,6 +99,7 @@ def _get_overflow_df(spark, data, data_type, expr): ).selectExpr(expr) @pytest.mark.parametrize('data_gen', _arith_data_gens, ids=idfn) +@disable_ansi_mode def test_addition(data_gen): data_type = data_gen.data_type assert_gpu_and_cpu_are_equal_collect( @@ -119,6 +124,7 @@ def test_addition_ansi_no_overflow(data_gen): conf=ansi_enabled_conf) @pytest.mark.parametrize('data_gen', _arith_data_gens, ids=idfn) +@disable_ansi_mode def test_subtraction(data_gen): data_type = data_gen.data_type assert_gpu_and_cpu_are_equal_collect( @@ -136,6 +142,7 @@ def test_subtraction(data_gen): DecimalGen(10, -2), DecimalGen(15, 3), DecimalGen(30, 12), DecimalGen(3, -3), DecimalGen(27, 7), DecimalGen(20, -3)], ids=idfn) @pytest.mark.parametrize('addOrSub', ['+', '-']) +@disable_ansi_mode def test_addition_subtraction_mixed(lhs, rhs, addOrSub): assert_gpu_and_cpu_are_equal_collect( lambda spark : two_col_df(spark, lhs, rhs).selectExpr(f"a {addOrSub} b") @@ -160,6 +167,7 @@ def test_subtraction_ansi_no_overflow(data_gen): _decimal_gen_38_10, _decimal_gen_38_neg10 ], ids=idfn) +@disable_ansi_mode def test_multiplication(data_gen): data_type = data_gen.data_type assert_gpu_and_cpu_are_equal_collect( @@ -203,6 +211,7 @@ def test_multiplication_ansi_overflow(): @pytest.mark.parametrize('rhs', [byte_gen, short_gen, int_gen, long_gen, DecimalGen(6, 3), DecimalGen(10, -2), DecimalGen(15, 3), DecimalGen(30, 12), DecimalGen(3, -3), DecimalGen(27, 7), DecimalGen(20, -3)], ids=idfn) +@disable_ansi_mode def test_multiplication_mixed(lhs, rhs): assert_gpu_and_cpu_are_equal_collect( lambda spark : two_col_df(spark, lhs, rhs).select( @@ -220,6 +229,7 @@ def test_float_multiplication_mixed(lhs, rhs): @pytest.mark.parametrize('data_gen', [double_gen, decimal_gen_32bit_neg_scale, DecimalGen(6, 3), DecimalGen(5, 5), DecimalGen(6, 0), DecimalGen(7, 4), DecimalGen(15, 0), DecimalGen(18, 0), DecimalGen(17, 2), DecimalGen(16, 4), DecimalGen(38, 21), DecimalGen(21, 17), DecimalGen(3, -2)], ids=idfn) +@disable_ansi_mode def test_division(data_gen): data_type = data_gen.data_type assert_gpu_and_cpu_are_equal_collect( @@ -232,6 +242,7 @@ def test_division(data_gen): @pytest.mark.parametrize('rhs', [byte_gen, short_gen, int_gen, long_gen, DecimalGen(4, 1), DecimalGen(5, 0), DecimalGen(5, 1), DecimalGen(10, 5)], ids=idfn) @pytest.mark.parametrize('lhs', [byte_gen, short_gen, int_gen, long_gen, DecimalGen(5, 3), DecimalGen(4, 2), DecimalGen(1, -2), DecimalGen(16, 1)], ids=idfn) +@disable_ansi_mode def test_division_mixed(lhs, rhs): assert_gpu_and_cpu_are_equal_collect( lambda spark : two_col_df(spark, lhs, rhs).select( @@ -242,12 +253,14 @@ def test_division_mixed(lhs, rhs): # instead of increasing the precision. So we have a second test that deals with a few of these use cases @pytest.mark.parametrize('rhs', [DecimalGen(30, 10), DecimalGen(28, 18)], ids=idfn) @pytest.mark.parametrize('lhs', [DecimalGen(27, 7), DecimalGen(20, -3)], ids=idfn) +@disable_ansi_mode def test_division_mixed_larger_dec(lhs, rhs): assert_gpu_and_cpu_are_equal_collect( lambda spark : two_col_df(spark, lhs, rhs).select( f.col('a'), f.col('b'), f.col('a') / f.col('b'))) +@disable_ansi_mode def test_special_decimal_division(): for precision in range(1, 39): for scale in range(-3, precision + 1): @@ -260,6 +273,7 @@ def test_special_decimal_division(): @approximate_float # we should get the perfectly correct answer for floats except when casting a decimal to a float in some corner cases. @pytest.mark.parametrize('rhs', [float_gen, double_gen], ids=idfn) @pytest.mark.parametrize('lhs', [DecimalGen(5, 3), DecimalGen(4, 2), DecimalGen(1, -2), DecimalGen(16, 1)], ids=idfn) +@disable_ansi_mode def test_float_division_mixed(lhs, rhs): assert_gpu_and_cpu_are_equal_collect( lambda spark : two_col_df(spark, lhs, rhs).select( @@ -269,6 +283,7 @@ def test_float_division_mixed(lhs, rhs): @pytest.mark.parametrize('data_gen', integral_gens + [ decimal_gen_32bit, decimal_gen_64bit, _decimal_gen_7_7, _decimal_gen_18_3, _decimal_gen_30_2, _decimal_gen_36_5, _decimal_gen_38_0], ids=idfn) +@disable_ansi_mode def test_int_division(data_gen): string_type = to_cast_string(data_gen.data_type) assert_gpu_and_cpu_are_equal_collect( @@ -282,12 +297,14 @@ def test_int_division(data_gen): @pytest.mark.parametrize('lhs', [DecimalGen(6, 5), DecimalGen(5, 4), DecimalGen(3, -2), _decimal_gen_30_2], ids=idfn) @pytest.mark.parametrize('rhs', [DecimalGen(13, 2), DecimalGen(6, 3), _decimal_gen_38_0, pytest.param(_decimal_gen_36_neg5, marks=pytest.mark.skipif(not is_before_spark_340() or is_databricks113_or_later(), reason='SPARK-41207'))], ids=idfn) +@disable_ansi_mode def test_int_division_mixed(lhs, rhs): assert_gpu_and_cpu_are_equal_collect( lambda spark : two_col_df(spark, lhs, rhs).selectExpr( 'a DIV b')) @pytest.mark.parametrize('data_gen', _arith_data_gens, ids=idfn) +@disable_ansi_mode def test_mod(data_gen): data_type = data_gen.data_type assert_gpu_and_cpu_are_equal_collect( @@ -308,6 +325,7 @@ def test_mod(data_gen): _decimal_gen_7_7] @pytest.mark.parametrize('data_gen', _pmod_gens, ids=idfn) +@disable_ansi_mode def test_pmod(data_gen): string_type = to_cast_string(data_gen.data_type) assert_gpu_and_cpu_are_equal_collect( @@ -321,6 +339,7 @@ def test_pmod(data_gen): @allow_non_gpu("ProjectExec", "Pmod") @pytest.mark.parametrize('data_gen', test_pmod_fallback_decimal_gens + [_decimal_gen_38_0, _decimal_gen_38_10], ids=idfn) +@disable_ansi_mode def test_pmod_fallback(data_gen): string_type = to_cast_string(data_gen.data_type) assert_gpu_fallback_collect( @@ -372,8 +391,10 @@ def test_cast_neg_to_decimal_err(): data_gen = _decimal_gen_7_7 if is_before_spark_322(): exception_content = "Decimal(compact,-120000000,20,0}) cannot be represented as Decimal(7, 7)" - elif is_databricks113_or_later() or not is_before_spark_340(): + elif is_databricks113_or_later() or not is_before_spark_340() and is_before_spark_400(): exception_content = "[NUMERIC_VALUE_OUT_OF_RANGE] -12 cannot be represented as Decimal(7, 7)" + elif not is_before_spark_400(): + exception_content = "[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] -12 cannot be represented as Decimal(7, 7)" else: exception_content = "Decimal(compact, -120000000, 20, 0) cannot be represented as Decimal(7, 7)" @@ -410,6 +431,7 @@ def test_mod_pmod_by_zero_not_ansi(data_gen): @pytest.mark.parametrize('rhs', [byte_gen, short_gen, int_gen, long_gen, DecimalGen(6, 3), DecimalGen(10, -2), DecimalGen(15, 3), DecimalGen(30, 12), DecimalGen(3, -3), DecimalGen(27, 7), DecimalGen(20, -3)], ids=idfn) +@disable_ansi_mode def test_mod_mixed(lhs, rhs): assert_gpu_and_cpu_are_equal_collect( lambda spark : two_col_df(spark, lhs, rhs).selectExpr(f"a % b")) @@ -417,6 +439,7 @@ def test_mod_mixed(lhs, rhs): # @pytest.mark.skipif(not is_databricks113_or_later() and not is_spark_340_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/8330") @pytest.mark.parametrize('lhs', [DecimalGen(38,0), DecimalGen(37,2), DecimalGen(38,5), DecimalGen(38,-10), DecimalGen(38,7)], ids=idfn) @pytest.mark.parametrize('rhs', [DecimalGen(27,7), DecimalGen(30,10), DecimalGen(38,1), DecimalGen(36,0), DecimalGen(28,-7)], ids=idfn) +@disable_ansi_mode def test_mod_mixed_decimal128(lhs, rhs): assert_gpu_and_cpu_are_equal_collect( lambda spark : two_col_df(spark, lhs, rhs).selectExpr("a", "b", f"a % b")) @@ -424,6 +447,7 @@ def test_mod_mixed_decimal128(lhs, rhs): # Split into 4 tests to permute https://github.com/NVIDIA/spark-rapids/issues/7553 failures @pytest.mark.parametrize('lhs', [byte_gen, short_gen, int_gen, long_gen], ids=idfn) @pytest.mark.parametrize('rhs', [byte_gen, short_gen, int_gen, long_gen], ids=idfn) +@disable_ansi_mode def test_pmod_mixed_numeric(lhs, rhs): assert_gpu_and_cpu_are_equal_collect( lambda spark : two_col_df(spark, lhs, rhs).selectExpr(f"pmod(a, b)")) @@ -433,6 +457,7 @@ def test_pmod_mixed_numeric(lhs, rhs): DecimalGen(4, 2), DecimalGen(3, -2), DecimalGen(16, 7), DecimalGen(19, 0), DecimalGen(30, 10) ], ids=idfn) @pytest.mark.parametrize('rhs', [byte_gen, short_gen, int_gen, long_gen], ids=idfn) +@disable_ansi_mode def test_pmod_mixed_decimal_lhs(lhs, rhs): assert_gpu_fallback_collect( lambda spark : two_col_df(spark, lhs, rhs).selectExpr(f"pmod(a, b)"), @@ -443,6 +468,7 @@ def test_pmod_mixed_decimal_lhs(lhs, rhs): @pytest.mark.parametrize('rhs', [DecimalGen(6, 3), DecimalGen(10, -2), DecimalGen(15, 3), DecimalGen(30, 12), DecimalGen(3, -3), DecimalGen(27, 7), DecimalGen(20, -3) ], ids=idfn) +@disable_ansi_mode def test_pmod_mixed_decimal_rhs(lhs, rhs): assert_gpu_fallback_collect( lambda spark : two_col_df(spark, lhs, rhs).selectExpr(f"pmod(a, b)"), @@ -455,6 +481,7 @@ def test_pmod_mixed_decimal_rhs(lhs, rhs): @pytest.mark.parametrize('rhs', [DecimalGen(6, 3), DecimalGen(10, -2), DecimalGen(15, 3), DecimalGen(30, 12), DecimalGen(3, -3), DecimalGen(27, 7), DecimalGen(20, -3) ], ids=idfn) +@disable_ansi_mode def test_pmod_mixed_decimal(lhs, rhs): assert_gpu_fallback_collect( lambda spark : two_col_df(spark, lhs, rhs).selectExpr(f"pmod(a, b)"), @@ -466,6 +493,7 @@ def test_signum(data_gen): lambda spark : unary_op_df(spark, data_gen).selectExpr('signum(a)')) @pytest.mark.parametrize('data_gen', numeric_gens + _arith_decimal_gens_low_precision, ids=idfn) +@disable_ansi_mode def test_unary_minus(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr('-a')) @@ -504,8 +532,7 @@ def test_unary_minus_ansi_overflow(data_type, value): assert_gpu_and_cpu_error( df_fun=lambda spark: _get_overflow_df(spark, [value], data_type, '-a').collect(), conf=ansi_enabled_conf, - error_message='java.lang.ArithmeticException' if is_before_spark_330() else \ - 'org.apache.spark.SparkArithmeticException') + error_message=_arithmetic_exception_string) # This just ends up being a pass through. There is no good way to force # a unary positive into a plan, because it gets optimized out, but this @@ -516,6 +543,7 @@ def test_unary_positive(data_gen): lambda spark : unary_op_df(spark, data_gen).selectExpr('+a')) @pytest.mark.parametrize('data_gen', numeric_gens + _arith_decimal_gens_low_precision, ids=idfn) +@disable_ansi_mode def test_abs(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr('abs(a)')) @@ -556,10 +584,9 @@ def test_abs_ansi_overflow(data_type, value): GPU: One or more rows overflow for abs operation. """ assert_gpu_and_cpu_error( - df_fun=lambda spark: _get_overflow_df(spark, [value], data_type, 'abs(a)').collect(), - conf=ansi_enabled_conf, - error_message='java.lang.ArithmeticException' if is_before_spark_330() else \ - 'org.apache.spark.SparkArithmeticException') + df_fun=lambda spark: _get_overflow_df(spark, [value], data_type, 'abs(a)').collect(), + conf=ansi_enabled_conf, + error_message=_arithmetic_exception_string) @approximate_float @pytest.mark.parametrize('data_gen', double_gens, ids=idfn) @@ -613,7 +640,8 @@ def test_ceil_scale_zero(data_gen): @pytest.mark.parametrize('data_gen', [_decimal_gen_36_neg5, _decimal_gen_38_neg10], ids=idfn) def test_floor_ceil_overflow(data_gen): exception_type = "java.lang.ArithmeticException" if is_before_spark_330() and not is_databricks104_or_later() \ - else "SparkArithmeticException" + else "SparkArithmeticException" if is_before_spark_400() else \ + "pyspark.errors.exceptions.captured.ArithmeticException: [NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]" assert_gpu_and_cpu_error( lambda spark: unary_op_df(spark, data_gen).selectExpr('floor(a)').collect(), conf={}, @@ -678,6 +706,7 @@ def test_shift_right_unsigned(data_gen): @approximate_float @datagen_overrides(seed=0, reason="https://github.com/NVIDIA/spark-rapids/issues/9350") @pytest.mark.parametrize('data_gen', _arith_data_gens_for_round, ids=idfn) +@disable_ansi_mode def test_decimal_bround(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen).selectExpr( @@ -692,6 +721,7 @@ def test_decimal_bround(data_gen): @approximate_float @datagen_overrides(seed=0, reason="https://github.com/NVIDIA/spark-rapids/issues/9847") @pytest.mark.parametrize('data_gen', _arith_data_gens_for_round, ids=idfn) +@disable_ansi_mode def test_decimal_round(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen).selectExpr( @@ -726,6 +756,7 @@ def doit(spark): @incompat @approximate_float +@disable_ansi_mode def test_non_decimal_round_overflow(): gen = StructGen([('byte_c', byte_gen), ('short_c', short_gen), ('int_c', int_gen), ('long_c', long_gen), @@ -1057,7 +1088,8 @@ def _div_overflow_exception_when(expr, ansi_enabled, is_lit=False): ansi_conf = {'spark.sql.ansi.enabled': ansi_enabled} err_exp = 'java.lang.ArithmeticException' if is_before_spark_330() else \ 'org.apache.spark.SparkArithmeticException' \ - if not is_lit or not is_spark_340_or_later() else "pyspark.errors.exceptions.captured.ArithmeticException" + if (not is_lit or not is_spark_340_or_later()) and is_before_spark_400() else \ + "pyspark.errors.exceptions.captured.ArithmeticException" err_mess = ': Overflow in integral divide' \ if is_before_spark_340() and not is_databricks113_or_later() else \ ': [ARITHMETIC_OVERFLOW] Overflow in integral divide' @@ -1123,7 +1155,7 @@ def test_add_overflow_with_ansi_enabled(data, tp, expr): assert_gpu_and_cpu_error( lambda spark: _get_overflow_df(spark, data, tp, expr).collect(), conf=ansi_enabled_conf, - error_message='java.lang.ArithmeticException' if is_before_spark_330() else 'SparkArithmeticException') + error_message=_arithmetic_exception_string) elif isinstance(tp, DecimalType): assert_gpu_and_cpu_error( lambda spark: _get_overflow_df(spark, data, tp, expr).collect(), @@ -1152,7 +1184,8 @@ def test_subtraction_overflow_with_ansi_enabled(data, tp, expr): assert_gpu_and_cpu_error( lambda spark: _get_overflow_df(spark, data, tp, expr).collect(), conf=ansi_enabled_conf, - error_message='java.lang.ArithmeticException' if is_before_spark_330() else 'SparkArithmeticException') + error_message='java.lang.ArithmeticException' if is_before_spark_330() else 'SparkArithmeticException' \ + if is_before_spark_400() else "pyspark.errors.exceptions.captured.ArithmeticException:") elif isinstance(tp, DecimalType): assert_gpu_and_cpu_error( lambda spark: _get_overflow_df(spark, data, tp, expr).collect(), @@ -1183,7 +1216,7 @@ def test_unary_minus_ansi_overflow_day_time_interval(ansi_enabled): assert_gpu_and_cpu_error( df_fun=lambda spark: _get_overflow_df(spark, [timedelta(microseconds=LONG_MIN)], DayTimeIntervalType(), '-a').collect(), conf={'spark.sql.ansi.enabled': ansi_enabled}, - error_message='SparkArithmeticException') + error_message='SparkArithmeticException' if is_before_spark_400() else "ArithmeticException") @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @pytest.mark.parametrize('ansi_enabled', ['false', 'true']) @@ -1224,7 +1257,7 @@ def test_add_overflow_with_ansi_enabled_day_time_interval(ansi_enabled): StructType([StructField('a', DayTimeIntervalType()), StructField('b', DayTimeIntervalType())]) ).selectExpr('a + b').collect(), conf={'spark.sql.ansi.enabled': ansi_enabled}, - error_message='SparkArithmeticException') + error_message=_arithmetic_exception_string) @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @pytest.mark.parametrize('ansi_enabled', ['false', 'true']) @@ -1244,7 +1277,7 @@ def test_subtraction_overflow_with_ansi_enabled_day_time_interval(ansi_enabled): StructType([StructField('a', DayTimeIntervalType()), StructField('b', DayTimeIntervalType())]) ).selectExpr('a - b').collect(), conf={'spark.sql.ansi.enabled': ansi_enabled}, - error_message='SparkArithmeticException') + error_message='SparkArithmeticException' if is_before_spark_400() else "ArithmeticException") @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') def test_unary_positive_day_time_interval(): @@ -1303,7 +1336,8 @@ def _get_overflow_df_2cols(spark, data_types, values, expr): def test_day_time_interval_division_overflow(data_type, value_pair): exception_message = "SparkArithmeticException: Overflow in integral divide." \ if is_before_spark_340() and not is_databricks113_or_later() else \ - "SparkArithmeticException: [ARITHMETIC_OVERFLOW] Overflow in integral divide." + "SparkArithmeticException: [ARITHMETIC_OVERFLOW] Overflow in integral divide." if is_before_spark_400() else \ + "ArithmeticException: [ARITHMETIC_OVERFLOW] Overflow in integral divide." assert_gpu_and_cpu_error( df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), conf={}, @@ -1338,7 +1372,8 @@ def test_day_time_interval_division_round_overflow(data_type, value_pair): def test_day_time_interval_divided_by_zero(data_type, value_pair): exception_message = "SparkArithmeticException: Division by zero." \ if is_before_spark_340() and not is_databricks113_or_later() else \ - "SparkArithmeticException: [INTERVAL_DIVIDED_BY_ZERO] Division by zero" + "SparkArithmeticException: [INTERVAL_DIVIDED_BY_ZERO] Division by zero" if is_before_spark_400() else \ + "ArithmeticException: [INTERVAL_DIVIDED_BY_ZERO] Division by zero" assert_gpu_and_cpu_error( df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), conf={}, @@ -1349,7 +1384,8 @@ def test_day_time_interval_divided_by_zero(data_type, value_pair): def test_day_time_interval_divided_by_zero_scalar(zero_literal): exception_message = "SparkArithmeticException: Division by zero." \ if is_before_spark_340() and not is_databricks113_or_later() else \ - "SparkArithmeticException: [INTERVAL_DIVIDED_BY_ZERO] Division by zero." + "SparkArithmeticException: [INTERVAL_DIVIDED_BY_ZERO] Division by zero." if is_before_spark_400() else \ + "ArithmeticException: [INTERVAL_DIVIDED_BY_ZERO] Division by zero" assert_gpu_and_cpu_error( df_fun=lambda spark: _get_overflow_df_1col(spark, DayTimeIntervalType(), [timedelta(seconds=1)], 'a / ' + zero_literal).collect(), conf={}, @@ -1369,7 +1405,8 @@ def test_day_time_interval_divided_by_zero_scalar(zero_literal): def test_day_time_interval_scalar_divided_by_zero(data_type, value): exception_message = "SparkArithmeticException: Division by zero." \ if is_before_spark_340() and not is_databricks113_or_later() else \ - "SparkArithmeticException: [INTERVAL_DIVIDED_BY_ZERO] Division by zero." + "SparkArithmeticException: [INTERVAL_DIVIDED_BY_ZERO] Division by zero." if is_before_spark_400() else \ + "ArithmeticException: [INTERVAL_DIVIDED_BY_ZERO] Division by zero" assert_gpu_and_cpu_error( df_fun=lambda spark: _get_overflow_df_1col(spark, data_type, [value], 'INTERVAL 1 SECOND / a').collect(), conf={}, diff --git a/integration_tests/src/main/python/conftest.py b/integration_tests/src/main/python/conftest.py index 1adeb6964fd..6af40b99768 100644 --- a/integration_tests/src/main/python/conftest.py +++ b/integration_tests/src/main/python/conftest.py @@ -54,6 +54,7 @@ def array_columns_to_sort_locally(): _allow_any_non_gpu = False _non_gpu_allowed = [] +_per_test_ansi_mode_enabled = None def is_allowing_any_non_gpu(): return _allow_any_non_gpu @@ -61,6 +62,11 @@ def is_allowing_any_non_gpu(): def get_non_gpu_allowed(): return _non_gpu_allowed + +def is_per_test_ansi_mode_enabled(): + return _per_test_ansi_mode_enabled + + def get_validate_execs_in_gpu_plan(): return _validate_execs_in_gpu_plan @@ -210,10 +216,14 @@ def pytest_runtest_setup(item): global _allow_any_non_gpu global _non_gpu_allowed + global _per_test_ansi_mode_enabled _non_gpu_allowed_databricks = [] _allow_any_non_gpu_databricks = False non_gpu_databricks = item.get_closest_marker('allow_non_gpu_databricks') non_gpu = item.get_closest_marker('allow_non_gpu') + _per_test_ansi_mode_enabled = None if item.get_closest_marker('disable_ansi_mode') is None \ + else not item.get_closest_marker('disable_ansi_mode') + if non_gpu_databricks: if is_databricks_runtime(): if non_gpu_databricks.kwargs and non_gpu_databricks.kwargs['any']: diff --git a/integration_tests/src/main/python/marks.py b/integration_tests/src/main/python/marks.py index 1f326a75505..9a0bde11113 100644 --- a/integration_tests/src/main/python/marks.py +++ b/integration_tests/src/main/python/marks.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ allow_non_gpu_databricks = pytest.mark.allow_non_gpu_databricks allow_non_gpu = pytest.mark.allow_non_gpu +disable_ansi_mode = pytest.mark.disable_ansi_mode validate_execs_in_gpu_plan = pytest.mark.validate_execs_in_gpu_plan approximate_float = pytest.mark.approximate_float ignore_order = pytest.mark.ignore_order diff --git a/integration_tests/src/main/python/spark_session.py b/integration_tests/src/main/python/spark_session.py index c55f1976497..26388617fff 100644 --- a/integration_tests/src/main/python/spark_session.py +++ b/integration_tests/src/main/python/spark_session.py @@ -16,7 +16,7 @@ import calendar, time from datetime import date, datetime from contextlib import contextmanager, ExitStack -from conftest import is_allowing_any_non_gpu, get_non_gpu_allowed, get_validate_execs_in_gpu_plan, is_databricks_runtime, is_at_least_precommit_run, get_inject_oom_conf +from conftest import is_allowing_any_non_gpu, get_non_gpu_allowed, get_validate_execs_in_gpu_plan, is_databricks_runtime, is_at_least_precommit_run, get_inject_oom_conf, is_per_test_ansi_mode_enabled from pyspark.sql import DataFrame from pyspark.sql.types import TimestampType, DateType, _acceptable_types from spark_init_internal import get_spark_i_know_what_i_am_doing, spark_version @@ -41,7 +41,6 @@ def _from_scala_map(scala_map): # Many of these are redundant with default settings for the configs but are set here explicitly # to ensure any cluster settings do not interfere with tests that assume the defaults. _default_conf = { - 'spark.ansi.enabled': 'false', 'spark.rapids.sql.castDecimalToFloat.enabled': 'false', 'spark.rapids.sql.castFloatToDecimal.enabled': 'false', 'spark.rapids.sql.castFloatToIntegralTypes.enabled': 'false', @@ -127,6 +126,9 @@ def with_spark_session(func, conf={}): """Run func that takes a spark session as input with the given configs set.""" reset_spark_session_conf() _add_job_description(conf) + # Only set the ansi conf if not set by the test explicitly by setting the value in the dict + if "spark.sql.ansi.enabled" not in conf and is_per_test_ansi_mode_enabled() is not None: + conf["spark.sql.ansi.enabled"] = is_per_test_ansi_mode_enabled() _set_all_confs(conf) ret = func(_spark) _check_for_proper_return_values(ret) @@ -205,6 +207,9 @@ def is_before_spark_350(): def is_before_spark_351(): return spark_version() < "3.5.1" +def is_before_spark_400(): + return spark_version() < "4.0.0" + def is_spark_320_or_later(): return spark_version() >= "3.2.0" From 7a8690f5e2e4e9009c121e08c6429e2496b4f01c Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Tue, 25 Jun 2024 15:39:47 +0800 Subject: [PATCH 6/6] fix duplicate counted metrics like op time for GpuCoalesceBatches (#11062) * with call site print, not good because some test cases by design will dup Signed-off-by: Hongbin Ma (Mahone) * done Signed-off-by: Hongbin Ma (Mahone) * add file Signed-off-by: Hongbin Ma (Mahone) * fix comiple Signed-off-by: Hongbin Ma (Mahone) * address review comments Signed-off-by: Hongbin Ma (Mahone) --------- Signed-off-by: Hongbin Ma (Mahone) --- .../spark/rapids/GpuCoalesceBatches.scala | 6 +- .../com/nvidia/spark/rapids/GpuExec.scala | 30 ++++++-- .../nvidia/spark/rapids/NvtxWithMetrics.scala | 18 +++-- .../nvidia/spark/rapids/MetricsSuite.scala | 68 +++++++++++++++++++ 4 files changed, 109 insertions(+), 13 deletions(-) create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/MetricsSuite.scala diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala index e6dc216d7e6..1afc03b177b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala @@ -462,7 +462,7 @@ abstract class AbstractGpuCoalesceIterator( // If we have reached the cuDF limit once, proactively filter batches // after that first limit is reached. GpuFilter.filterAndClose(cbFromIter, inputFilterTier.get, - NoopMetric, NoopMetric, opTime) + NoopMetric, NoopMetric, NoopMetric) } else { Iterator(cbFromIter) } @@ -499,7 +499,7 @@ abstract class AbstractGpuCoalesceIterator( var filteredBytes = 0L if (hasAnyToConcat) { val filteredDowIter = GpuFilter.filterAndClose(concatAllAndPutOnGPU(), - filterTier, NoopMetric, NoopMetric, opTime) + filterTier, NoopMetric, NoopMetric, NoopMetric) while (filteredDowIter.hasNext) { closeOnExcept(filteredDowIter.next()) { filteredDownCb => filteredNumRows += filteredDownCb.numRows() @@ -512,7 +512,7 @@ abstract class AbstractGpuCoalesceIterator( // filterAndClose takes ownership of CB so we should not close it on a failure // anymore... val filteredCbIter = GpuFilter.filterAndClose(cb.release, filterTier, - NoopMetric, NoopMetric, opTime) + NoopMetric, NoopMetric, NoopMetric) while (filteredCbIter.hasNext) { closeOnExcept(filteredCbIter.next()) { filteredCb => val filteredWouldBeRows = filteredNumRows + filteredCb.numRows() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala index ec87dd62d6c..d83f20113b2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala @@ -152,12 +152,34 @@ sealed abstract class GpuMetric extends Serializable { def +=(v: Long): Unit def add(v: Long): Unit + private var isTimerActive = false + + final def tryActivateTimer(): Boolean = { + if (!isTimerActive) { + isTimerActive = true + true + } else { + false + } + } + + final def deactivateTimer(duration: Long): Unit = { + if (isTimerActive) { + isTimerActive = false + add(duration) + } + } + final def ns[T](f: => T): T = { - val start = System.nanoTime() - try { + if (tryActivateTimer()) { + val start = System.nanoTime() + try { + f + } finally { + deactivateTimer(System.nanoTime() - start) + } + } else { f - } finally { - add(System.nanoTime() - start) } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala index 92a11f56123..538f117e50f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,26 +32,32 @@ object NvtxWithMetrics { * by the amount of time spent in the range */ class NvtxWithMetrics(name: String, color: NvtxColor, val metrics: GpuMetric*) - extends NvtxRange(name, color) { + extends NvtxRange(name, color) { + val needTracks = metrics.map(_.tryActivateTimer()) private val start = System.nanoTime() override def close(): Unit = { val time = System.nanoTime() - start - metrics.foreach { metric => - metric += time + metrics.toSeq.zip(needTracks).foreach { pair => + if (pair._2) { + pair._1.deactivateTimer(time) + } } super.close() } } class MetricRange(val metrics: GpuMetric*) extends AutoCloseable { + val needTracks = metrics.map(_.tryActivateTimer()) private val start = System.nanoTime() override def close(): Unit = { val time = System.nanoTime() - start - metrics.foreach { metric => - metric += time + metrics.toSeq.zip(needTracks).foreach { pair => + if (pair._2) { + pair._1.deactivateTimer(time) + } } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/MetricsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/MetricsSuite.scala new file mode 100644 index 00000000000..580c5a2ed55 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/MetricsSuite.scala @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import ai.rapids.cudf.NvtxColor +import com.nvidia.spark.rapids.Arm.withResource +import org.scalatest.funsuite.AnyFunSuite + +class MetricsSuite extends AnyFunSuite { + + test("GpuMetric.ns: duplicate timing on the same metrics") { + val m1 = new LocalGpuMetric() + m1.ns( + m1.ns( + Thread.sleep(100) + ) + ) + // if the timing is duplicated, the value should be around 200,000,000 + assert(m1.value < 100000000 * 1.5) + assert(m1.value > 100000000 * 0.5) + } + + test("MetricRange: duplicate timing on the same metrics") { + val m1 = new LocalGpuMetric() + val m2 = new LocalGpuMetric() + withResource(new MetricRange(m1, m2)) { _ => + withResource(new MetricRange(m2, m1)) { _ => + Thread.sleep(100) + } + } + + // if the timing is duplicated, the value should be around 200,000,000 + assert(m1.value < 100000000 * 1.5) + assert(m1.value > 100000000 * 0.5) + assert(m2.value < 100000000 * 1.5) + assert(m2.value > 100000000 * 0.5) + } + + test("NvtxWithMetrics: duplicate timing on the same metrics") { + val m1 = new LocalGpuMetric() + val m2 = new LocalGpuMetric() + withResource(new NvtxWithMetrics("a", NvtxColor.BLUE, m1, m2)) { _ => + withResource(new NvtxWithMetrics("b", NvtxColor.BLUE, m2, m1)) { _ => + Thread.sleep(100) + } + } + + // if the timing is duplicated, the value should be around 200,000,000 + assert(m1.value < 100000000 * 1.5) + assert(m1.value > 100000000 * 0.5) + assert(m2.value < 100000000 * 1.5) + assert(m2.value > 100000000 * 0.5) + } +}