diff --git a/aggregator/pom.xml b/aggregator/pom.xml
index a66c30c226a..f2fc06a370f 100644
--- a/aggregator/pom.xml
+++ b/aggregator/pom.xml
@@ -385,6 +385,12 @@
${project.version}${spark.version.classifier}
+
+ com.nvidia
+ rapids-4-spark-delta-23x_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+
@@ -408,6 +414,12 @@
${project.version}${spark.version.classifier}
+
+ com.nvidia
+ rapids-4-spark-delta-23x_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+
@@ -431,6 +443,12 @@
${project.version}${spark.version.classifier}
+
+ com.nvidia
+ rapids-4-spark-delta-23x_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+
@@ -471,6 +489,12 @@
${project.version}${spark.version.classifier}
+
+ com.nvidia
+ rapids-4-spark-delta-23x_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+
@@ -494,6 +518,12 @@
${project.version}${spark.version.classifier}
+
+ com.nvidia
+ rapids-4-spark-delta-23x_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+
@@ -534,6 +564,12 @@
${project.version}${spark.version.classifier}
+
+ com.nvidia
+ rapids-4-spark-delta-23x_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+
diff --git a/delta-lake/README.md b/delta-lake/README.md
index 945e3508f07..ff9d2553e31 100644
--- a/delta-lake/README.md
+++ b/delta-lake/README.md
@@ -14,6 +14,7 @@ and directory contains the corresponding support code.
| 2.0.x | Spark 3.2.x | `delta-20x` |
| 2.1.x | Spark 3.3.x | `delta-21x` |
| 2.2.x | Spark 3.3.x | `delta-22x` |
+| 2.3.x | Spark 3.3.x | `delta-23x` |
| 2.4.x | Spark 3.4.x | `delta-24x` |
| Databricks 10.4 | Databricks 10.4 | `delta-spark321db` |
| Databricks 11.3 | Databricks 11.3 | `delta-spark330db` |
diff --git a/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/DeltaRuntimeShim.scala b/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/DeltaRuntimeShim.scala
index 65c9ac54094..91c55d090e5 100644
--- a/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/DeltaRuntimeShim.scala
+++ b/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/DeltaRuntimeShim.scala
@@ -53,7 +53,11 @@ object DeltaRuntimeShim {
Try {
DeltaUDF.getClass.getMethod("stringStringUdf", classOf[String => String])
}.map(_ => "org.apache.spark.sql.delta.rapids.delta21x.Delta21xRuntimeShim")
- .getOrElse("org.apache.spark.sql.delta.rapids.delta22x.Delta22xRuntimeShim")
+ .orElse {
+ Try {
+ classOf[DeltaLog].getMethod("assertRemovable")
+ }.map(_ => "org.apache.spark.sql.delta.rapids.delta22x.Delta22xRuntimeShim")
+ }.getOrElse("org.apache.spark.sql.delta.rapids.delta23x.Delta23xRuntimeShim")
} else if (VersionUtils.cmpSparkVersion(3, 5, 0) < 0) {
"org.apache.spark.sql.delta.rapids.delta24x.Delta24xRuntimeShim"
} else {
diff --git a/delta-lake/delta-23x/pom.xml b/delta-lake/delta-23x/pom.xml
new file mode 100644
index 00000000000..9b8cb489cb6
--- /dev/null
+++ b/delta-lake/delta-23x/pom.xml
@@ -0,0 +1,98 @@
+
+
+
+ 4.0.0
+
+
+ com.nvidia
+ rapids-4-spark-parent_2.12
+ 23.12.0-SNAPSHOT
+ ../../pom.xml
+
+
+ rapids-4-spark-delta-23x_2.12
+ RAPIDS Accelerator for Apache Spark Delta Lake 2.3.x Support
+ Delta Lake 2.3.x support for the RAPIDS Accelerator for Apache Spark
+ 23.12.0-SNAPSHOT
+
+
+ ../delta-lake/delta-23x
+ false
+ **/*
+ package
+
+
+
+
+ com.nvidia
+ rapids-4-spark-sql_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+ provided
+
+
+ io.delta
+ delta-core_${scala.binary.version}
+ 2.3.0
+ provided
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+
+
+
+
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-common-sources
+ generate-sources
+
+ add-source
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ org.apache.rat
+ apache-rat-plugin
+
+
+
+
diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/DeleteCommandMeta.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/DeleteCommandMeta.scala
new file mode 100644
index 00000000000..49274eba3aa
--- /dev/null
+++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/DeleteCommandMeta.scala
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta.delta23x
+
+import com.nvidia.spark.rapids.{DataFromReplacementRule, RapidsConf, RapidsMeta, RunnableCommandMeta}
+import com.nvidia.spark.rapids.delta.RapidsDeltaUtils
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.delta.commands.{DeleteCommand, DeletionVectorUtils}
+import org.apache.spark.sql.delta.rapids.GpuDeltaLog
+import org.apache.spark.sql.delta.rapids.delta23x.GpuDeleteCommand
+import org.apache.spark.sql.delta.sources.DeltaSQLConf
+import org.apache.spark.sql.execution.command.RunnableCommand
+
+class DeleteCommandMeta(
+ deleteCmd: DeleteCommand,
+ conf: RapidsConf,
+ parent: Option[RapidsMeta[_, _, _]],
+ rule: DataFromReplacementRule)
+ extends RunnableCommandMeta[DeleteCommand](deleteCmd, conf, parent, rule) {
+
+ override def tagSelfForGpu(): Unit = {
+ if (!conf.isDeltaWriteEnabled) {
+ willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " +
+ s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
+ }
+ val dvFeatureEnabled = DeletionVectorUtils.deletionVectorsWritable(
+ deleteCmd.deltaLog.unsafeVolatileSnapshot)
+ if (dvFeatureEnabled && deleteCmd.conf.getConf(
+ DeltaSQLConf.DELETE_USE_PERSISTENT_DELETION_VECTORS)) {
+ // https://github.com/NVIDIA/spark-rapids/issues/8554
+ willNotWorkOnGpu("Deletion vectors are not supported on GPU")
+ }
+ RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, Some(deleteCmd.deltaLog),
+ Map.empty, SparkSession.active)
+ }
+
+ override def convertToGpu(): RunnableCommand = {
+ GpuDeleteCommand(
+ new GpuDeltaLog(deleteCmd.deltaLog, conf),
+ deleteCmd.target,
+ deleteCmd.condition)
+ }
+}
diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/Delta23xProvider.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/Delta23xProvider.scala
new file mode 100644
index 00000000000..c32c0a0e02d
--- /dev/null
+++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/Delta23xProvider.scala
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta.delta23x
+
+import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta}
+import com.nvidia.spark.rapids.delta.DeltaIOProvider
+
+import org.apache.spark.sql.delta.DeltaParquetFileFormat
+import org.apache.spark.sql.delta.DeltaParquetFileFormat.{IS_ROW_DELETED_COLUMN_NAME, ROW_INDEX_COLUMN_NAME}
+import org.apache.spark.sql.delta.catalog.DeltaCatalog
+import org.apache.spark.sql.delta.commands.{DeleteCommand, MergeIntoCommand, UpdateCommand}
+import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim
+import org.apache.spark.sql.execution.FileSourceScanExec
+import org.apache.spark.sql.execution.command.RunnableCommand
+import org.apache.spark.sql.execution.datasources.FileFormat
+import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
+import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec}
+
+object Delta23xProvider extends DeltaIOProvider {
+
+ override def getRunnableCommandRules: Map[Class[_ <: RunnableCommand],
+ RunnableCommandRule[_ <: RunnableCommand]] = {
+ Seq(
+ GpuOverrides.runnableCmd[DeleteCommand](
+ "Delete rows from a Delta Lake table",
+ (a, conf, p, r) => new DeleteCommandMeta(a, conf, p, r))
+ .disabledByDefault("Delta Lake delete support is experimental"),
+ GpuOverrides.runnableCmd[MergeIntoCommand](
+ "Merge of a source query/table into a Delta table",
+ (a, conf, p, r) => new MergeIntoCommandMeta(a, conf, p, r))
+ .disabledByDefault("Delta Lake merge support is experimental"),
+ GpuOverrides.runnableCmd[UpdateCommand](
+ "Update rows in a Delta Lake table",
+ (a, conf, p, r) => new UpdateCommandMeta(a, conf, p, r))
+ .disabledByDefault("Delta Lake update support is experimental")
+ ).map(r => (r.getClassFor.asSubclass(classOf[RunnableCommand]), r)).toMap
+ }
+
+ override def tagSupportForGpuFileSourceScan(meta: SparkPlanMeta[FileSourceScanExec]): Unit = {
+ val format = meta.wrapped.relation.fileFormat
+ if (format.getClass == classOf[DeltaParquetFileFormat]) {
+ val deltaFormat = format.asInstanceOf[DeltaParquetFileFormat]
+ val requiredSchema = meta.wrapped.requiredSchema
+ if (requiredSchema.exists(_.name == IS_ROW_DELETED_COLUMN_NAME)) {
+ meta.willNotWorkOnGpu(
+ s"reading metadata column $IS_ROW_DELETED_COLUMN_NAME is not supported")
+ }
+ if (requiredSchema.exists(_.name == ROW_INDEX_COLUMN_NAME)) {
+ meta.willNotWorkOnGpu(
+ s"reading metadata column $ROW_INDEX_COLUMN_NAME is not supported")
+ }
+ if (deltaFormat.hasDeletionVectorMap()) {
+ meta.willNotWorkOnGpu("deletion vectors are not supported")
+ }
+ GpuReadParquetFileFormat.tagSupport(meta)
+ } else {
+ meta.willNotWorkOnGpu(s"format ${format.getClass} is not supported")
+ }
+ }
+
+ override def getReadFileFormat(format: FileFormat): FileFormat = {
+ val cpuFormat = format.asInstanceOf[DeltaParquetFileFormat]
+ GpuDelta23xParquetFileFormat(cpuFormat.metadata, cpuFormat.isSplittable)
+ }
+
+ override def convertToGpu(
+ cpuExec: AtomicCreateTableAsSelectExec,
+ meta: AtomicCreateTableAsSelectExecMeta): GpuExec = {
+ val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog]
+ GpuAtomicCreateTableAsSelectExec(
+ DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf),
+ cpuExec.ident,
+ cpuExec.partitioning,
+ cpuExec.plan,
+ meta.childPlans.head.convertIfNeeded(),
+ cpuExec.tableSpec,
+ cpuExec.writeOptions,
+ cpuExec.ifNotExists)
+ }
+
+ override def convertToGpu(
+ cpuExec: AtomicReplaceTableAsSelectExec,
+ meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = {
+ val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog]
+ GpuAtomicReplaceTableAsSelectExec(
+ DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf),
+ cpuExec.ident,
+ cpuExec.partitioning,
+ cpuExec.plan,
+ meta.childPlans.head.convertIfNeeded(),
+ cpuExec.tableSpec,
+ cpuExec.writeOptions,
+ cpuExec.orCreate,
+ cpuExec.invalidateCache)
+ }
+}
diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDelta23xParquetFileFormat.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDelta23xParquetFileFormat.scala
new file mode 100644
index 00000000000..0466a01506a
--- /dev/null
+++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDelta23xParquetFileFormat.scala
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta.delta23x
+
+import com.nvidia.spark.rapids.delta.GpuDeltaParquetFileFormat
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.delta.{DeltaColumnMappingMode, IdMapping}
+import org.apache.spark.sql.delta.actions.Metadata
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+case class GpuDelta23xParquetFileFormat(
+ metadata: Metadata,
+ isSplittable: Boolean) extends GpuDeltaParquetFileFormat {
+
+ override val columnMappingMode: DeltaColumnMappingMode = metadata.columnMappingMode
+ override val referenceSchema: StructType = metadata.schema
+
+ if (columnMappingMode == IdMapping) {
+ val requiredReadConf = SQLConf.PARQUET_FIELD_ID_READ_ENABLED
+ require(SparkSession.getActiveSession.exists(_.sessionState.conf.getConf(requiredReadConf)),
+ s"${requiredReadConf.key} must be enabled to support Delta id column mapping mode")
+ val requiredWriteConf = SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED
+ require(SparkSession.getActiveSession.exists(_.sessionState.conf.getConf(requiredWriteConf)),
+ s"${requiredWriteConf.key} must be enabled to support Delta id column mapping mode")
+ }
+
+ override def isSplitable(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ path: Path): Boolean = isSplittable
+
+ /**
+ * We sometimes need to replace FileFormat within LogicalPlans, so we have to override
+ * `equals` to ensure file format changes are captured
+ */
+ override def equals(other: Any): Boolean = {
+ other match {
+ case ff: GpuDelta23xParquetFileFormat =>
+ ff.columnMappingMode == columnMappingMode &&
+ ff.referenceSchema == referenceSchema &&
+ ff.isSplittable == isSplittable
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = getClass.getCanonicalName.hashCode()
+}
diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDeltaCatalog.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDeltaCatalog.scala
new file mode 100644
index 00000000000..d10b114299d
--- /dev/null
+++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDeltaCatalog.scala
@@ -0,0 +1,180 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from DeltaDataSource.scala in the
+ * Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta.delta23x
+
+import com.nvidia.spark.rapids.RapidsConf
+import java.util
+
+import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.catalog.CatalogTable
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, Table}
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.delta.catalog.DeltaCatalog
+import org.apache.spark.sql.delta.commands.TableCreationModes
+import org.apache.spark.sql.delta.metering.DeltaLogging
+import org.apache.spark.sql.delta.rapids.GpuDeltaCatalogBase
+import org.apache.spark.sql.delta.rapids.delta23x.GpuCreateDeltaTableCommand
+import org.apache.spark.sql.execution.command.LeafRunnableCommand
+import org.apache.spark.sql.types.StructType
+
+class GpuDeltaCatalog(
+ override val cpuCatalog: DeltaCatalog,
+ override val rapidsConf: RapidsConf)
+ extends GpuDeltaCatalogBase with DeltaLogging {
+
+ override val spark: SparkSession = cpuCatalog.spark
+
+ override protected def buildGpuCreateDeltaTableCommand(
+ rapidsConf: RapidsConf,
+ table: CatalogTable,
+ existingTableOpt: Option[CatalogTable],
+ mode: SaveMode,
+ query: Option[LogicalPlan],
+ operation: TableCreationModes.CreationMode,
+ tableByPath: Boolean): LeafRunnableCommand = {
+ GpuCreateDeltaTableCommand(
+ table,
+ existingTableOpt,
+ mode,
+ query,
+ operation,
+ tableByPath = tableByPath
+ )(rapidsConf)
+ }
+
+ override protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] = {
+ cpuCatalog.getExistingTableIfExists(table)
+ }
+
+ override protected def verifyTableAndSolidify(
+ tableDesc: CatalogTable,
+ query: Option[LogicalPlan]): CatalogTable = {
+ cpuCatalog.verifyTableAndSolidify(tableDesc, query)
+ }
+
+ override protected def createGpuStagedDeltaTableV2(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String],
+ operation: TableCreationModes.CreationMode): StagedTable = {
+ new GpuStagedDeltaTableV2WithLogging(ident, schema, partitions, properties, operation)
+ }
+
+ override def loadTable(ident: Identifier, timestamp: Long): Table = {
+ cpuCatalog.loadTable(ident, timestamp)
+ }
+
+ override def loadTable(ident: Identifier, version: String): Table = {
+ cpuCatalog.loadTable(ident, version)
+ }
+
+ /**
+ * Creates a Delta table using GPU for writing the data
+ *
+ * @param ident The identifier of the table
+ * @param schema The schema of the table
+ * @param partitions The partition transforms for the table
+ * @param allTableProperties The table properties that configure the behavior of the table or
+ * provide information about the table
+ * @param writeOptions Options specific to the write during table creation or replacement
+ * @param sourceQuery A query if this CREATE request came from a CTAS or RTAS
+ * @param operation The specific table creation mode, whether this is a
+ * Create/Replace/Create or Replace
+ */
+ override def createDeltaTable(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ allTableProperties: util.Map[String, String],
+ writeOptions: Map[String, String],
+ sourceQuery: Option[DataFrame],
+ operation: TableCreationModes.CreationMode
+ ): Table = recordFrameProfile(
+ "DeltaCatalog", "createDeltaTable") {
+ super.createDeltaTable(
+ ident,
+ schema,
+ partitions,
+ allTableProperties,
+ writeOptions,
+ sourceQuery,
+ operation)
+ }
+
+ override def createTable(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table =
+ recordFrameProfile("DeltaCatalog", "createTable") {
+ super.createTable(ident, schema, partitions, properties)
+ }
+
+ override def stageReplace(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): StagedTable =
+ recordFrameProfile("DeltaCatalog", "stageReplace") {
+ super.stageReplace(ident, schema, partitions, properties)
+ }
+
+ override def stageCreateOrReplace(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): StagedTable =
+ recordFrameProfile("DeltaCatalog", "stageCreateOrReplace") {
+ super.stageCreateOrReplace(ident, schema, partitions, properties)
+ }
+
+ override def stageCreate(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): StagedTable =
+ recordFrameProfile("DeltaCatalog", "stageCreate") {
+ super.stageCreate(ident, schema, partitions, properties)
+ }
+
+ /**
+ * A staged Delta table, which creates a HiveMetaStore entry and appends data if this was a
+ * CTAS/RTAS command. We have a ugly way of using this API right now, but it's the best way to
+ * maintain old behavior compatibility between Databricks Runtime and OSS Delta Lake.
+ */
+ protected class GpuStagedDeltaTableV2WithLogging(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String],
+ operation: TableCreationModes.CreationMode)
+ extends GpuStagedDeltaTableV2(ident, schema, partitions, properties, operation) {
+
+ override def commitStagedChanges(): Unit = recordFrameProfile(
+ "DeltaCatalog", "commitStagedChanges") {
+ super.commitStagedChanges()
+ }
+ }
+}
diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/MergeIntoCommandMeta.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/MergeIntoCommandMeta.scala
new file mode 100644
index 00000000000..8e1cd7dd490
--- /dev/null
+++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/MergeIntoCommandMeta.scala
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta.delta23x
+
+import com.nvidia.spark.rapids.{DataFromReplacementRule, RapidsConf, RapidsMeta, RunnableCommandMeta}
+import com.nvidia.spark.rapids.delta.RapidsDeltaUtils
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.delta.commands.MergeIntoCommand
+import org.apache.spark.sql.delta.rapids.GpuDeltaLog
+import org.apache.spark.sql.delta.rapids.delta23x.GpuMergeIntoCommand
+import org.apache.spark.sql.execution.command.RunnableCommand
+
+class MergeIntoCommandMeta(
+ mergeCmd: MergeIntoCommand,
+ conf: RapidsConf,
+ parent: Option[RapidsMeta[_, _, _]],
+ rule: DataFromReplacementRule)
+ extends RunnableCommandMeta[MergeIntoCommand](mergeCmd, conf, parent, rule) {
+
+ override def tagSelfForGpu(): Unit = {
+ if (!conf.isDeltaWriteEnabled) {
+ willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " +
+ s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
+ }
+ if (mergeCmd.notMatchedBySourceClauses.nonEmpty) {
+ // https://github.com/NVIDIA/spark-rapids/issues/8415
+ willNotWorkOnGpu("notMatchedBySourceClauses not supported on GPU")
+ }
+ val targetSchema = mergeCmd.migratedSchema.getOrElse(mergeCmd.target.schema)
+ val deltaLog = mergeCmd.targetFileIndex.deltaLog
+ RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, Some(deltaLog), Map.empty,
+ SparkSession.active)
+ }
+
+ override def convertToGpu(): RunnableCommand = {
+ GpuMergeIntoCommand(
+ mergeCmd.source,
+ mergeCmd.target,
+ new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf),
+ mergeCmd.condition,
+ mergeCmd.matchedClauses,
+ mergeCmd.notMatchedClauses,
+ mergeCmd.notMatchedBySourceClauses,
+ mergeCmd.migratedSchema)(conf)
+ }
+}
diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/UpdateCommandMeta.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/UpdateCommandMeta.scala
new file mode 100644
index 00000000000..1f1bcd93137
--- /dev/null
+++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/UpdateCommandMeta.scala
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta.delta23x
+
+import com.nvidia.spark.rapids.{DataFromReplacementRule, RapidsConf, RapidsMeta, RunnableCommandMeta}
+import com.nvidia.spark.rapids.delta.RapidsDeltaUtils
+
+import org.apache.spark.sql.delta.commands.UpdateCommand
+import org.apache.spark.sql.delta.rapids.GpuDeltaLog
+import org.apache.spark.sql.delta.rapids.delta23x.GpuUpdateCommand
+import org.apache.spark.sql.execution.command.RunnableCommand
+
+class UpdateCommandMeta(
+ updateCmd: UpdateCommand,
+ conf: RapidsConf,
+ parent: Option[RapidsMeta[_, _, _]],
+ rule: DataFromReplacementRule)
+ extends RunnableCommandMeta[UpdateCommand](updateCmd, conf, parent, rule) {
+
+ override def tagSelfForGpu(): Unit = {
+ if (!conf.isDeltaWriteEnabled) {
+ willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " +
+ s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
+ }
+ RapidsDeltaUtils.tagForDeltaWrite(this, updateCmd.target.schema,
+ Some(updateCmd.tahoeFileIndex.deltaLog), Map.empty, updateCmd.tahoeFileIndex.spark)
+ }
+
+ override def convertToGpu(): RunnableCommand = {
+ GpuUpdateCommand(
+ new GpuDeltaLog(updateCmd.tahoeFileIndex.deltaLog, conf),
+ updateCmd.tahoeFileIndex,
+ updateCmd.target,
+ updateCmd.updateExpressions,
+ updateCmd.condition
+ )
+ }
+}
diff --git a/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/ScalaStack.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala
similarity index 73%
rename from sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/ScalaStack.scala
rename to delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala
index b560de5463e..f4f9836e9f9 100644
--- a/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/ScalaStack.scala
+++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2023, NVIDIA CORPORATION.
+ * Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,9 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package com.nvidia.spark.rapids.delta.shims
-package com.nvidia.spark.rapids
+import org.apache.spark.sql.delta.stats.UsesMetadataFields
-import scala.collection.mutable.ArrayStack
-
-class ScalaStack[T] extends ArrayStack[T]
+trait ShimUsesMetadataFields extends UsesMetadataFields
\ No newline at end of file
diff --git a/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/Delta23xRuntimeShim.scala b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/Delta23xRuntimeShim.scala
new file mode 100644
index 00000000000..c12982d54fe
--- /dev/null
+++ b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/Delta23xRuntimeShim.scala
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.delta.rapids.delta23x
+
+import com.nvidia.spark.rapids.RapidsConf
+import com.nvidia.spark.rapids.delta.DeltaProvider
+import com.nvidia.spark.rapids.delta.delta23x.{Delta23xProvider, GpuDeltaCatalog}
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.catalog.StagingTableCatalog
+import org.apache.spark.sql.delta.{DeltaLog, DeltaUDF, Snapshot}
+import org.apache.spark.sql.delta.catalog.DeltaCatalog
+import org.apache.spark.sql.delta.rapids.{DeltaRuntimeShim, GpuOptimisticTransactionBase}
+import org.apache.spark.sql.execution.datasources.FileFormat
+import org.apache.spark.sql.expressions.UserDefinedFunction
+import org.apache.spark.util.Clock
+
+class Delta23xRuntimeShim extends DeltaRuntimeShim {
+ override def getDeltaProvider: DeltaProvider = Delta23xProvider
+
+ override def startTransaction(
+ log: DeltaLog,
+ conf: RapidsConf,
+ clock: Clock): GpuOptimisticTransactionBase = {
+ new GpuOptimisticTransaction(log, conf)(clock)
+ }
+
+ override def stringFromStringUdf(f: String => String): UserDefinedFunction = {
+ DeltaUDF.stringFromString(f)
+ }
+
+ override def unsafeVolatileSnapshotFromLog(deltaLog: DeltaLog): Snapshot = {
+ deltaLog.unsafeVolatileSnapshot
+ }
+
+ override def fileFormatFromLog(deltaLog: DeltaLog): FileFormat =
+ deltaLog.fileFormat(deltaLog.unsafeVolatileMetadata)
+
+ override def getTightBoundColumnOnFileInitDisabled(spark: SparkSession): Boolean = false
+
+ override def getGpuDeltaCatalog(
+ cpuCatalog: DeltaCatalog,
+ rapidsConf: RapidsConf): StagingTableCatalog = {
+ new GpuDeltaCatalog(cpuCatalog, rapidsConf)
+ }
+}
diff --git a/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuCreateDeltaTableCommand.scala b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuCreateDeltaTableCommand.scala
new file mode 100644
index 00000000000..83b6408d647
--- /dev/null
+++ b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuCreateDeltaTableCommand.scala
@@ -0,0 +1,480 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from CreateDeltaTableCommand.scala in the
+ * Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.delta.rapids.delta23x
+
+import com.nvidia.spark.rapids.RapidsConf
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.delta._
+import org.apache.spark.sql.delta.actions.{Action, Metadata, Protocol}
+import org.apache.spark.sql.delta.commands.{DeltaCommand, TableCreationModes, WriteIntoDelta}
+import org.apache.spark.sql.delta.metering.DeltaLogging
+import org.apache.spark.sql.delta.rapids.GpuDeltaLog
+import org.apache.spark.sql.delta.schema.SchemaUtils
+import org.apache.spark.sql.delta.sources.DeltaSQLConf
+import org.apache.spark.sql.execution.command.{LeafRunnableCommand, RunnableCommand}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Single entry point for all write or declaration operations for Delta tables accessed through
+ * the table name.
+ *
+ * @param table The table identifier for the Delta table
+ * @param existingTableOpt The existing table for the same identifier if exists
+ * @param mode The save mode when writing data. Relevant when the query is empty or set to Ignore
+ * with `CREATE TABLE IF NOT EXISTS`.
+ * @param query The query to commit into the Delta table if it exist. This can come from
+ * - CTAS
+ * - saveAsTable
+ * @param protocol This is used to create a table with specific protocol version
+ */
+case class GpuCreateDeltaTableCommand(
+ table: CatalogTable,
+ existingTableOpt: Option[CatalogTable],
+ mode: SaveMode,
+ query: Option[LogicalPlan],
+ operation: TableCreationModes.CreationMode = TableCreationModes.Create,
+ tableByPath: Boolean = false,
+ override val output: Seq[Attribute] = Nil,
+ protocol: Option[Protocol] = None)(@transient rapidsConf: RapidsConf)
+ extends LeafRunnableCommand
+ with DeltaCommand
+ with DeltaLogging {
+
+ override def otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf)
+
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val table = this.table
+
+ assert(table.tableType != CatalogTableType.VIEW)
+ assert(table.identifier.database.isDefined, "Database should've been fixed at analysis")
+ // There is a subtle race condition here, where the table can be created by someone else
+ // while this command is running. Nothing we can do about that though :(
+ val tableExists = existingTableOpt.isDefined
+ if (mode == SaveMode.Ignore && tableExists) {
+ // Early exit on ignore
+ return Nil
+ } else if (mode == SaveMode.ErrorIfExists && tableExists) {
+ throw DeltaErrors.tableAlreadyExists(table)
+ }
+
+ val tableWithLocation = if (tableExists) {
+ val existingTable = existingTableOpt.get
+ table.storage.locationUri match {
+ case Some(location) if location.getPath != existingTable.location.getPath =>
+ throw DeltaErrors.tableLocationMismatch(table, existingTable)
+ case _ =>
+ }
+ table.copy(
+ storage = existingTable.storage,
+ tableType = existingTable.tableType)
+ } else if (table.storage.locationUri.isEmpty) {
+ // We are defining a new managed table
+ assert(table.tableType == CatalogTableType.MANAGED)
+ val loc = sparkSession.sessionState.catalog.defaultTablePath(table.identifier)
+ table.copy(storage = table.storage.copy(locationUri = Some(loc)))
+ } else {
+ // 1. We are defining a new external table
+ // 2. It's a managed table which already has the location populated. This can happen in DSV2
+ // CTAS flow.
+ table
+ }
+
+ val isManagedTable = tableWithLocation.tableType == CatalogTableType.MANAGED
+ val tableLocation = new Path(tableWithLocation.location)
+ val gpuDeltaLog = GpuDeltaLog.forTable(sparkSession, tableLocation, rapidsConf)
+ val hadoopConf = gpuDeltaLog.deltaLog.newDeltaHadoopConf()
+ val fs = tableLocation.getFileSystem(hadoopConf)
+ val options = new DeltaOptions(table.storage.properties, sparkSession.sessionState.conf)
+ var result: Seq[Row] = Nil
+
+ recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.ddl.createTable") {
+ val txn = gpuDeltaLog.startTransaction()
+ val opStartTs = System.currentTimeMillis()
+ if (query.isDefined) {
+ // If the mode is Ignore or ErrorIfExists, the table must not exist, or we would return
+ // earlier. And the data should not exist either, to match the behavior of
+ // Ignore/ErrorIfExists mode. This means the table path should not exist or is empty.
+ if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) {
+ assert(!tableExists)
+ // We may have failed a previous write. The retry should still succeed even if we have
+ // garbage data
+ if (txn.readVersion > -1 || !fs.exists(gpuDeltaLog.deltaLog.logPath)) {
+ assertPathEmpty(hadoopConf, tableWithLocation)
+ }
+ }
+
+ // Execute write command for `deltaWriter` by
+ // - replacing the metadata new target table for DataFrameWriterV2 writer if it is a
+ // REPLACE or CREATE_OR_REPLACE command,
+ // - running the write procedure of DataFrameWriter command and returning the
+ // new created actions,
+ // - returning the Delta Operation type of this DataFrameWriter
+ def doDeltaWrite(
+ deltaWriter: WriteIntoDelta,
+ schema: StructType): (Seq[Action], DeltaOperations.Operation) = {
+ // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that
+ // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1.
+ if (!isV1Writer) {
+ replaceMetadataIfNecessary(
+ txn, tableWithLocation, options, schema)
+ }
+ val actions = deltaWriter.write(txn, sparkSession)
+ val op = getOperation(txn.metadata, isManagedTable, Some(options))
+ (actions, op)
+ }
+
+ // We are either appending/overwriting with saveAsTable or creating a new table with CTAS or
+ // we are creating a table as part of a RunnableCommand
+ query.get match {
+ case deltaWriter: WriteIntoDelta =>
+ if (!hasBeenExecuted(txn, sparkSession, Some(options))) {
+ val (actions, op) = doDeltaWrite(deltaWriter, deltaWriter.data.schema.asNullable)
+ txn.commit(actions, op)
+ }
+ case cmd: RunnableCommand =>
+ result = cmd.run(sparkSession)
+ case other =>
+ // When using V1 APIs, the `other` plan is not yet optimized, therefore, it is safe
+ // to once again go through analysis
+ val data = Dataset.ofRows(sparkSession, other)
+ val deltaWriter = WriteIntoDelta(
+ deltaLog = gpuDeltaLog.deltaLog,
+ mode = mode,
+ options,
+ partitionColumns = table.partitionColumnNames,
+ configuration = tableWithLocation.properties + ("comment" -> table.comment.orNull),
+ data = data)
+ if (!hasBeenExecuted(txn, sparkSession, Some(options))) {
+ val (actions, op) = doDeltaWrite(deltaWriter, other.schema.asNullable)
+ txn.commit(actions, op)
+ }
+ }
+ } else {
+ def createTransactionLogOrVerify(): Unit = {
+ if (isManagedTable) {
+ // When creating a managed table, the table path should not exist or is empty, or
+ // users would be surprised to see the data, or see the data directory being dropped
+ // after the table is dropped.
+ assertPathEmpty(hadoopConf, tableWithLocation)
+ }
+
+ // This is either a new table, or, we never defined the schema of the table. While it is
+ // unexpected that `txn.metadata.schema` to be empty when txn.readVersion >= 0, we still
+ // guard against it, in case of checkpoint corruption bugs.
+ val noExistingMetadata = txn.readVersion == -1 || txn.metadata.schema.isEmpty
+ if (noExistingMetadata) {
+ assertTableSchemaDefined(fs, tableLocation, tableWithLocation, txn, sparkSession)
+ assertPathEmpty(hadoopConf, tableWithLocation)
+ // This is a user provided schema.
+ // Doesn't come from a query, Follow nullability invariants.
+ val newMetadata = getProvidedMetadata(tableWithLocation, table.schema.json)
+ txn.updateMetadataForNewTable(newMetadata)
+ protocol.foreach { protocol =>
+ txn.updateProtocol(protocol)
+ }
+ val op = getOperation(newMetadata, isManagedTable, None)
+ txn.commit(Nil, op)
+ } else {
+ verifyTableMetadata(txn, tableWithLocation)
+ }
+ }
+ // We are defining a table using the Create or Replace Table statements.
+ operation match {
+ case TableCreationModes.Create =>
+ require(!tableExists, "Can't recreate a table when it exists")
+ createTransactionLogOrVerify()
+
+ case TableCreationModes.CreateOrReplace if !tableExists =>
+ // If the table doesn't exist, CREATE OR REPLACE must provide a schema
+ if (tableWithLocation.schema.isEmpty) {
+ throw DeltaErrors.schemaNotProvidedException
+ }
+ createTransactionLogOrVerify()
+ case _ =>
+ // When the operation is a REPLACE or CREATE OR REPLACE, then the schema shouldn't be
+ // empty, since we'll use the entry to replace the schema
+ if (tableWithLocation.schema.isEmpty) {
+ throw DeltaErrors.schemaNotProvidedException
+ }
+ // We need to replace
+ replaceMetadataIfNecessary(txn, tableWithLocation, options, tableWithLocation.schema)
+ // Truncate the table
+ val operationTimestamp = System.currentTimeMillis()
+ val removes = txn.filterFiles().map(_.removeWithTimestamp(operationTimestamp))
+ val op = getOperation(txn.metadata, isManagedTable, None)
+ txn.commit(removes, op)
+ }
+ }
+
+ // We would have failed earlier on if we couldn't ignore the existence of the table
+ // In addition, we just might using saveAsTable to append to the table, so ignore the creation
+ // if it already exists.
+ // Note that someone may have dropped and recreated the table in a separate location in the
+ // meantime... Unfortunately we can't do anything there at the moment, because Hive sucks.
+ logInfo(s"Table is path-based table: $tableByPath. Update catalog with mode: $operation")
+ updateCatalog(
+ sparkSession,
+ tableWithLocation,
+ gpuDeltaLog.deltaLog.update(checkIfUpdatedSinceTs = Some(opStartTs)),
+ txn)
+
+ result
+ }
+ }
+
+ private def getProvidedMetadata(table: CatalogTable, schemaString: String): Metadata = {
+ Metadata(
+ description = table.comment.orNull,
+ schemaString = schemaString,
+ partitionColumns = table.partitionColumnNames,
+ configuration = table.properties,
+ createdTime = Some(System.currentTimeMillis()))
+ }
+
+ private def assertPathEmpty(
+ hadoopConf: Configuration,
+ tableWithLocation: CatalogTable): Unit = {
+ val path = new Path(tableWithLocation.location)
+ val fs = path.getFileSystem(hadoopConf)
+ // Verify that the table location associated with CREATE TABLE doesn't have any data. Note that
+ // we intentionally diverge from this behavior w.r.t regular datasource tables (that silently
+ // overwrite any previous data)
+ if (fs.exists(path) && fs.listStatus(path).nonEmpty) {
+ throw DeltaErrors.createTableWithNonEmptyLocation(
+ tableWithLocation.identifier.toString,
+ tableWithLocation.location.toString)
+ }
+ }
+
+ private def assertTableSchemaDefined(
+ fs: FileSystem,
+ path: Path,
+ table: CatalogTable,
+ txn: OptimisticTransaction,
+ sparkSession: SparkSession): Unit = {
+ // If we allow creating an empty schema table and indeed the table is new, we just need to
+ // make sure:
+ // 1. txn.readVersion == -1 to read a new table
+ // 2. for external tables: path must either doesn't exist or is completely empty
+ val allowCreatingTableWithEmptySchema = sparkSession.sessionState
+ .conf.getConf(DeltaSQLConf.DELTA_ALLOW_CREATE_EMPTY_SCHEMA_TABLE) && txn.readVersion == -1
+
+ // Users did not specify the schema. We expect the schema exists in Delta.
+ if (table.schema.isEmpty) {
+ if (table.tableType == CatalogTableType.EXTERNAL) {
+ if (fs.exists(path) && fs.listStatus(path).nonEmpty) {
+ throw DeltaErrors.createExternalTableWithoutLogException(
+ path, table.identifier.quotedString, sparkSession)
+ } else {
+ if (allowCreatingTableWithEmptySchema) return
+ throw DeltaErrors.createExternalTableWithoutSchemaException(
+ path, table.identifier.quotedString, sparkSession)
+ }
+ } else {
+ if (allowCreatingTableWithEmptySchema) return
+ throw DeltaErrors.createManagedTableWithoutSchemaException(
+ table.identifier.quotedString, sparkSession)
+ }
+ }
+ }
+
+ /**
+ * Verify against our transaction metadata that the user specified the right metadata for the
+ * table.
+ */
+ private def verifyTableMetadata(
+ txn: OptimisticTransaction,
+ tableDesc: CatalogTable): Unit = {
+ val existingMetadata = txn.metadata
+ val path = new Path(tableDesc.location)
+
+ // The delta log already exists. If they give any configuration, we'll make sure it all matches.
+ // Otherwise we'll just go with the metadata already present in the log.
+ // The schema compatibility checks will be made in `WriteIntoDelta` for CreateTable
+ // with a query
+ if (txn.readVersion > -1) {
+ if (tableDesc.schema.nonEmpty) {
+ // We check exact alignment on create table if everything is provided
+ // However, if in column mapping mode, we can safely ignore the related metadata fields in
+ // existing metadata because new table desc will not have related metadata assigned yet
+ val differences = SchemaUtils.reportDifferences(
+ DeltaColumnMapping.dropColumnMappingMetadata(existingMetadata.schema),
+ tableDesc.schema)
+ if (differences.nonEmpty) {
+ throw DeltaErrors.createTableWithDifferentSchemaException(
+ path, tableDesc.schema, existingMetadata.schema, differences)
+ }
+ }
+
+ // If schema is specified, we must make sure the partitioning matches, even the partitioning
+ // is not specified.
+ if (tableDesc.schema.nonEmpty &&
+ tableDesc.partitionColumnNames != existingMetadata.partitionColumns) {
+ throw DeltaErrors.createTableWithDifferentPartitioningException(
+ path, tableDesc.partitionColumnNames, existingMetadata.partitionColumns)
+ }
+
+ if (tableDesc.properties.nonEmpty && tableDesc.properties != existingMetadata.configuration) {
+ throw DeltaErrors.createTableWithDifferentPropertiesException(
+ path, tableDesc.properties, existingMetadata.configuration)
+ }
+ }
+ }
+
+ /**
+ * Based on the table creation operation, and parameters, we can resolve to different operations.
+ * A lot of this is needed for legacy reasons in Databricks Runtime.
+ * @param metadata The table metadata, which we are creating or replacing
+ * @param isManagedTable Whether we are creating or replacing a managed table
+ * @param options Write options, if this was a CTAS/RTAS
+ */
+ private def getOperation(
+ metadata: Metadata,
+ isManagedTable: Boolean,
+ options: Option[DeltaOptions]): DeltaOperations.Operation = operation match {
+ // This is legacy saveAsTable behavior in Databricks Runtime
+ case TableCreationModes.Create if existingTableOpt.isDefined && query.isDefined =>
+ DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere,
+ options.flatMap(_.userMetadata))
+
+ // DataSourceV2 table creation
+ // CREATE TABLE (non-DataFrameWriter API) doesn't have options syntax
+ // (userMetadata uses SQLConf in this case)
+ case TableCreationModes.Create =>
+ DeltaOperations.CreateTable(metadata, isManagedTable, query.isDefined)
+
+ // DataSourceV2 table replace
+ // REPLACE TABLE (non-DataFrameWriter API) doesn't have options syntax
+ // (userMetadata uses SQLConf in this case)
+ case TableCreationModes.Replace =>
+ DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = false, query.isDefined)
+
+ // Legacy saveAsTable with Overwrite mode
+ case TableCreationModes.CreateOrReplace if options.exists(_.replaceWhere.isDefined) =>
+ DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere,
+ options.flatMap(_.userMetadata))
+
+ // New DataSourceV2 saveAsTable with overwrite mode behavior
+ case TableCreationModes.CreateOrReplace =>
+ DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = true, query.isDefined,
+ options.flatMap(_.userMetadata))
+ }
+
+ /**
+ * Similar to getOperation, here we disambiguate the catalog alterations we need to do based
+ * on the table operation, and whether we have reached here through legacy code or DataSourceV2
+ * code paths.
+ */
+ private def updateCatalog(
+ spark: SparkSession,
+ table: CatalogTable,
+ snapshot: Snapshot,
+ txn: OptimisticTransaction): Unit = {
+ val cleaned = cleanupTableDefinition(spark, table, snapshot)
+ operation match {
+ case _ if tableByPath => // do nothing with the metastore if this is by path
+ case TableCreationModes.Create =>
+ spark.sessionState.catalog.createTable(
+ cleaned,
+ ignoreIfExists = existingTableOpt.isDefined,
+ validateLocation = false)
+ case TableCreationModes.Replace | TableCreationModes.CreateOrReplace
+ if existingTableOpt.isDefined =>
+ spark.sessionState.catalog.alterTable(table)
+ case TableCreationModes.Replace =>
+ val ident = Identifier.of(table.identifier.database.toArray, table.identifier.table)
+ throw DeltaErrors.cannotReplaceMissingTableException(ident)
+ case TableCreationModes.CreateOrReplace =>
+ spark.sessionState.catalog.createTable(
+ cleaned,
+ ignoreIfExists = false,
+ validateLocation = false)
+ }
+ }
+
+ /** Clean up the information we pass on to store in the catalog. */
+ private def cleanupTableDefinition(spark: SparkSession, table: CatalogTable, snapshot: Snapshot)
+ : CatalogTable = {
+ // These actually have no effect on the usability of Delta, but feature flagging legacy
+ // behavior for now
+ val storageProps = if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) {
+ // Legacy behavior
+ table.storage
+ } else {
+ table.storage.copy(properties = Map.empty)
+ }
+
+ table.copy(
+ schema = new StructType(),
+ properties = Map.empty,
+ partitionColumnNames = Nil,
+ // Remove write specific options when updating the catalog
+ storage = storageProps,
+ tracksPartitionsInCatalog = true)
+ }
+
+ /**
+ * With DataFrameWriterV2, methods like `replace()` or `createOrReplace()` mean that the
+ * metadata of the table should be replaced. If overwriteSchema=false is provided with these
+ * methods, then we will verify that the metadata match exactly.
+ */
+ private def replaceMetadataIfNecessary(
+ txn: OptimisticTransaction,
+ tableDesc: CatalogTable,
+ options: DeltaOptions,
+ schema: StructType): Unit = {
+ val isReplace = (operation == TableCreationModes.CreateOrReplace ||
+ operation == TableCreationModes.Replace)
+ // If a user explicitly specifies not to overwrite the schema, during a replace, we should
+ // tell them that it's not supported
+ val dontOverwriteSchema = options.options.contains(DeltaOptions.OVERWRITE_SCHEMA_OPTION) &&
+ !options.canOverwriteSchema
+ if (isReplace && dontOverwriteSchema) {
+ throw DeltaErrors.illegalUsageException(DeltaOptions.OVERWRITE_SCHEMA_OPTION, "replacing")
+ }
+ if (txn.readVersion > -1L && isReplace && !dontOverwriteSchema) {
+ // When a table already exists, and we're using the DataFrameWriterV2 API to replace
+ // or createOrReplace a table, we blindly overwrite the metadata.
+ txn.updateMetadataForNewTable(getProvidedMetadata(table, schema.json))
+ }
+ }
+
+ /**
+ * Horrible hack to differentiate between DataFrameWriterV1 and V2 so that we can decide
+ * what to do with table metadata. In DataFrameWriterV1, mode("overwrite").saveAsTable,
+ * behaves as a CreateOrReplace table, but we have asked for "overwriteSchema" as an
+ * explicit option to overwrite partitioning or schema information. With DataFrameWriterV2,
+ * the behavior asked for by the user is clearer: .createOrReplace(), which means that we
+ * should overwrite schema and/or partitioning. Therefore we have this hack.
+ */
+ private def isV1Writer: Boolean = {
+ Thread.currentThread().getStackTrace.exists(_.toString.contains(
+ classOf[DataFrameWriter[_]].getCanonicalName + "."))
+ }
+}
diff --git a/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuDeleteCommand.scala b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuDeleteCommand.scala
new file mode 100644
index 00000000000..1ddd00bc046
--- /dev/null
+++ b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuDeleteCommand.scala
@@ -0,0 +1,381 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from DeleteCommand.scala
+ * in the Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.delta.rapids.delta23x
+
+import com.nvidia.spark.rapids.delta.GpuDeltaMetricUpdateUDF
+
+import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, Not}
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.delta.{DeltaConfigs, DeltaLog, DeltaOperations, DeltaTableUtils, DeltaUDF, OptimisticTransaction}
+import org.apache.spark.sql.delta.actions.{Action, AddCDCFile, FileAction}
+import org.apache.spark.sql.delta.commands.{DeleteCommandMetrics, DeleteMetric, DeletionVectorUtils, DeltaCommand}
+import org.apache.spark.sql.delta.commands.DeleteCommand.{rewritingFilesMsg, FINDING_TOUCHED_FILES_MSG}
+import org.apache.spark.sql.delta.commands.MergeIntoCommand.totalBytesAndDistinctPartitionValues
+import org.apache.spark.sql.delta.files.TahoeBatchFileIndex
+import org.apache.spark.sql.delta.rapids.GpuDeltaLog
+import org.apache.spark.sql.delta.sources.DeltaSQLConf
+import org.apache.spark.sql.delta.util.Utils
+import org.apache.spark.sql.execution.command.LeafRunnableCommand
+import org.apache.spark.sql.functions.input_file_name
+import org.apache.spark.sql.types.LongType
+
+/**
+ * GPU version of Delta Lake DeleteCommand.
+ *
+ * Performs a Delete based on the search condition
+ *
+ * Algorithm:
+ * 1) Scan all the files and determine which files have
+ * the rows that need to be deleted.
+ * 2) Traverse the affected files and rebuild the touched files.
+ * 3) Use the Delta protocol to atomically write the remaining rows to new files and remove
+ * the affected files that are identified in step 1.
+ */
+case class GpuDeleteCommand(
+ gpuDeltaLog: GpuDeltaLog,
+ target: LogicalPlan,
+ condition: Option[Expression])
+ extends LeafRunnableCommand with DeltaCommand with DeleteCommandMetrics {
+
+ override def innerChildren: Seq[QueryPlan[_]] = Seq(target)
+
+ override val output: Seq[Attribute] = Seq(AttributeReference("num_affected_rows", LongType)())
+
+ override lazy val metrics = createMetrics
+
+ final override def run(sparkSession: SparkSession): Seq[Row] = {
+ val deltaLog = gpuDeltaLog.deltaLog
+ recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.dml.delete") {
+ gpuDeltaLog.withNewTransaction { txn =>
+ DeltaLog.assertRemovable(txn.snapshot)
+ if (hasBeenExecuted(txn, sparkSession)) {
+ sendDriverMetrics(sparkSession, metrics)
+ return Seq.empty
+ }
+
+ val deleteActions = performDelete(sparkSession, deltaLog, txn)
+ txn.commitIfNeeded(deleteActions, DeltaOperations.Delete(condition.map(_.sql).toSeq))
+ }
+ // Re-cache all cached plans(including this relation itself, if it's cached) that refer to
+ // this data source relation.
+ sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, target)
+ }
+
+ // Adjust for deletes at partition boundaries. Deletes at partition boundaries is a metadata
+ // operation, therefore we don't actually have any information around how many rows were deleted
+ // While this info may exist in the file statistics, it's not guaranteed that we have these
+ // statistics. To avoid any performance regressions, we currently just return a -1 in such cases
+ if (metrics("numRemovedFiles").value > 0 && metrics("numDeletedRows").value == 0) {
+ Seq(Row(-1L))
+ } else {
+ Seq(Row(metrics("numDeletedRows").value))
+ }
+ }
+
+ def performDelete(
+ sparkSession: SparkSession,
+ deltaLog: DeltaLog,
+ txn: OptimisticTransaction): Seq[Action] = {
+ import org.apache.spark.sql.delta.implicits._
+
+ var numRemovedFiles: Long = 0
+ var numAddedFiles: Long = 0
+ var numAddedChangeFiles: Long = 0
+ var scanTimeMs: Long = 0
+ var rewriteTimeMs: Long = 0
+ var numAddedBytes: Long = 0
+ var changeFileBytes: Long = 0
+ var numRemovedBytes: Long = 0
+ var numFilesBeforeSkipping: Long = 0
+ var numBytesBeforeSkipping: Long = 0
+ var numFilesAfterSkipping: Long = 0
+ var numBytesAfterSkipping: Long = 0
+ var numPartitionsAfterSkipping: Option[Long] = None
+ var numPartitionsRemovedFrom: Option[Long] = None
+ var numPartitionsAddedTo: Option[Long] = None
+ var numDeletedRows: Option[Long] = None
+ var numCopiedRows: Option[Long] = None
+
+ val startTime = System.nanoTime()
+ val numFilesTotal = txn.snapshot.numOfFiles
+
+ val deleteActions: Seq[Action] = condition match {
+ case None =>
+ // Case 1: Delete the whole table if the condition is true
+ val reportRowLevelMetrics = conf.getConf(DeltaSQLConf.DELTA_DML_METRICS_FROM_METADATA)
+ val allFiles = txn.filterFiles(Nil, keepNumRecords = reportRowLevelMetrics)
+
+ numRemovedFiles = allFiles.size
+ scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000
+ val (numBytes, numPartitions) = totalBytesAndDistinctPartitionValues(allFiles)
+ numRemovedBytes = numBytes
+ numFilesBeforeSkipping = numRemovedFiles
+ numBytesBeforeSkipping = numBytes
+ numFilesAfterSkipping = numRemovedFiles
+ numBytesAfterSkipping = numBytes
+ numDeletedRows = getDeletedRowsFromAddFilesAndUpdateMetrics(allFiles)
+
+ if (txn.metadata.partitionColumns.nonEmpty) {
+ numPartitionsAfterSkipping = Some(numPartitions)
+ numPartitionsRemovedFrom = Some(numPartitions)
+ numPartitionsAddedTo = Some(0)
+ }
+ val operationTimestamp = System.currentTimeMillis()
+ allFiles.map(_.removeWithTimestamp(operationTimestamp))
+ case Some(cond) =>
+ val (metadataPredicates, otherPredicates) =
+ DeltaTableUtils.splitMetadataAndDataPredicates(
+ cond, txn.metadata.partitionColumns, sparkSession)
+
+ numFilesBeforeSkipping = txn.snapshot.numOfFiles
+ numBytesBeforeSkipping = txn.snapshot.sizeInBytes
+
+ if (otherPredicates.isEmpty) {
+ // Case 2: The condition can be evaluated using metadata only.
+ // Delete a set of files without the need of scanning any data files.
+ val operationTimestamp = System.currentTimeMillis()
+ val reportRowLevelMetrics = conf.getConf(DeltaSQLConf.DELTA_DML_METRICS_FROM_METADATA)
+ val candidateFiles =
+ txn.filterFiles(metadataPredicates, keepNumRecords = reportRowLevelMetrics)
+
+ scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000
+ numRemovedFiles = candidateFiles.size
+ numRemovedBytes = candidateFiles.map(_.size).sum
+ numFilesAfterSkipping = candidateFiles.size
+ val (numCandidateBytes, numCandidatePartitions) =
+ totalBytesAndDistinctPartitionValues(candidateFiles)
+ numBytesAfterSkipping = numCandidateBytes
+ numDeletedRows = getDeletedRowsFromAddFilesAndUpdateMetrics(candidateFiles)
+
+ if (txn.metadata.partitionColumns.nonEmpty) {
+ numPartitionsAfterSkipping = Some(numCandidatePartitions)
+ numPartitionsRemovedFrom = Some(numCandidatePartitions)
+ numPartitionsAddedTo = Some(0)
+ }
+ candidateFiles.map(_.removeWithTimestamp(operationTimestamp))
+ } else {
+ // Case 3: Delete the rows based on the condition.
+
+ // Should we write the DVs to represent the deleted rows?
+ val shouldWriteDVs = shouldWritePersistentDeletionVectors(sparkSession, txn)
+
+ val candidateFiles = txn.filterFiles(
+ metadataPredicates ++ otherPredicates,
+ keepNumRecords = shouldWriteDVs)
+ // `candidateFiles` contains the files filtered using statistics and delete condition
+ // They may or may not contains any rows that need to be deleted.
+
+ numFilesAfterSkipping = candidateFiles.size
+ val (numCandidateBytes, numCandidatePartitions) =
+ totalBytesAndDistinctPartitionValues(candidateFiles)
+ numBytesAfterSkipping = numCandidateBytes
+ if (txn.metadata.partitionColumns.nonEmpty) {
+ numPartitionsAfterSkipping = Some(numCandidatePartitions)
+ }
+
+ val nameToAddFileMap = generateCandidateFileMap(deltaLog.dataPath, candidateFiles)
+
+ val fileIndex = new TahoeBatchFileIndex(
+ sparkSession, "delete", candidateFiles, deltaLog, deltaLog.dataPath, txn.snapshot)
+ if (shouldWriteDVs) {
+ // this should be unreachable because we fall back to CPU
+ // if deletion vectors are enabled. The tracking issue for adding deletion vector
+ // support is https://github.com/NVIDIA/spark-rapids/issues/8554
+ throw new IllegalStateException("Deletion vectors are not supported on GPU")
+ } else {
+ // Keep everything from the resolved target except a new TahoeFileIndex
+ // that only involves the affected files instead of all files.
+ val newTarget = DeltaTableUtils.replaceFileIndex(target, fileIndex)
+ val data = Dataset.ofRows(sparkSession, newTarget)
+ val deletedRowCount = metrics("numDeletedRows")
+ val deletedRowUdf = DeltaUDF.boolean {
+ new GpuDeltaMetricUpdateUDF(deletedRowCount)
+ }.asNondeterministic()
+ val filesToRewrite =
+ withStatusCode("DELTA", FINDING_TOUCHED_FILES_MSG) {
+ if (candidateFiles.isEmpty) {
+ Array.empty[String]
+ } else {
+ data.filter(new Column(cond))
+ .select(input_file_name())
+ .filter(deletedRowUdf())
+ .distinct()
+ .as[String]
+ .collect()
+ }
+ }
+
+ numRemovedFiles = filesToRewrite.length
+ scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000
+ if (filesToRewrite.isEmpty) {
+ // Case 3.1: no row matches and no delete will be triggered
+ if (txn.metadata.partitionColumns.nonEmpty) {
+ numPartitionsRemovedFrom = Some(0)
+ numPartitionsAddedTo = Some(0)
+ }
+ Nil
+ } else {
+ // Case 3.2: some files need an update to remove the deleted files
+ // Do the second pass and just read the affected files
+ val baseRelation = buildBaseRelation(
+ sparkSession, txn, "delete", deltaLog.dataPath, filesToRewrite, nameToAddFileMap)
+ // Keep everything from the resolved target except a new TahoeFileIndex
+ // that only involves the affected files instead of all files.
+ val newTarget = DeltaTableUtils.replaceFileIndex(target, baseRelation.location)
+ val targetDF = Dataset.ofRows(sparkSession, newTarget)
+ val filterCond = Not(EqualNullSafe(cond, Literal.TrueLiteral))
+ val rewrittenActions = rewriteFiles(txn, targetDF, filterCond, filesToRewrite.length)
+ val (changeFiles, rewrittenFiles) = rewrittenActions
+ .partition(_.isInstanceOf[AddCDCFile])
+ numAddedFiles = rewrittenFiles.size
+ val removedFiles = filesToRewrite.map(f =>
+ getTouchedFile(deltaLog.dataPath, f, nameToAddFileMap))
+ val (removedBytes, removedPartitions) =
+ totalBytesAndDistinctPartitionValues(removedFiles)
+ numRemovedBytes = removedBytes
+ val (rewrittenBytes, rewrittenPartitions) =
+ totalBytesAndDistinctPartitionValues(rewrittenFiles)
+ numAddedBytes = rewrittenBytes
+ if (txn.metadata.partitionColumns.nonEmpty) {
+ numPartitionsRemovedFrom = Some(removedPartitions)
+ numPartitionsAddedTo = Some(rewrittenPartitions)
+ }
+ numAddedChangeFiles = changeFiles.size
+ changeFileBytes = changeFiles.collect { case f: AddCDCFile => f.size }.sum
+ rewriteTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 - scanTimeMs
+ numDeletedRows = Some(metrics("numDeletedRows").value)
+ numCopiedRows =
+ Some(metrics("numTouchedRows").value - metrics("numDeletedRows").value)
+
+ val operationTimestamp = System.currentTimeMillis()
+ removeFilesFromPaths(
+ deltaLog, nameToAddFileMap, filesToRewrite, operationTimestamp) ++ rewrittenActions
+ }
+ }
+ }
+ }
+ metrics("numRemovedFiles").set(numRemovedFiles)
+ metrics("numAddedFiles").set(numAddedFiles)
+ val executionTimeMs = (System.nanoTime() - startTime) / 1000 / 1000
+ metrics("executionTimeMs").set(executionTimeMs)
+ metrics("scanTimeMs").set(scanTimeMs)
+ metrics("rewriteTimeMs").set(rewriteTimeMs)
+ metrics("numAddedChangeFiles").set(numAddedChangeFiles)
+ metrics("changeFileBytes").set(changeFileBytes)
+ metrics("numAddedBytes").set(numAddedBytes)
+ metrics("numRemovedBytes").set(numRemovedBytes)
+ metrics("numFilesBeforeSkipping").set(numFilesBeforeSkipping)
+ metrics("numBytesBeforeSkipping").set(numBytesBeforeSkipping)
+ metrics("numFilesAfterSkipping").set(numFilesAfterSkipping)
+ metrics("numBytesAfterSkipping").set(numBytesAfterSkipping)
+ numPartitionsAfterSkipping.foreach(metrics("numPartitionsAfterSkipping").set)
+ numPartitionsAddedTo.foreach(metrics("numPartitionsAddedTo").set)
+ numPartitionsRemovedFrom.foreach(metrics("numPartitionsRemovedFrom").set)
+ numCopiedRows.foreach(metrics("numCopiedRows").set)
+ txn.registerSQLMetrics(sparkSession, metrics)
+ sendDriverMetrics(sparkSession, metrics)
+
+ recordDeltaEvent(
+ deltaLog,
+ "delta.dml.delete.stats",
+ data = DeleteMetric(
+ condition = condition.map(_.sql).getOrElse("true"),
+ numFilesTotal,
+ numFilesAfterSkipping,
+ numAddedFiles,
+ numRemovedFiles,
+ numAddedFiles,
+ numAddedChangeFiles = numAddedChangeFiles,
+ numFilesBeforeSkipping,
+ numBytesBeforeSkipping,
+ numFilesAfterSkipping,
+ numBytesAfterSkipping,
+ numPartitionsAfterSkipping,
+ numPartitionsAddedTo,
+ numPartitionsRemovedFrom,
+ numCopiedRows,
+ numDeletedRows,
+ numAddedBytes,
+ numRemovedBytes,
+ changeFileBytes = changeFileBytes,
+ scanTimeMs,
+ rewriteTimeMs)
+ )
+
+ if (deleteActions.nonEmpty) {
+ createSetTransaction(sparkSession, deltaLog).toSeq ++ deleteActions
+ } else {
+ Seq.empty
+ }
+ }
+
+ /**
+ * Returns the list of `AddFile`s and `AddCDCFile`s that have been re-written.
+ */
+ private def rewriteFiles(
+ txn: OptimisticTransaction,
+ baseData: DataFrame,
+ filterCondition: Expression,
+ numFilesToRewrite: Long): Seq[FileAction] = {
+ val shouldWriteCdc = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(txn.metadata)
+
+ // number of total rows that we have seen / are either copying or deleting (sum of both).
+ val numTouchedRows = metrics("numTouchedRows")
+ val numTouchedRowsUdf = DeltaUDF.boolean {
+ new GpuDeltaMetricUpdateUDF(numTouchedRows)
+ }.asNondeterministic()
+
+ withStatusCode(
+ "DELTA", rewritingFilesMsg(numFilesToRewrite)) {
+ val dfToWrite = if (shouldWriteCdc) {
+ import org.apache.spark.sql.delta.commands.cdc.CDCReader._
+ // The logic here ends up being surprisingly elegant, with all source rows ending up in
+ // the output. Recall that we flipped the user-provided delete condition earlier, before the
+ // call to `rewriteFiles`. All rows which match this latest `filterCondition` are retained
+ // as table data, while all rows which don't match are removed from the rewritten table data
+ // but do get included in the output as CDC events.
+ baseData
+ .filter(numTouchedRowsUdf())
+ .withColumn(
+ CDC_TYPE_COLUMN_NAME,
+ new Column(If(filterCondition, CDC_TYPE_NOT_CDC, CDC_TYPE_DELETE))
+ )
+ } else {
+ baseData
+ .filter(numTouchedRowsUdf())
+ .filter(new Column(filterCondition))
+ }
+
+ txn.writeFiles(dfToWrite)
+ }
+ }
+
+ def shouldWritePersistentDeletionVectors(
+ spark: SparkSession, txn: OptimisticTransaction): Boolean = {
+ // DELETE with DVs only enabled for tests.
+ Utils.isTesting &&
+ spark.conf.get(DeltaSQLConf.DELETE_USE_PERSISTENT_DELETION_VECTORS) &&
+ DeletionVectorUtils.deletionVectorsWritable(txn.snapshot)
+ }
+}
diff --git a/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuMergeIntoCommand.scala b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuMergeIntoCommand.scala
new file mode 100644
index 00000000000..c4eef959208
--- /dev/null
+++ b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuMergeIntoCommand.scala
@@ -0,0 +1,1331 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from MergeIntoCommand.scala
+ * in the Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.delta.rapids.delta23x
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import com.fasterxml.jackson.databind.annotation.JsonDeserialize
+import com.nvidia.spark.rapids.{BaseExprMeta, GpuOverrides, RapidsConf}
+import com.nvidia.spark.rapids.delta._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, BasePredicate, Expression, Literal, NamedExpression, PredicateHelper, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.delta._
+import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, FileAction}
+import org.apache.spark.sql.delta.commands.DeltaCommand
+import org.apache.spark.sql.delta.commands.merge.MergeIntoMaterializeSource
+import org.apache.spark.sql.delta.rapids.GpuDeltaLog
+import org.apache.spark.sql.delta.schema.{ImplicitMetadataOperation, SchemaUtils}
+import org.apache.spark.sql.delta.sources.DeltaSQLConf
+import org.apache.spark.sql.delta.util.{AnalysisHelper, SetAccumulator}
+import org.apache.spark.sql.execution.command.LeafRunnableCommand
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DataTypes, LongType, StringType, StructType}
+
+case class GpuMergeDataSizes(
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ rows: Option[Long] = None,
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ files: Option[Long] = None,
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ bytes: Option[Long] = None,
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ partitions: Option[Long] = None)
+
+/**
+ * Represents the state of a single merge clause:
+ * - merge clause's (optional) predicate
+ * - action type (insert, update, delete)
+ * - action's expressions
+ */
+case class GpuMergeClauseStats(
+ condition: Option[String],
+ actionType: String,
+ actionExpr: Seq[String])
+
+object GpuMergeClauseStats {
+ def apply(mergeClause: DeltaMergeIntoClause): GpuMergeClauseStats = {
+ GpuMergeClauseStats(
+ condition = mergeClause.condition.map(_.sql),
+ mergeClause.clauseType.toLowerCase(),
+ actionExpr = mergeClause.actions.map(_.sql))
+ }
+}
+
+/** State for a GPU merge operation */
+case class GpuMergeStats(
+ // Merge condition expression
+ conditionExpr: String,
+
+ // Expressions used in old MERGE stats, now always Null
+ updateConditionExpr: String,
+ updateExprs: Seq[String],
+ insertConditionExpr: String,
+ insertExprs: Seq[String],
+ deleteConditionExpr: String,
+
+ // Newer expressions used in MERGE with any number of MATCHED/NOT MATCHED/NOT MATCHED BY SOURCE
+ matchedStats: Seq[GpuMergeClauseStats],
+ notMatchedStats: Seq[GpuMergeClauseStats],
+ notMatchedBySourceStats: Seq[GpuMergeClauseStats],
+
+ // Timings
+ executionTimeMs: Long,
+ scanTimeMs: Long,
+ rewriteTimeMs: Long,
+
+ // Data sizes of source and target at different stages of processing
+ source: GpuMergeDataSizes,
+ targetBeforeSkipping: GpuMergeDataSizes,
+ targetAfterSkipping: GpuMergeDataSizes,
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ sourceRowsInSecondScan: Option[Long],
+
+ // Data change sizes
+ targetFilesRemoved: Long,
+ targetFilesAdded: Long,
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ targetChangeFilesAdded: Option[Long],
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ targetChangeFileBytes: Option[Long],
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ targetBytesRemoved: Option[Long],
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ targetBytesAdded: Option[Long],
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ targetPartitionsRemovedFrom: Option[Long],
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ targetPartitionsAddedTo: Option[Long],
+ targetRowsCopied: Long,
+ targetRowsUpdated: Long,
+ targetRowsMatchedUpdated: Long,
+ targetRowsNotMatchedBySourceUpdated: Long,
+ targetRowsInserted: Long,
+ targetRowsDeleted: Long,
+ targetRowsMatchedDeleted: Long,
+ targetRowsNotMatchedBySourceDeleted: Long,
+
+ // MergeMaterializeSource stats
+ materializeSourceReason: Option[String] = None,
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ materializeSourceAttempts: Option[Long] = None
+)
+
+object GpuMergeStats {
+
+ def fromMergeSQLMetrics(
+ metrics: Map[String, SQLMetric],
+ condition: Expression,
+ matchedClauses: Seq[DeltaMergeIntoMatchedClause],
+ notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause],
+ notMatchedBySourceClauses: Seq[DeltaMergeIntoNotMatchedBySourceClause],
+ isPartitioned: Boolean): GpuMergeStats = {
+
+ def metricValueIfPartitioned(metricName: String): Option[Long] = {
+ if (isPartitioned) Some(metrics(metricName).value) else None
+ }
+
+ GpuMergeStats(
+ // Merge condition expression
+ conditionExpr = condition.sql,
+
+ // Newer expressions used in MERGE with any number of MATCHED/NOT MATCHED/
+ // NOT MATCHED BY SOURCE
+ matchedStats = matchedClauses.map(GpuMergeClauseStats(_)),
+ notMatchedStats = notMatchedClauses.map(GpuMergeClauseStats(_)),
+ notMatchedBySourceStats = notMatchedBySourceClauses.map(GpuMergeClauseStats(_)),
+
+ // Timings
+ executionTimeMs = metrics("executionTimeMs").value,
+ scanTimeMs = metrics("scanTimeMs").value,
+ rewriteTimeMs = metrics("rewriteTimeMs").value,
+
+ // Data sizes of source and target at different stages of processing
+ source = GpuMergeDataSizes(rows = Some(metrics("numSourceRows").value)),
+ targetBeforeSkipping =
+ GpuMergeDataSizes(
+ files = Some(metrics("numTargetFilesBeforeSkipping").value),
+ bytes = Some(metrics("numTargetBytesBeforeSkipping").value)),
+ targetAfterSkipping =
+ GpuMergeDataSizes(
+ files = Some(metrics("numTargetFilesAfterSkipping").value),
+ bytes = Some(metrics("numTargetBytesAfterSkipping").value),
+ partitions = metricValueIfPartitioned("numTargetPartitionsAfterSkipping")),
+ sourceRowsInSecondScan =
+ metrics.get("numSourceRowsInSecondScan").map(_.value).filter(_ >= 0),
+
+ // Data change sizes
+ targetFilesAdded = metrics("numTargetFilesAdded").value,
+ targetChangeFilesAdded = metrics.get("numTargetChangeFilesAdded").map(_.value),
+ targetChangeFileBytes = metrics.get("numTargetChangeFileBytes").map(_.value),
+ targetFilesRemoved = metrics("numTargetFilesRemoved").value,
+ targetBytesAdded = Some(metrics("numTargetBytesAdded").value),
+ targetBytesRemoved = Some(metrics("numTargetBytesRemoved").value),
+ targetPartitionsRemovedFrom = metricValueIfPartitioned("numTargetPartitionsRemovedFrom"),
+ targetPartitionsAddedTo = metricValueIfPartitioned("numTargetPartitionsAddedTo"),
+ targetRowsCopied = metrics("numTargetRowsCopied").value,
+ targetRowsUpdated = metrics("numTargetRowsUpdated").value,
+ targetRowsMatchedUpdated = metrics("numTargetRowsMatchedUpdated").value,
+ targetRowsNotMatchedBySourceUpdated = metrics("numTargetRowsNotMatchedBySourceUpdated").value,
+ targetRowsInserted = metrics("numTargetRowsInserted").value,
+ targetRowsDeleted = metrics("numTargetRowsDeleted").value,
+ targetRowsMatchedDeleted = metrics("numTargetRowsMatchedDeleted").value,
+ targetRowsNotMatchedBySourceDeleted = metrics("numTargetRowsNotMatchedBySourceDeleted").value,
+
+ // Deprecated fields
+ updateConditionExpr = null,
+ updateExprs = null,
+ insertConditionExpr = null,
+ insertExprs = null,
+ deleteConditionExpr = null)
+ }
+}
+
+/**
+ * GPU version of Delta Lake's MergeIntoCommand.
+ *
+ * Performs a merge of a source query/table into a Delta table.
+ *
+ * Issues an error message when the ON search_condition of the MERGE statement can match
+ * a single row from the target table with multiple rows of the source table-reference.
+ *
+ * Algorithm:
+ *
+ * Phase 1: Find the input files in target that are touched by the rows that satisfy
+ * the condition and verify that no two source rows match with the same target row.
+ * This is implemented as an inner-join using the given condition. See [[findTouchedFiles]]
+ * for more details.
+ *
+ * Phase 2: Read the touched files again and write new files with updated and/or inserted rows.
+ *
+ * Phase 3: Use the Delta protocol to atomically remove the touched files and add the new files.
+ *
+ * @param source Source data to merge from
+ * @param target Target table to merge into
+ * @param gpuDeltaLog Delta log to use
+ * @param condition Condition for a source row to match with a target row
+ * @param matchedClauses All info related to matched clauses.
+ * @param notMatchedClauses All info related to not matched clauses.
+ * @param notMatchedBySourceClauses All info related to not matched by source clauses.
+ * @param migratedSchema The final schema of the target - may be changed by schema
+ * evolution.
+ */
+case class GpuMergeIntoCommand(
+ @transient source: LogicalPlan,
+ @transient target: LogicalPlan,
+ @transient gpuDeltaLog: GpuDeltaLog,
+ condition: Expression,
+ matchedClauses: Seq[DeltaMergeIntoMatchedClause],
+ notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause],
+ notMatchedBySourceClauses: Seq[DeltaMergeIntoNotMatchedBySourceClause],
+ migratedSchema: Option[StructType])(
+ @transient val rapidsConf: RapidsConf)
+ extends LeafRunnableCommand
+ with DeltaCommand
+ with PredicateHelper
+ with AnalysisHelper
+ with ImplicitMetadataOperation
+ with MergeIntoMaterializeSource {
+
+ import GpuMergeIntoCommand._
+
+ import SQLMetrics._
+ import org.apache.spark.sql.delta.commands.cdc.CDCReader._
+
+ override val otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf)
+
+ override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE)
+ override val canOverwriteSchema: Boolean = false
+
+ override val output: Seq[Attribute] = Seq(
+ AttributeReference("num_affected_rows", LongType)(),
+ AttributeReference("num_updated_rows", LongType)(),
+ AttributeReference("num_deleted_rows", LongType)(),
+ AttributeReference("num_inserted_rows", LongType)())
+
+ @transient private lazy val sc: SparkContext = SparkContext.getOrCreate()
+ @transient private lazy val targetDeltaLog: DeltaLog = gpuDeltaLog.deltaLog
+ /**
+ * Map to get target output attributes by name.
+ * The case sensitivity of the map is set accordingly to Spark configuration.
+ */
+ @transient private lazy val targetOutputAttributesMap: Map[String, Attribute] = {
+ val attrMap: Map[String, Attribute] = target
+ .outputSet.view
+ .map(attr => attr.name -> attr).toMap
+ if (conf.caseSensitiveAnalysis) {
+ attrMap
+ } else {
+ CaseInsensitiveMap(attrMap)
+ }
+ }
+
+ /** Whether this merge statement has only a single insert (NOT MATCHED) clause. */
+ private def isSingleInsertOnly: Boolean =
+ matchedClauses.isEmpty && notMatchedBySourceClauses.isEmpty && notMatchedClauses.length == 1
+ /** Whether this merge statement has no insert (NOT MATCHED) clause. */
+ private def hasNoInserts: Boolean = notMatchedClauses.isEmpty
+
+ // We over-count numTargetRowsDeleted when there are multiple matches;
+ // this is the amount of the overcount, so we can subtract it to get a correct final metric.
+ private var multipleMatchDeleteOnlyOvercount: Option[Long] = None
+
+ override lazy val metrics = Map[String, SQLMetric](
+ "numSourceRows" -> createMetric(sc, "number of source rows"),
+ "numSourceRowsInSecondScan" ->
+ createMetric(sc, "number of source rows (during repeated scan)"),
+ "numTargetRowsCopied" -> createMetric(sc, "number of target rows rewritten unmodified"),
+ "numTargetRowsInserted" -> createMetric(sc, "number of inserted rows"),
+ "numTargetRowsUpdated" -> createMetric(sc, "number of updated rows"),
+ "numTargetRowsMatchedUpdated" ->
+ createMetric(sc, "number of rows updated by a matched clause"),
+ "numTargetRowsNotMatchedBySourceUpdated" ->
+ createMetric(sc, "number of rows updated by a not matched by source clause"),
+ "numTargetRowsDeleted" -> createMetric(sc, "number of deleted rows"),
+ "numTargetRowsMatchedDeleted" ->
+ createMetric(sc, "number of rows deleted by a matched clause"),
+ "numTargetRowsNotMatchedBySourceDeleted" ->
+ createMetric(sc, "number of rows deleted by a not matched by source clause"),
+ "numTargetFilesBeforeSkipping" -> createMetric(sc, "number of target files before skipping"),
+ "numTargetFilesAfterSkipping" -> createMetric(sc, "number of target files after skipping"),
+ "numTargetFilesRemoved" -> createMetric(sc, "number of files removed to target"),
+ "numTargetFilesAdded" -> createMetric(sc, "number of files added to target"),
+ "numTargetChangeFilesAdded" ->
+ createMetric(sc, "number of change data capture files generated"),
+ "numTargetChangeFileBytes" ->
+ createMetric(sc, "total size of change data capture files generated"),
+ "numTargetBytesBeforeSkipping" -> createMetric(sc, "number of target bytes before skipping"),
+ "numTargetBytesAfterSkipping" -> createMetric(sc, "number of target bytes after skipping"),
+ "numTargetBytesRemoved" -> createMetric(sc, "number of target bytes removed"),
+ "numTargetBytesAdded" -> createMetric(sc, "number of target bytes added"),
+ "numTargetPartitionsAfterSkipping" ->
+ createMetric(sc, "number of target partitions after skipping"),
+ "numTargetPartitionsRemovedFrom" ->
+ createMetric(sc, "number of target partitions from which files were removed"),
+ "numTargetPartitionsAddedTo" ->
+ createMetric(sc, "number of target partitions to which files were added"),
+ "executionTimeMs" ->
+ createTimingMetric(sc, "time taken to execute the entire operation"),
+ "scanTimeMs" ->
+ createTimingMetric(sc, "time taken to scan the files for matches"),
+ "rewriteTimeMs" ->
+ createTimingMetric(sc, "time taken to rewrite the matched files"))
+
+ override def run(spark: SparkSession): Seq[Row] = {
+ metrics("executionTimeMs").set(0)
+ metrics("scanTimeMs").set(0)
+ metrics("rewriteTimeMs").set(0)
+
+ if (migratedSchema.isDefined) {
+ // Block writes of void columns in the Delta log. Currently void columns are not properly
+ // supported and are dropped on read, but this is not enough for merge command that is also
+ // reading the schema from the Delta log. Until proper support we prefer to fail merge
+ // queries that add void columns.
+ val newNullColumn = SchemaUtils.findNullTypeColumn(migratedSchema.get)
+ if (newNullColumn.isDefined) {
+ throw new AnalysisException(
+ s"""Cannot add column '${newNullColumn.get}' with type 'void'. Please explicitly specify a
+ |non-void type.""".stripMargin.replaceAll("\n", " ")
+ )
+ }
+ }
+ val (materializeSource, _) = shouldMaterializeSource(spark, source, isSingleInsertOnly)
+ if (!materializeSource) {
+ runMerge(spark)
+ } else {
+ // If it is determined that source should be materialized, wrap the execution with retries,
+ // in case the data of the materialized source is lost.
+ runWithMaterializedSourceLostRetries(
+ spark, targetDeltaLog, metrics, runMerge)
+ }
+ }
+
+ protected def runMerge(spark: SparkSession): Seq[Row] = {
+ recordDeltaOperation(targetDeltaLog, "delta.dml.merge") {
+ val startTime = System.nanoTime()
+ gpuDeltaLog.withNewTransaction { deltaTxn =>
+ if (hasBeenExecuted(deltaTxn, spark)) {
+ sendDriverMetrics(spark, metrics)
+ return Seq.empty
+ }
+ if (target.schema.size != deltaTxn.metadata.schema.size) {
+ throw DeltaErrors.schemaChangedSinceAnalysis(
+ atAnalysis = target.schema, latestSchema = deltaTxn.metadata.schema)
+ }
+
+ if (canMergeSchema) {
+ updateMetadata(
+ spark, deltaTxn, migratedSchema.getOrElse(target.schema),
+ deltaTxn.metadata.partitionColumns, deltaTxn.metadata.configuration,
+ isOverwriteMode = false, rearrangeOnly = false)
+ }
+
+ // If materialized, prepare the DF reading the materialize source
+ // Otherwise, prepare a regular DF from source plan.
+ val materializeSourceReason = prepareSourceDFAndReturnMaterializeReason(
+ spark,
+ source,
+ condition,
+ matchedClauses,
+ notMatchedClauses,
+ isSingleInsertOnly)
+
+ val deltaActions = {
+ if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) {
+ writeInsertsOnlyWhenNoMatchedClauses(spark, deltaTxn)
+ } else {
+ val filesToRewrite = findTouchedFiles(spark, deltaTxn)
+ val newWrittenFiles = withStatusCode("DELTA", "Writing merged data") {
+ writeAllChanges(spark, deltaTxn, filesToRewrite)
+ }
+ filesToRewrite.map(_.remove) ++ newWrittenFiles
+ }
+ }
+
+ val finalActions = createSetTransaction(spark, targetDeltaLog).toSeq ++ deltaActions
+ // Metrics should be recorded before commit (where they are written to delta logs).
+ metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000)
+ deltaTxn.registerSQLMetrics(spark, metrics)
+
+ // This is a best-effort sanity check.
+ if (metrics("numSourceRowsInSecondScan").value >= 0 &&
+ metrics("numSourceRows").value != metrics("numSourceRowsInSecondScan").value) {
+ log.warn(s"Merge source has ${metrics("numSourceRows")} rows in initial scan but " +
+ s"${metrics("numSourceRowsInSecondScan")} rows in second scan")
+ if (conf.getConf(DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED)) {
+ throw DeltaErrors.sourceNotDeterministicInMergeException(spark)
+ }
+ }
+
+ deltaTxn.commitIfNeeded(
+ finalActions,
+ DeltaOperations.Merge(
+ Option(condition.sql),
+ matchedClauses.map(DeltaOperations.MergePredicate(_)),
+ notMatchedClauses.map(DeltaOperations.MergePredicate(_)),
+ notMatchedBySourceClauses.map(DeltaOperations.MergePredicate(_))))
+
+ // Record metrics
+ var stats = GpuMergeStats.fromMergeSQLMetrics(
+ metrics,
+ condition,
+ matchedClauses,
+ notMatchedClauses,
+ notMatchedBySourceClauses,
+ deltaTxn.metadata.partitionColumns.nonEmpty)
+ stats = stats.copy(
+ materializeSourceReason = Some(materializeSourceReason.toString),
+ materializeSourceAttempts = Some(attempt))
+
+ recordDeltaEvent(targetDeltaLog, "delta.dml.merge.stats", data = stats)
+
+ }
+ spark.sharedState.cacheManager.recacheByPlan(spark, target)
+ }
+ sendDriverMetrics(spark, metrics)
+ Seq(Row(metrics("numTargetRowsUpdated").value + metrics("numTargetRowsDeleted").value +
+ metrics("numTargetRowsInserted").value, metrics("numTargetRowsUpdated").value,
+ metrics("numTargetRowsDeleted").value, metrics("numTargetRowsInserted").value))
+ }
+
+ /**
+ * Find the target table files that contain the rows that satisfy the merge condition. This is
+ * implemented as an inner-join between the source query/table and the target table using
+ * the merge condition.
+ */
+ private def findTouchedFiles(
+ spark: SparkSession,
+ deltaTxn: OptimisticTransaction
+ ): Seq[AddFile] = recordMergeOperation(sqlMetricName = "scanTimeMs") {
+
+ // Accumulator to collect all the distinct touched files
+ val touchedFilesAccum = new SetAccumulator[String]()
+ spark.sparkContext.register(touchedFilesAccum, TOUCHED_FILES_ACCUM_NAME)
+
+ // UDFs to records touched files names and add them to the accumulator
+ val recordTouchedFileName = DeltaUDF.intFromString(
+ new GpuDeltaRecordTouchedFileNameUDF(touchedFilesAccum)).asNondeterministic()
+
+ // Prune non-matching files if we don't need to collect them for NOT MATCHED BY SOURCE clauses.
+ val dataSkippedFiles =
+ if (notMatchedBySourceClauses.isEmpty) {
+ val targetOnlyPredicates =
+ splitConjunctivePredicates(condition).filter(_.references.subsetOf(target.outputSet))
+ deltaTxn.filterFiles(targetOnlyPredicates)
+ } else {
+ deltaTxn.filterFiles()
+ }
+
+ // UDF to increment metrics
+ val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows")
+ val sourceDF = getSourceDF()
+ .filter(new Column(incrSourceRowCountExpr))
+
+ // Join the source and target table using the merge condition to find touched files. An inner
+ // join collects all candidate files for MATCHED clauses, a right outer join also includes
+ // candidates for NOT MATCHED BY SOURCE clauses.
+ // In addition, we attach two columns
+ // - a monotonically increasing row id for target rows to later identify whether the same
+ // target row is modified by multiple user or not
+ // - the target file name the row is from to later identify the files touched by matched rows
+ val joinType = if (notMatchedBySourceClauses.isEmpty) "inner" else "right_outer"
+ val targetDF = buildTargetPlanWithFiles(spark, deltaTxn, dataSkippedFiles)
+ .withColumn(ROW_ID_COL, monotonically_increasing_id())
+ .withColumn(FILE_NAME_COL, input_file_name())
+ val joinToFindTouchedFiles = sourceDF.join(targetDF, new Column(condition), joinType)
+
+ // Process the matches from the inner join to record touched files and find multiple matches
+ val collectTouchedFiles = joinToFindTouchedFiles
+ .select(col(ROW_ID_COL), recordTouchedFileName(col(FILE_NAME_COL)).as("one"))
+
+ // Calculate frequency of matches per source row
+ val matchedRowCounts = collectTouchedFiles.groupBy(ROW_ID_COL).agg(sum("one").as("count"))
+
+ // Get multiple matches and simultaneously collect (using touchedFilesAccum) the file names
+ // multipleMatchCount = # of target rows with more than 1 matching source row (duplicate match)
+ // multipleMatchSum = total # of duplicate matched rows
+ import org.apache.spark.sql.delta.implicits._
+ val (multipleMatchCount, multipleMatchSum) = matchedRowCounts
+ .filter("count > 1")
+ .select(coalesce(count(new Column("*")), lit(0)), coalesce(sum("count"), lit(0)))
+ .as[(Long, Long)]
+ .collect()
+ .head
+
+ val hasMultipleMatches = multipleMatchCount > 0
+
+ // Throw error if multiple matches are ambiguous or cannot be computed correctly.
+ val canBeComputedUnambiguously = {
+ // Multiple matches are not ambiguous when there is only one unconditional delete as
+ // all the matched row pairs in the 2nd join in `writeAllChanges` will get deleted.
+ val isUnconditionalDelete = matchedClauses.headOption match {
+ case Some(DeltaMergeIntoMatchedDeleteClause(None)) => true
+ case _ => false
+ }
+ matchedClauses.size == 1 && isUnconditionalDelete
+ }
+
+ if (hasMultipleMatches && !canBeComputedUnambiguously) {
+ throw DeltaErrors.multipleSourceRowMatchingTargetRowInMergeException(spark)
+ }
+
+ if (hasMultipleMatches) {
+ // This is only allowed for delete-only queries.
+ // This query will count the duplicates for numTargetRowsDeleted in Job 2,
+ // because we count matches after the join and not just the target rows.
+ // We have to compensate for this by subtracting the duplicates later,
+ // so we need to record them here.
+ val duplicateCount = multipleMatchSum - multipleMatchCount
+ multipleMatchDeleteOnlyOvercount = Some(duplicateCount)
+ }
+
+ // Get the AddFiles using the touched file names.
+ val touchedFileNames = touchedFilesAccum.value.iterator().asScala.toSeq
+ logTrace(s"findTouchedFiles: matched files:\n\t${touchedFileNames.mkString("\n\t")}")
+
+ val nameToAddFileMap = generateCandidateFileMap(targetDeltaLog.dataPath, dataSkippedFiles)
+ val touchedAddFiles = touchedFileNames.map(f =>
+ getTouchedFile(targetDeltaLog.dataPath, f, nameToAddFileMap))
+
+ // When the target table is empty, and the optimizer optimized away the join entirely
+ // numSourceRows will be incorrectly 0. We need to scan the source table once to get the correct
+ // metric here.
+ if (metrics("numSourceRows").value == 0 &&
+ (dataSkippedFiles.isEmpty || targetDF.take(1).isEmpty)) {
+ val numSourceRows = sourceDF.count()
+ metrics("numSourceRows").set(numSourceRows)
+ }
+
+ // Update metrics
+ metrics("numTargetFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles
+ metrics("numTargetBytesBeforeSkipping") += deltaTxn.snapshot.sizeInBytes
+ val (afterSkippingBytes, afterSkippingPartitions) =
+ totalBytesAndDistinctPartitionValues(dataSkippedFiles)
+ metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size
+ metrics("numTargetBytesAfterSkipping") += afterSkippingBytes
+ metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions
+ val (removedBytes, removedPartitions) = totalBytesAndDistinctPartitionValues(touchedAddFiles)
+ metrics("numTargetFilesRemoved") += touchedAddFiles.size
+ metrics("numTargetBytesRemoved") += removedBytes
+ metrics("numTargetPartitionsRemovedFrom") += removedPartitions
+ touchedAddFiles
+ }
+
+ /**
+ * This is an optimization of the case when there is no update clause for the merge.
+ * We perform an left anti join on the source data to find the rows to be inserted.
+ *
+ * This will currently only optimize for the case when there is a _single_ notMatchedClause.
+ */
+ private def writeInsertsOnlyWhenNoMatchedClauses(
+ spark: SparkSession,
+ deltaTxn: OptimisticTransaction
+ ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") {
+
+ // UDFs to update metrics
+ val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows")
+ val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted")
+
+ val outputColNames = getTargetOutputCols(deltaTxn).map(_.name)
+ // we use head here since we know there is only a single notMatchedClause
+ val outputExprs = notMatchedClauses.head.resolvedActions.map(_.expr)
+ val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) =>
+ new Column(Alias(expr, name)())
+ }
+
+ // source DataFrame
+ val sourceDF = getSourceDF()
+ .filter(new Column(incrSourceRowCountExpr))
+ .filter(new Column(notMatchedClauses.head.condition.getOrElse(Literal.TrueLiteral)))
+
+ // Skip data based on the merge condition
+ val conjunctivePredicates = splitConjunctivePredicates(condition)
+ val targetOnlyPredicates =
+ conjunctivePredicates.filter(_.references.subsetOf(target.outputSet))
+ val dataSkippedFiles = deltaTxn.filterFiles(targetOnlyPredicates)
+
+ // target DataFrame
+ val targetDF = buildTargetPlanWithFiles(spark, deltaTxn, dataSkippedFiles)
+
+ val insertDf = sourceDF.join(targetDF, new Column(condition), "leftanti")
+ .select(outputCols: _*)
+ .filter(new Column(incrInsertedCountExpr))
+
+ val newFiles = deltaTxn
+ .writeFiles(repartitionIfNeeded(spark, insertDf, deltaTxn.metadata.partitionColumns))
+ .filter {
+ // In some cases (e.g. insert-only when all rows are matched, insert-only with an empty
+ // source, insert-only with an unsatisfied condition) we can write out an empty insertDf.
+ // This is hard to catch before the write without collecting the DF ahead of time. Instead,
+ // we can just accept only the AddFiles that actually add rows or
+ // when we don't know the number of records
+ case a: AddFile => a.numLogicalRecords.forall(_ > 0)
+ case _ => true
+ }
+
+ // Update metrics
+ metrics("numTargetFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles
+ metrics("numTargetBytesBeforeSkipping") += deltaTxn.snapshot.sizeInBytes
+ val (afterSkippingBytes, afterSkippingPartitions) =
+ totalBytesAndDistinctPartitionValues(dataSkippedFiles)
+ metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size
+ metrics("numTargetBytesAfterSkipping") += afterSkippingBytes
+ metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions
+ metrics("numTargetFilesRemoved") += 0
+ metrics("numTargetBytesRemoved") += 0
+ metrics("numTargetPartitionsRemovedFrom") += 0
+ val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles)
+ metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile])
+ metrics("numTargetBytesAdded") += addedBytes
+ metrics("numTargetPartitionsAddedTo") += addedPartitions
+ newFiles
+ }
+
+ /**
+ * Write new files by reading the touched files and updating/inserting data using the source
+ * query/table. This is implemented using a full|right-outer-join using the merge condition.
+ *
+ * Note that unlike the insert-only code paths with just one control column INCR_ROW_COUNT_COL,
+ * this method has two additional control columns ROW_DROPPED_COL for dropping deleted rows and
+ * CDC_TYPE_COL_NAME used for handling CDC when enabled.
+ */
+ private def writeAllChanges(
+ spark: SparkSession,
+ deltaTxn: OptimisticTransaction,
+ filesToRewrite: Seq[AddFile]
+ ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") {
+ import org.apache.spark.sql.catalyst.expressions.Literal.{TrueLiteral, FalseLiteral}
+
+ val cdcEnabled = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(deltaTxn.metadata)
+
+ var targetOutputCols = getTargetOutputCols(deltaTxn)
+ var outputRowSchema = deltaTxn.metadata.schema
+
+ // When we have duplicate matches (only allowed when the whenMatchedCondition is a delete with
+ // no match condition) we will incorrectly generate duplicate CDC rows.
+ // Duplicate matches can be due to:
+ // - Duplicate rows in the source w.r.t. the merge condition
+ // - A target-only or source-only merge condition, which essentially turns our join into a cross
+ // join with the target/source satisfiying the merge condition.
+ // These duplicate matches are dropped from the main data output since this is a delete
+ // operation, but the duplicate CDC rows are not removed by default.
+ // See https://github.com/delta-io/delta/issues/1274
+
+ // We address this specific scenario by adding row ids to the target before performing our join.
+ // There should only be one CDC delete row per target row so we can use these row ids to dedupe
+ // the duplicate CDC delete rows.
+
+ // We also need to address the scenario when there are duplicate matches with delete and we
+ // insert duplicate rows. Here we need to additionally add row ids to the source before the
+ // join to avoid dropping these valid duplicate inserted rows and their corresponding cdc rows.
+
+ // When there is an insert clause, we set SOURCE_ROW_ID_COL=null for all delete rows because we
+ // need to drop the duplicate matches.
+ val isDeleteWithDuplicateMatchesAndCdc = multipleMatchDeleteOnlyOvercount.nonEmpty && cdcEnabled
+
+ // Generate a new target dataframe that has same output attributes exprIds as the target plan.
+ // This allows us to apply the existing resolved update/insert expressions.
+ val baseTargetDF = buildTargetPlanWithFiles(spark, deltaTxn, filesToRewrite)
+ val joinType = if (hasNoInserts &&
+ spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED)) {
+ "rightOuter"
+ } else {
+ "fullOuter"
+ }
+
+ logDebug(s"""writeAllChanges using $joinType join:
+ | source.output: ${source.outputSet}
+ | target.output: ${target.outputSet}
+ | condition: $condition
+ | newTarget.output: ${baseTargetDF.queryExecution.logical.outputSet}
+ """.stripMargin)
+
+ // UDFs to update metrics
+ // Make UDFs that appear in the custom join processor node deterministic, as they always
+ // return true and update a metric. Catalyst precludes non-deterministic UDFs that are not
+ // allowed outside a very specific set of Catalyst nodes (Project, Filter, Window, Aggregate).
+ val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRowsInSecondScan",
+ deterministic = true)
+ val incrUpdatedCountExpr = makeMetricUpdateUDF("numTargetRowsUpdated", deterministic = true)
+ val incrUpdatedMatchedCountExpr = makeMetricUpdateUDF("numTargetRowsMatchedUpdated",
+ deterministic = true)
+ val incrUpdatedNotMatchedBySourceCountExpr =
+ makeMetricUpdateUDF("numTargetRowsNotMatchedBySourceUpdated", deterministic = true)
+ val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted", deterministic = true)
+ val incrNoopCountExpr = makeMetricUpdateUDF("numTargetRowsCopied", deterministic = true)
+ val incrDeletedCountExpr = makeMetricUpdateUDF("numTargetRowsDeleted", deterministic = true)
+ val incrDeletedMatchedCountExpr = makeMetricUpdateUDF("numTargetRowsMatchedDeleted",
+ deterministic = true)
+ val incrDeletedNotMatchedBySourceCountExpr =
+ makeMetricUpdateUDF("numTargetRowsNotMatchedBySourceDeleted", deterministic = true)
+
+ // Apply an outer join to find both, matches and non-matches. We are adding two boolean fields
+ // with value `true`, one to each side of the join. Whether this field is null or not after
+ // the outer join, will allow us to identify whether the resultant joined row was a
+ // matched inner result or an unmatched result with null on one side.
+ // We add row IDs to the targetDF if we have a delete-when-matched clause with duplicate
+ // matches and CDC is enabled, and additionally add row IDs to the source if we also have an
+ // insert clause. See above at isDeleteWithDuplicateMatchesAndCdc definition for more details.
+ var sourceDF = getSourceDF()
+ .withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr))
+ var targetDF = baseTargetDF
+ .withColumn(TARGET_ROW_PRESENT_COL, lit(true))
+ if (isDeleteWithDuplicateMatchesAndCdc) {
+ targetDF = targetDF.withColumn(TARGET_ROW_ID_COL, monotonically_increasing_id())
+ if (notMatchedClauses.nonEmpty) { // insert clause
+ sourceDF = sourceDF.withColumn(SOURCE_ROW_ID_COL, monotonically_increasing_id())
+ }
+ }
+ val joinedDF = sourceDF.join(targetDF, new Column(condition), joinType)
+ val joinedPlan = joinedDF.queryExecution.analyzed
+
+ def resolveOnJoinedPlan(exprs: Seq[Expression]): Seq[Expression] = {
+ tryResolveReferencesForExpressions(spark, exprs, joinedPlan)
+ }
+
+ // ==== Generate the expressions to process full-outer join output and generate target rows ====
+ // If there are N columns in the target table, there will be N + 3 columns after processing
+ // - N columns for target table
+ // - ROW_DROPPED_COL to define whether the generated row should dropped or written
+ // - INCR_ROW_COUNT_COL containing a UDF to update the output row row counter
+ // - CDC_TYPE_COLUMN_NAME containing the type of change being performed in a particular row
+
+ // To generate these N + 3 columns, we will generate N + 3 expressions and apply them to the
+ // rows in the joinedDF. The CDC column will be either used for CDC generation or dropped before
+ // performing the final write, and the other two will always be dropped after executing the
+ // metrics UDF and filtering on ROW_DROPPED_COL.
+
+ // We produce rows for both the main table data (with CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC),
+ // and rows for the CDC data which will be output to CDCReader.CDC_LOCATION.
+ // See [[CDCReader]] for general details on how partitioning on the CDC type column works.
+
+ // In the following functions `updateOutput`, `deleteOutput` and `insertOutput`, we
+ // produce a Seq[Expression] for each intended output row.
+ // Depending on the clause and whether CDC is enabled, we output between 0 and 3 rows, as a
+ // Seq[Seq[Expression]]
+
+ // There is one corner case outlined above at isDeleteWithDuplicateMatchesAndCdc definition.
+ // When we have a delete-ONLY merge with duplicate matches we have N + 4 columns:
+ // N target cols, TARGET_ROW_ID_COL, ROW_DROPPED_COL, INCR_ROW_COUNT_COL, CDC_TYPE_COLUMN_NAME
+ // When we have a delete-when-matched merge with duplicate matches + an insert clause, we have
+ // N + 5 columns:
+ // N target cols, TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL, ROW_DROPPED_COL, INCR_ROW_COUNT_COL,
+ // CDC_TYPE_COLUMN_NAME
+ // These ROW_ID_COL will always be dropped before the final write.
+
+ if (isDeleteWithDuplicateMatchesAndCdc) {
+ targetOutputCols = targetOutputCols :+ UnresolvedAttribute(TARGET_ROW_ID_COL)
+ outputRowSchema = outputRowSchema.add(TARGET_ROW_ID_COL, DataTypes.LongType)
+ if (notMatchedClauses.nonEmpty) { // there is an insert clause, make SRC_ROW_ID_COL=null
+ targetOutputCols = targetOutputCols :+ Alias(Literal(null), SOURCE_ROW_ID_COL)()
+ outputRowSchema = outputRowSchema.add(SOURCE_ROW_ID_COL, DataTypes.LongType)
+ }
+ }
+
+ if (cdcEnabled) {
+ outputRowSchema = outputRowSchema
+ .add(ROW_DROPPED_COL, DataTypes.BooleanType)
+ .add(INCR_ROW_COUNT_COL, DataTypes.BooleanType)
+ .add(CDC_TYPE_COLUMN_NAME, DataTypes.StringType)
+ }
+
+ def updateOutput(resolvedActions: Seq[DeltaMergeAction], incrMetricExpr: Expression)
+ : Seq[Seq[Expression]] = {
+ val updateExprs = {
+ // Generate update expressions and set ROW_DELETED_COL = false and
+ // CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC
+ val mainDataOutput = resolvedActions.map(_.expr) :+ FalseLiteral :+
+ incrMetricExpr :+ CDC_TYPE_NOT_CDC_LITERAL
+ if (cdcEnabled) {
+ // For update preimage, we have do a no-op copy with ROW_DELETED_COL = false and
+ // CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_PREIMAGE and INCR_ROW_COUNT_COL as a no-op
+ // (because the metric will be incremented in `mainDataOutput`)
+ val preImageOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+
+ Literal(CDC_TYPE_UPDATE_PREIMAGE)
+ // For update postimage, we have the same expressions as for mainDataOutput but with
+ // INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in
+ // `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_POSTIMAGE
+ val postImageOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+
+ Literal(CDC_TYPE_UPDATE_POSTIMAGE)
+ Seq(mainDataOutput, preImageOutput, postImageOutput)
+ } else {
+ Seq(mainDataOutput)
+ }
+ }
+ updateExprs.map(resolveOnJoinedPlan)
+ }
+
+ def deleteOutput(incrMetricExpr: Expression): Seq[Seq[Expression]] = {
+ val deleteExprs = {
+ // Generate expressions to set the ROW_DELETED_COL = true and CDC_TYPE_COLUMN_NAME =
+ // CDC_TYPE_NOT_CDC
+ val mainDataOutput = targetOutputCols :+ TrueLiteral :+ incrMetricExpr :+
+ CDC_TYPE_NOT_CDC_LITERAL
+ if (cdcEnabled) {
+ // For delete we do a no-op copy with ROW_DELETED_COL = false, INCR_ROW_COUNT_COL as a
+ // no-op (because the metric will be incremented in `mainDataOutput`) and
+ // CDC_TYPE_COLUMN_NAME = CDC_TYPE_DELETE
+ val deleteCdcOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+ CDC_TYPE_DELETE
+ Seq(mainDataOutput, deleteCdcOutput)
+ } else {
+ Seq(mainDataOutput)
+ }
+ }
+ deleteExprs.map(resolveOnJoinedPlan)
+ }
+
+ def insertOutput(resolvedActions: Seq[DeltaMergeAction], incrMetricExpr: Expression)
+ : Seq[Seq[Expression]] = {
+ // Generate insert expressions and set ROW_DELETED_COL = false and
+ // CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC
+ val insertExprs = resolvedActions.map(_.expr)
+ val mainDataOutput = resolveOnJoinedPlan(
+ if (isDeleteWithDuplicateMatchesAndCdc) {
+ // Must be delete-when-matched merge with duplicate matches + insert clause
+ // Therefore we must keep the target row id and source row id. Since this is a not-matched
+ // clause we know the target row-id will be null. See above at
+ // isDeleteWithDuplicateMatchesAndCdc definition for more details.
+ insertExprs :+
+ Alias(Literal(null), TARGET_ROW_ID_COL)() :+ UnresolvedAttribute(SOURCE_ROW_ID_COL) :+
+ FalseLiteral :+ incrMetricExpr :+ CDC_TYPE_NOT_CDC_LITERAL
+ } else {
+ insertExprs :+ FalseLiteral :+ incrMetricExpr :+ CDC_TYPE_NOT_CDC_LITERAL
+ }
+ )
+ if (cdcEnabled) {
+ // For insert we have the same expressions as for mainDataOutput, but with
+ // INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in
+ // `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_INSERT
+ val insertCdcOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+ Literal(CDC_TYPE_INSERT)
+ Seq(mainDataOutput, insertCdcOutput)
+ } else {
+ Seq(mainDataOutput)
+ }
+ }
+
+ def clauseOutput(clause: DeltaMergeIntoClause): Seq[Seq[Expression]] = clause match {
+ case u: DeltaMergeIntoMatchedUpdateClause =>
+ updateOutput(u.resolvedActions, And(incrUpdatedCountExpr, incrUpdatedMatchedCountExpr))
+ case _: DeltaMergeIntoMatchedDeleteClause =>
+ deleteOutput(And(incrDeletedCountExpr, incrDeletedMatchedCountExpr))
+ case i: DeltaMergeIntoNotMatchedInsertClause =>
+ insertOutput(i.resolvedActions, incrInsertedCountExpr)
+ case u: DeltaMergeIntoNotMatchedBySourceUpdateClause =>
+ updateOutput(
+ u.resolvedActions,
+ And(incrUpdatedCountExpr, incrUpdatedNotMatchedBySourceCountExpr))
+ case _: DeltaMergeIntoNotMatchedBySourceDeleteClause =>
+ deleteOutput(And(incrDeletedCountExpr, incrDeletedNotMatchedBySourceCountExpr))
+ }
+
+ def clauseCondition(clause: DeltaMergeIntoClause): Expression = {
+ // if condition is None, then expression always evaluates to true
+ val condExpr = clause.condition.getOrElse(TrueLiteral)
+ resolveOnJoinedPlan(Seq(condExpr)).head
+ }
+
+ val targetRowHasNoMatch = resolveOnJoinedPlan(Seq(col(SOURCE_ROW_PRESENT_COL).isNull.expr)).head
+ val sourceRowHasNoMatch = resolveOnJoinedPlan(Seq(col(TARGET_ROW_PRESENT_COL).isNull.expr)).head
+ val matchedConditions = matchedClauses.map(clauseCondition)
+ val matchedOutputs = matchedClauses.map(clauseOutput)
+ val notMatchedConditions = notMatchedClauses.map(clauseCondition)
+ val notMatchedOutputs = notMatchedClauses.map(clauseOutput)
+ val notMatchedBySourceConditions = notMatchedBySourceClauses.map(clauseCondition)
+ val notMatchedBySourceOutputs = notMatchedBySourceClauses.map(clauseOutput)
+ val noopCopyOutput =
+ resolveOnJoinedPlan(targetOutputCols :+ FalseLiteral :+ incrNoopCountExpr :+
+ CDC_TYPE_NOT_CDC_LITERAL)
+ val deleteRowOutput =
+ resolveOnJoinedPlan(targetOutputCols :+ TrueLiteral :+ TrueLiteral :+
+ CDC_TYPE_NOT_CDC_LITERAL)
+ var outputDF = addMergeJoinProcessor(spark, joinedPlan, outputRowSchema,
+ targetRowHasNoMatch = targetRowHasNoMatch,
+ sourceRowHasNoMatch = sourceRowHasNoMatch,
+ matchedConditions = matchedConditions,
+ matchedOutputs = matchedOutputs,
+ notMatchedConditions = notMatchedConditions,
+ notMatchedOutputs = notMatchedOutputs,
+ notMatchedBySourceConditions = notMatchedBySourceConditions,
+ notMatchedBySourceOutputs = notMatchedBySourceOutputs,
+ noopCopyOutput = noopCopyOutput,
+ deleteRowOutput = deleteRowOutput)
+
+ if (isDeleteWithDuplicateMatchesAndCdc) {
+ // When we have a delete when matched clause with duplicate matches we have to remove
+ // duplicate CDC rows. This scenario is further explained at
+ // isDeleteWithDuplicateMatchesAndCdc definition.
+
+ // To remove duplicate CDC rows generated by the duplicate matches we dedupe by
+ // TARGET_ROW_ID_COL since there should only be one CDC delete row per target row.
+ // When there is an insert clause in addition to the delete clause we additionally dedupe by
+ // SOURCE_ROW_ID_COL and CDC_TYPE_COLUMN_NAME to avoid dropping valid duplicate inserted rows
+ // and their corresponding CDC rows.
+ val columnsToDedupeBy = if (notMatchedClauses.nonEmpty) { // insert clause
+ Seq(TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL, CDC_TYPE_COLUMN_NAME)
+ } else {
+ Seq(TARGET_ROW_ID_COL)
+ }
+ outputDF = outputDF
+ .dropDuplicates(columnsToDedupeBy)
+ .drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL, TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL)
+ } else {
+ outputDF = outputDF.drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL)
+ }
+
+ logDebug("writeAllChanges: join output plan:\n" + outputDF.queryExecution)
+
+ // Write to Delta
+ val newFiles = deltaTxn
+ .writeFiles(repartitionIfNeeded(spark, outputDF, deltaTxn.metadata.partitionColumns))
+
+ // Update metrics
+ val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles)
+ metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile])
+ metrics("numTargetChangeFilesAdded") += newFiles.count(_.isInstanceOf[AddCDCFile])
+ metrics("numTargetChangeFileBytes") += newFiles.collect{ case f: AddCDCFile => f.size }.sum
+ metrics("numTargetBytesAdded") += addedBytes
+ metrics("numTargetPartitionsAddedTo") += addedPartitions
+ if (multipleMatchDeleteOnlyOvercount.isDefined) {
+ // Compensate for counting duplicates during the query.
+ val actualRowsDeleted =
+ metrics("numTargetRowsDeleted").value - multipleMatchDeleteOnlyOvercount.get
+ assert(actualRowsDeleted >= 0)
+ metrics("numTargetRowsDeleted").set(actualRowsDeleted)
+ val actualRowsMatchedDeleted =
+ metrics("numTargetRowsMatchedDeleted").value - multipleMatchDeleteOnlyOvercount.get
+ assert(actualRowsMatchedDeleted >= 0)
+ metrics("numTargetRowsMatchedDeleted").set(actualRowsMatchedDeleted)
+ }
+
+ newFiles
+ }
+
+ private def addMergeJoinProcessor(
+ spark: SparkSession,
+ joinedPlan: LogicalPlan,
+ outputRowSchema: StructType,
+ targetRowHasNoMatch: Expression,
+ sourceRowHasNoMatch: Expression,
+ matchedConditions: Seq[Expression],
+ matchedOutputs: Seq[Seq[Seq[Expression]]],
+ notMatchedConditions: Seq[Expression],
+ notMatchedOutputs: Seq[Seq[Seq[Expression]]],
+ notMatchedBySourceConditions: Seq[Expression],
+ notMatchedBySourceOutputs: Seq[Seq[Seq[Expression]]],
+ noopCopyOutput: Seq[Expression],
+ deleteRowOutput: Seq[Expression]): Dataset[Row] = {
+ def wrap(e: Expression): BaseExprMeta[Expression] = {
+ GpuOverrides.wrapExpr(e, rapidsConf, None)
+ }
+
+ val targetRowHasNoMatchMeta = wrap(targetRowHasNoMatch)
+ val sourceRowHasNoMatchMeta = wrap(sourceRowHasNoMatch)
+ val matchedConditionsMetas = matchedConditions.map(wrap)
+ val matchedOutputsMetas = matchedOutputs.map(_.map(_.map(wrap)))
+ val notMatchedConditionsMetas = notMatchedConditions.map(wrap)
+ val notMatchedOutputsMetas = notMatchedOutputs.map(_.map(_.map(wrap)))
+ val notMatchedBySourceConditionsMetas = notMatchedConditions.map(wrap)
+ val notMatchedBySourceOutputsMetas = notMatchedOutputs.map(_.map(_.map(wrap)))
+ val noopCopyOutputMetas = noopCopyOutput.map(wrap)
+ val deleteRowOutputMetas = deleteRowOutput.map(wrap)
+ val allMetas = Seq(targetRowHasNoMatchMeta, sourceRowHasNoMatchMeta) ++
+ matchedConditionsMetas ++ matchedOutputsMetas.flatten.flatten ++
+ notMatchedConditionsMetas ++ notMatchedOutputsMetas.flatten.flatten ++
+ notMatchedBySourceConditionsMetas ++ notMatchedBySourceOutputsMetas.flatten.flatten ++
+ noopCopyOutputMetas ++ deleteRowOutputMetas
+ allMetas.foreach(_.tagForGpu())
+ val canReplace = allMetas.forall(_.canExprTreeBeReplaced) && rapidsConf.isOperatorEnabled(
+ "spark.rapids.sql.exec.RapidsProcessDeltaMergeJoinExec", false, false)
+ if (rapidsConf.shouldExplainAll || (rapidsConf.shouldExplain && !canReplace)) {
+ val exprExplains = allMetas.map(_.explain(rapidsConf.shouldExplainAll))
+ val execWorkInfo = if (canReplace) {
+ "will run on GPU"
+ } else {
+ "cannot run on GPU because not all merge processing expressions can be replaced"
+ }
+ logWarning(s" $execWorkInfo:\n" +
+ s" ${exprExplains.mkString(" ")}")
+ }
+
+ if (canReplace) {
+ val processedJoinPlan = RapidsProcessDeltaMergeJoin(
+ joinedPlan,
+ outputRowSchema.toAttributes,
+ targetRowHasNoMatch = targetRowHasNoMatch,
+ sourceRowHasNoMatch = sourceRowHasNoMatch,
+ matchedConditions = matchedConditions,
+ matchedOutputs = matchedOutputs,
+ notMatchedConditions = notMatchedConditions,
+ notMatchedOutputs = notMatchedOutputs,
+ notMatchedBySourceConditions = notMatchedBySourceConditions,
+ notMatchedBySourceOutputs = notMatchedBySourceOutputs,
+ noopCopyOutput = noopCopyOutput,
+ deleteRowOutput = deleteRowOutput)
+ Dataset.ofRows(spark, processedJoinPlan)
+ } else {
+ val joinedRowEncoder = RowEncoder(joinedPlan.schema)
+ val outputRowEncoder = RowEncoder(outputRowSchema).resolveAndBind()
+
+ val processor = new JoinedRowProcessor(
+ targetRowHasNoMatch = targetRowHasNoMatch,
+ sourceRowHasNoMatch = sourceRowHasNoMatch,
+ matchedConditions = matchedConditions,
+ matchedOutputs = matchedOutputs,
+ notMatchedConditions = notMatchedConditions,
+ notMatchedOutputs = notMatchedOutputs,
+ notMatchedBySourceConditions = notMatchedBySourceConditions,
+ notMatchedBySourceOutputs = notMatchedBySourceOutputs,
+ noopCopyOutput = noopCopyOutput,
+ deleteRowOutput = deleteRowOutput,
+ joinedAttributes = joinedPlan.output,
+ joinedRowEncoder = joinedRowEncoder,
+ outputRowEncoder = outputRowEncoder)
+
+ Dataset.ofRows(spark, joinedPlan).mapPartitions(processor.processPartition)(outputRowEncoder)
+ }
+ }
+
+ /**
+ * Build a new logical plan using the given `files` that has the same output columns (exprIds)
+ * as the `target` logical plan, so that existing update/insert expressions can be applied
+ * on this new plan.
+ */
+ private def buildTargetPlanWithFiles(
+ spark: SparkSession,
+ deltaTxn: OptimisticTransaction,
+ files: Seq[AddFile]): DataFrame = {
+ val targetOutputCols = getTargetOutputCols(deltaTxn)
+ val targetOutputColsMap = {
+ val colsMap: Map[String, NamedExpression] = targetOutputCols.view
+ .map(col => col.name -> col).toMap
+ if (conf.caseSensitiveAnalysis) {
+ colsMap
+ } else {
+ CaseInsensitiveMap(colsMap)
+ }
+ }
+
+ val plan = {
+ // We have to do surgery to use the attributes from `targetOutputCols` to scan the table.
+ // In cases of schema evolution, they may not be the same type as the original attributes.
+ val original =
+ deltaTxn.deltaLog.createDataFrame(deltaTxn.snapshot, files).queryExecution.analyzed
+ val transformed = original.transform {
+ case LogicalRelation(base, _, catalogTbl, isStreaming) =>
+ LogicalRelation(
+ base,
+ // We can ignore the new columns which aren't yet AttributeReferences.
+ targetOutputCols.collect { case a: AttributeReference => a },
+ catalogTbl,
+ isStreaming)
+ }
+
+ // In case of schema evolution & column mapping, we would also need to rebuild the file format
+ // because under column mapping, the reference schema within DeltaParquetFileFormat
+ // that is used to populate metadata needs to be updated
+ if (deltaTxn.metadata.columnMappingMode != NoMapping) {
+ val updatedFileFormat = deltaTxn.deltaLog.fileFormat(deltaTxn.metadata)
+ DeltaTableUtils.replaceFileFormat(transformed, updatedFileFormat)
+ } else {
+ transformed
+ }
+ }
+
+ // For each plan output column, find the corresponding target output column (by name) and
+ // create an alias
+ val aliases = plan.output.map {
+ case newAttrib: AttributeReference =>
+ val existingTargetAttrib = targetOutputColsMap.get(newAttrib.name)
+ .getOrElse {
+ throw DeltaErrors.failedFindAttributeInOutputColumns(
+ newAttrib.name, targetOutputCols.mkString(","))
+ }.asInstanceOf[AttributeReference]
+
+ if (existingTargetAttrib.exprId == newAttrib.exprId) {
+ // It's not valid to alias an expression to its own exprId (this is considered a
+ // non-unique exprId by the analyzer), so we just use the attribute directly.
+ newAttrib
+ } else {
+ Alias(newAttrib, existingTargetAttrib.name)(exprId = existingTargetAttrib.exprId)
+ }
+ }
+
+ Dataset.ofRows(spark, Project(aliases, plan))
+ }
+
+ /** Expressions to increment SQL metrics */
+ private def makeMetricUpdateUDF(name: String, deterministic: Boolean = false): Expression = {
+ // only capture the needed metric in a local variable
+ val metric = metrics(name)
+ var u = DeltaUDF.boolean(new GpuDeltaMetricUpdateUDF(metric))
+ if (!deterministic) {
+ u = u.asNondeterministic()
+ }
+ u.apply().expr
+ }
+
+ private def getTargetOutputCols(txn: OptimisticTransaction): Seq[NamedExpression] = {
+ txn.metadata.schema.map { col =>
+ targetOutputAttributesMap
+ .get(col.name)
+ .map { a =>
+ AttributeReference(col.name, col.dataType, col.nullable)(a.exprId)
+ }
+ .getOrElse(Alias(Literal(null), col.name)()
+ )
+ }
+ }
+
+ /**
+ * Repartitions the output DataFrame by the partition columns if table is partitioned
+ * and `merge.repartitionBeforeWrite.enabled` is set to true.
+ */
+ protected def repartitionIfNeeded(
+ spark: SparkSession,
+ df: DataFrame,
+ partitionColumns: Seq[String]): DataFrame = {
+ if (partitionColumns.nonEmpty && spark.conf.get(DeltaSQLConf.MERGE_REPARTITION_BEFORE_WRITE)) {
+ df.repartition(partitionColumns.map(col): _*)
+ } else {
+ df
+ }
+ }
+
+ /**
+ * Execute the given `thunk` and return its result while recording the time taken to do it.
+ *
+ * @param sqlMetricName name of SQL metric to update with the time taken by the thunk
+ * @param thunk the code to execute
+ */
+ private def recordMergeOperation[A](sqlMetricName: String)(thunk: => A): A = {
+ val startTimeNs = System.nanoTime()
+ val r = thunk
+ val timeTakenMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)
+ if (sqlMetricName != null && timeTakenMs > 0) {
+ metrics(sqlMetricName) += timeTakenMs
+ }
+ r
+ }
+}
+
+object GpuMergeIntoCommand {
+ /**
+ * Spark UI will track all normal accumulators along with Spark tasks to show them on Web UI.
+ * However, the accumulator used by `MergeIntoCommand` can store a very large value since it
+ * tracks all files that need to be rewritten. We should ask Spark UI to not remember it,
+ * otherwise, the UI data may consume lots of memory. Hence, we use the prefix `internal.metrics.`
+ * to make this accumulator become an internal accumulator, so that it will not be tracked by
+ * Spark UI.
+ */
+ val TOUCHED_FILES_ACCUM_NAME = "internal.metrics.MergeIntoDelta.touchedFiles"
+
+ val ROW_ID_COL = "_row_id_"
+ val TARGET_ROW_ID_COL = "_target_row_id_"
+ val SOURCE_ROW_ID_COL = "_source_row_id_"
+ val FILE_NAME_COL = "_file_name_"
+ val SOURCE_ROW_PRESENT_COL = "_source_row_present_"
+ val TARGET_ROW_PRESENT_COL = "_target_row_present_"
+ val ROW_DROPPED_COL = GpuDeltaMergeConstants.ROW_DROPPED_COL
+ val INCR_ROW_COUNT_COL = "_incr_row_count_"
+
+ // Some Delta versions use Literal(null) which translates to a literal of NullType instead
+ // of the Literal(null, StringType) which is needed, so using a fixed version here
+ // rather than the version from Delta Lake.
+ val CDC_TYPE_NOT_CDC_LITERAL = Literal(null, StringType)
+
+ /**
+ * @param targetRowHasNoMatch whether a joined row is a target row with no match in the source
+ * table
+ * @param sourceRowHasNoMatch whether a joined row is a source row with no match in the target
+ * table
+ * @param matchedConditions condition for each match clause
+ * @param matchedOutputs corresponding output for each match clause. for each clause, we
+ * have 1-3 output rows, each of which is a sequence of expressions
+ * to apply to the joined row
+ * @param notMatchedConditions condition for each not-matched clause
+ * @param notMatchedOutputs corresponding output for each not-matched clause. for each clause,
+ * we have 1-2 output rows, each of which is a sequence of
+ * expressions to apply to the joined row
+ * @param notMatchedBySourceConditions condition for each not-matched-by-source clause
+ * @param notMatchedBySourceOutputs corresponding output for each not-matched-by-source
+ * clause. for each clause, we have 1-3 output rows, each of
+ * which is a sequence of expressions to apply to the joined
+ * row
+ * @param noopCopyOutput no-op expression to copy a target row to the output
+ * @param deleteRowOutput expression to drop a row from the final output. this is used for
+ * source rows that don't match any not-matched clauses
+ * @param joinedAttributes schema of our outer-joined dataframe
+ * @param joinedRowEncoder joinedDF row encoder
+ * @param outputRowEncoder final output row encoder
+ */
+ class JoinedRowProcessor(
+ targetRowHasNoMatch: Expression,
+ sourceRowHasNoMatch: Expression,
+ matchedConditions: Seq[Expression],
+ matchedOutputs: Seq[Seq[Seq[Expression]]],
+ notMatchedConditions: Seq[Expression],
+ notMatchedOutputs: Seq[Seq[Seq[Expression]]],
+ notMatchedBySourceConditions: Seq[Expression],
+ notMatchedBySourceOutputs: Seq[Seq[Seq[Expression]]],
+ noopCopyOutput: Seq[Expression],
+ deleteRowOutput: Seq[Expression],
+ joinedAttributes: Seq[Attribute],
+ joinedRowEncoder: ExpressionEncoder[Row],
+ outputRowEncoder: ExpressionEncoder[Row]) extends Serializable {
+
+ private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = {
+ UnsafeProjection.create(exprs, joinedAttributes)
+ }
+
+ private def generatePredicate(expr: Expression): BasePredicate = {
+ GeneratePredicate.generate(expr, joinedAttributes)
+ }
+
+ def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = {
+
+ val targetRowHasNoMatchPred = generatePredicate(targetRowHasNoMatch)
+ val sourceRowHasNoMatchPred = generatePredicate(sourceRowHasNoMatch)
+ val matchedPreds = matchedConditions.map(generatePredicate)
+ val matchedProjs = matchedOutputs.map(_.map(generateProjection))
+ val notMatchedPreds = notMatchedConditions.map(generatePredicate)
+ val notMatchedProjs = notMatchedOutputs.map(_.map(generateProjection))
+ val notMatchedBySourcePreds = notMatchedBySourceConditions.map(generatePredicate)
+ val notMatchedBySourceProjs = notMatchedBySourceOutputs.map(_.map(generateProjection))
+ val noopCopyProj = generateProjection(noopCopyOutput)
+ val deleteRowProj = generateProjection(deleteRowOutput)
+ val outputProj = UnsafeProjection.create(outputRowEncoder.schema)
+
+ // this is accessing ROW_DROPPED_COL. If ROW_DROPPED_COL is not in outputRowEncoder.schema
+ // then CDC must be disabled and it's the column after our output cols
+ def shouldDeleteRow(row: InternalRow): Boolean = {
+ row.getBoolean(
+ outputRowEncoder.schema.getFieldIndex(ROW_DROPPED_COL)
+ .getOrElse(outputRowEncoder.schema.fields.size)
+ )
+ }
+
+ def processRow(inputRow: InternalRow): Iterator[InternalRow] = {
+ // Identify which set of clauses to execute: matched, not-matched or not-matched-by-source
+ val (predicates, projections, noopAction) = if (targetRowHasNoMatchPred.eval(inputRow)) {
+ // Target row did not match any source row, so update the target row.
+ (notMatchedBySourcePreds, notMatchedBySourceProjs, noopCopyProj)
+ } else if (sourceRowHasNoMatchPred.eval(inputRow)) {
+ // Source row did not match with any target row, so insert the new source row
+ (notMatchedPreds, notMatchedProjs, deleteRowProj)
+ } else {
+ // Source row matched with target row, so update the target row
+ (matchedPreds, matchedProjs, noopCopyProj)
+ }
+
+ // find (predicate, projection) pair whose predicate satisfies inputRow
+ val pair = (predicates zip projections).find {
+ case (predicate, _) => predicate.eval(inputRow)
+ }
+
+ pair match {
+ case Some((_, projections)) =>
+ projections.map(_.apply(inputRow)).iterator
+ case None => Iterator(noopAction.apply(inputRow))
+ }
+ }
+
+ val toRow = joinedRowEncoder.createSerializer()
+ val fromRow = outputRowEncoder.createDeserializer()
+ rowIterator
+ .map(toRow)
+ .flatMap(processRow)
+ .filter(!shouldDeleteRow(_))
+ .map { notDeletedInternalRow =>
+ fromRow(outputProj(notDeletedInternalRow))
+ }
+ }
+ }
+
+ /** Count the number of distinct partition values among the AddFiles in the given set. */
+ def totalBytesAndDistinctPartitionValues(files: Seq[FileAction]): (Long, Int) = {
+ val distinctValues = new mutable.HashSet[Map[String, String]]()
+ var bytes = 0L
+ val iter = files.collect { case a: AddFile => a }.iterator
+ while (iter.hasNext) {
+ val file = iter.next()
+ distinctValues += file.partitionValues
+ bytes += file.size
+ }
+ // If the only distinct value map is an empty map, then it must be an unpartitioned table.
+ // Return 0 in that case.
+ val numDistinctValues =
+ if (distinctValues.size == 1 && distinctValues.head.isEmpty) 0 else distinctValues.size
+ (bytes, numDistinctValues)
+ }
+}
diff --git a/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuOptimisticTransaction.scala b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuOptimisticTransaction.scala
new file mode 100644
index 00000000000..38ee8a786c0
--- /dev/null
+++ b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuOptimisticTransaction.scala
@@ -0,0 +1,274 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from OptimisticTransaction.scala and TransactionalWrite.scala
+ * in the Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.delta.rapids.delta23x
+
+import java.net.URI
+
+import scala.collection.mutable.ListBuffer
+
+import com.nvidia.spark.rapids._
+import com.nvidia.spark.rapids.delta._
+import org.apache.commons.lang3.exception.ExceptionUtils
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.delta._
+import org.apache.spark.sql.delta.actions.{AddFile, FileAction}
+import org.apache.spark.sql.delta.constraints.{Constraint, Constraints}
+import org.apache.spark.sql.delta.rapids.GpuOptimisticTransactionBase
+import org.apache.spark.sql.delta.schema.InvariantViolationException
+import org.apache.spark.sql.delta.sources.DeltaSQLConf
+import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormatWriter}
+import org.apache.spark.sql.functions.to_json
+import org.apache.spark.sql.rapids.{BasicColumnarWriteJobStatsTracker, ColumnarWriteJobStatsTracker, GpuFileFormatWriter, GpuWriteJobStatsTracker}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.{Clock, SerializableConfiguration}
+
+/**
+ * Used to perform a set of reads in a transaction and then commit a set of updates to the
+ * state of the log. All reads from the DeltaLog, MUST go through this instance rather
+ * than directly to the DeltaLog otherwise they will not be check for logical conflicts
+ * with concurrent updates.
+ *
+ * This class is not thread-safe.
+ *
+ * @param deltaLog The Delta Log for the table this transaction is modifying.
+ * @param snapshot The snapshot that this transaction is reading at.
+ * @param rapidsConf RAPIDS Accelerator config settings.
+ */
+class GpuOptimisticTransaction
+ (deltaLog: DeltaLog, snapshot: Snapshot, rapidsConf: RapidsConf)
+ (implicit clock: Clock)
+ extends GpuOptimisticTransactionBase(deltaLog, snapshot, rapidsConf)(clock) {
+
+ /** Creates a new OptimisticTransaction.
+ *
+ * @param deltaLog The Delta Log for the table this transaction is modifying.
+ * @param rapidsConf RAPIDS Accelerator config settings
+ */
+ def this(deltaLog: DeltaLog, rapidsConf: RapidsConf)(implicit clock: Clock) = {
+ this(deltaLog, deltaLog.update(), rapidsConf)
+ }
+
+ private def getGpuStatsColExpr(
+ statsDataSchema: Seq[Attribute],
+ statsCollection: GpuStatisticsCollection): Expression = {
+ Dataset.ofRows(spark, LocalRelation(statsDataSchema))
+ .select(to_json(statsCollection.statsCollector))
+ .queryExecution.analyzed.expressions.head
+ }
+
+ /** Return the pair of optional stats tracker and stats collection class */
+ private def getOptionalGpuStatsTrackerAndStatsCollection(
+ output: Seq[Attribute],
+ partitionSchema: StructType, data: DataFrame): (
+ Option[GpuDeltaJobStatisticsTracker],
+ Option[GpuStatisticsCollection]) = {
+ if (spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_COLLECT_STATS)) {
+
+ val (statsDataSchema, statsCollectionSchema) = getStatsSchema(output, partitionSchema)
+
+ val indexedCols = DeltaConfigs.DATA_SKIPPING_NUM_INDEXED_COLS.fromMetaData(metadata)
+ val prefixLength =
+ spark.sessionState.conf.getConf(DeltaSQLConf.DATA_SKIPPING_STRING_PREFIX_LENGTH)
+ val tableSchema = {
+ // If collecting stats using the table schema, then pass in statsCollectionSchema.
+ // Otherwise pass in statsDataSchema to collect stats using the DataFrame schema.
+ if (spark.sessionState.conf.getConf(DeltaSQLConf
+ .DELTA_COLLECT_STATS_USING_TABLE_SCHEMA)) {
+ statsCollectionSchema.toStructType
+ } else {
+ statsDataSchema.toStructType
+ }
+ }
+
+ val _spark = spark
+
+ val statsCollection = new GpuStatisticsCollection {
+ override val spark = _spark
+ override val deletionVectorsSupported = false
+ override val tableDataSchema = tableSchema
+ override val dataSchema = statsDataSchema.toStructType
+ override val numIndexedCols = indexedCols
+ override val stringPrefixLength: Int = prefixLength
+ }
+
+ val statsColExpr = getGpuStatsColExpr(statsDataSchema, statsCollection)
+
+ val statsSchema = statsCollection.statCollectionSchema
+ val explodedDataSchema = statsCollection.explodedDataSchema
+ val batchStatsToRow = (batch: ColumnarBatch, row: InternalRow) => {
+ GpuStatisticsCollection.batchStatsToRow(statsSchema, explodedDataSchema, batch, row)
+ }
+ (Some(new GpuDeltaJobStatisticsTracker(statsDataSchema, statsColExpr, batchStatsToRow)),
+ Some(statsCollection))
+ } else {
+ (None, None)
+ }
+ }
+
+ override def writeFiles(
+ inputData: Dataset[_],
+ writeOptions: Option[DeltaOptions],
+ additionalConstraints: Seq[Constraint]): Seq[FileAction] = {
+ hasWritten = true
+
+ val spark = inputData.sparkSession
+ val (data, partitionSchema) = performCDCPartition(inputData)
+ val outputPath = deltaLog.dataPath
+
+ val (normalizedQueryExecution, output, generatedColumnConstraints, _) =
+ normalizeData(deltaLog, data)
+
+ // Build a new plan with a stub GpuDeltaWrite node to work around undesired transitions between
+ // columns and rows when AQE is involved. Without this node in the plan, AdaptiveSparkPlanExec
+ // could be the root node of the plan. In that case we do not have enough context to know
+ // whether the AdaptiveSparkPlanExec should be columnar or not, since the GPU overrides do not
+ // see how the parent is using the AdaptiveSparkPlanExec outputs. By using this stub node that
+ // appears to be a data writing node to AQE (it derives from V2CommandExec), the
+ // AdaptiveSparkPlanExec will be planned as a child of this new node. That provides enough
+ // context to plan the AQE sub-plan properly with respect to columnar and row transitions.
+ // We could force the AQE node to be columnar here by explicitly replacing the node, but that
+ // breaks the connection between the queryExecution and the node that will actually execute.
+ val gpuWritePlan = Dataset.ofRows(spark, RapidsDeltaWrite(normalizedQueryExecution.logical))
+ val queryExecution = gpuWritePlan.queryExecution
+
+ val partitioningColumns = getPartitioningColumns(partitionSchema, output)
+
+ val committer = getCommitter(outputPath)
+
+ // If Statistics Collection is enabled, then create a stats tracker that will be injected during
+ // the FileFormatWriter.write call below and will collect per-file stats using
+ // StatisticsCollection
+ val (optionalStatsTracker, _) = getOptionalGpuStatsTrackerAndStatsCollection(output,
+ partitionSchema, data)
+
+ val constraints =
+ Constraints.getAll(metadata, spark) ++ generatedColumnConstraints ++ additionalConstraints
+
+ SQLExecution.withNewExecutionId(queryExecution, Option("deltaTransactionalWrite")) {
+ val outputSpec = FileFormatWriter.OutputSpec(
+ outputPath.toString,
+ Map.empty,
+ output)
+
+ // Remove any unnecessary row conversions added as part of Spark planning
+ val queryPhysicalPlan = queryExecution.executedPlan match {
+ case GpuColumnarToRowExec(child, _) => child
+ case p => p
+ }
+ val gpuRapidsWrite = queryPhysicalPlan match {
+ case g: GpuRapidsDeltaWriteExec => Some(g)
+ case _ => None
+ }
+
+ val empty2NullPlan = convertEmptyToNullIfNeeded(queryPhysicalPlan,
+ partitioningColumns, constraints)
+ val planWithInvariants = addInvariantChecks(empty2NullPlan, constraints)
+ val physicalPlan = convertToGpu(planWithInvariants)
+
+ val statsTrackers: ListBuffer[ColumnarWriteJobStatsTracker] = ListBuffer()
+
+ val hadoopConf = spark.sessionState.newHadoopConfWithOptions(
+ metadata.configuration ++ deltaLog.options)
+
+ if (spark.conf.get(DeltaSQLConf.DELTA_HISTORY_METRICS_ENABLED)) {
+ val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
+ val basicWriteJobStatsTracker = new BasicColumnarWriteJobStatsTracker(
+ serializableHadoopConf,
+ BasicWriteJobStatsTracker.metrics)
+ registerSQLMetrics(spark, basicWriteJobStatsTracker.driverSideMetrics)
+ statsTrackers.append(basicWriteJobStatsTracker)
+ gpuRapidsWrite.foreach { grw =>
+ val tracker = new GpuWriteJobStatsTracker(serializableHadoopConf,
+ grw.basicMetrics, grw.taskMetrics)
+ statsTrackers.append(tracker)
+ }
+ }
+
+ // Retain only a minimal selection of Spark writer options to avoid any potential
+ // compatibility issues
+ val options = writeOptions match {
+ case None => Map.empty[String, String]
+ case Some(writeOptions) =>
+ writeOptions.options.filterKeys { key =>
+ key.equalsIgnoreCase(DeltaOptions.MAX_RECORDS_PER_FILE) ||
+ key.equalsIgnoreCase(DeltaOptions.COMPRESSION)
+ }.toMap
+ }
+
+ val deltaFileFormat = deltaLog.fileFormat(metadata)
+ val gpuFileFormat = if (deltaFileFormat.getClass == classOf[DeltaParquetFileFormat]) {
+ new GpuParquetFileFormat
+ } else {
+ throw new IllegalStateException(s"file format $deltaFileFormat is not supported")
+ }
+
+ try {
+ GpuFileFormatWriter.write(
+ sparkSession = spark,
+ plan = physicalPlan,
+ fileFormat = gpuFileFormat,
+ committer = committer,
+ outputSpec = outputSpec,
+ hadoopConf = hadoopConf,
+ partitionColumns = partitioningColumns,
+ bucketSpec = None,
+ statsTrackers = optionalStatsTracker.toSeq ++ statsTrackers,
+ options = options,
+ rapidsConf.stableSort,
+ rapidsConf.concurrentWriterPartitionFlushSize)
+ } catch {
+ case s: SparkException =>
+ // Pull an InvariantViolationException up to the top level if it was the root cause.
+ val violationException = ExceptionUtils.getRootCause(s)
+ if (violationException.isInstanceOf[InvariantViolationException]) {
+ throw violationException
+ } else {
+ throw s
+ }
+ }
+ }
+
+ val resultFiles = committer.addedStatuses.map { a =>
+ a.copy(stats = optionalStatsTracker.map(
+ _.recordedStats(new Path(new URI(a.path)).getName)).getOrElse(a.stats))
+ }.filter {
+ // In some cases, we can write out an empty `inputData`. Some examples of this (though, they
+ // may be fixed in the future) are the MERGE command when you delete with empty source, or
+ // empty target, or on disjoint tables. This is hard to catch before the write without
+ // collecting the DF ahead of time. Instead, we can return only the AddFiles that
+ // a) actually add rows, or
+ // b) don't have any stats so we don't know the number of rows at all
+ case a: AddFile => a.numLogicalRecords.forall(_ > 0)
+ case _ => true
+ }
+
+ resultFiles.toSeq ++ committer.changeFiles
+ }
+}
diff --git a/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuUpdateCommand.scala b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuUpdateCommand.scala
new file mode 100644
index 00000000000..dd425f044d6
--- /dev/null
+++ b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuUpdateCommand.scala
@@ -0,0 +1,275 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from UpdateCommand.scala
+ * in the Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.delta.rapids.delta23x
+
+import com.nvidia.spark.rapids.delta.GpuDeltaMetricUpdateUDF
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.{Column, Dataset, Row, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal}
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.delta.{DeltaLog, DeltaOperations, DeltaTableUtils, DeltaUDF, OptimisticTransaction}
+import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, FileAction}
+import org.apache.spark.sql.delta.commands.{DeltaCommand, UpdateCommand, UpdateMetric}
+import org.apache.spark.sql.delta.files.{TahoeBatchFileIndex, TahoeFileIndex}
+import org.apache.spark.sql.delta.rapids.GpuDeltaLog
+import org.apache.spark.sql.execution.command.LeafRunnableCommand
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetrics.{createMetric, createTimingMetric}
+import org.apache.spark.sql.functions.input_file_name
+import org.apache.spark.sql.types.LongType
+
+case class GpuUpdateCommand(
+ gpuDeltaLog: GpuDeltaLog,
+ tahoeFileIndex: TahoeFileIndex,
+ target: LogicalPlan,
+ updateExpressions: Seq[Expression],
+ condition: Option[Expression])
+ extends LeafRunnableCommand with DeltaCommand {
+
+ override val output: Seq[Attribute] = {
+ Seq(AttributeReference("num_affected_rows", LongType)())
+ }
+
+ override def innerChildren: Seq[QueryPlan[_]] = Seq(target)
+
+ @transient private lazy val sc: SparkContext = SparkContext.getOrCreate()
+
+ override lazy val metrics = Map[String, SQLMetric](
+ "numAddedFiles" -> createMetric(sc, "number of files added."),
+ "numAddedBytes" -> createMetric(sc, "number of bytes added."),
+ "numRemovedFiles" -> createMetric(sc, "number of files removed."),
+ "numRemovedBytes" -> createMetric(sc, "number of bytes removed."),
+ "numUpdatedRows" -> createMetric(sc, "number of rows updated."),
+ "numCopiedRows" -> createMetric(sc, "number of rows copied."),
+ "executionTimeMs" ->
+ createTimingMetric(sc, "time taken to execute the entire operation"),
+ "scanTimeMs" ->
+ createTimingMetric(sc, "time taken to scan the files for matches"),
+ "rewriteTimeMs" ->
+ createTimingMetric(sc, "time taken to rewrite the matched files"),
+ "numAddedChangeFiles" -> createMetric(sc, "number of change data capture files generated"),
+ "changeFileBytes" -> createMetric(sc, "total size of change data capture files generated"),
+ "numTouchedRows" -> createMetric(sc, "number of rows touched (copied + updated)")
+ )
+
+ final override def run(sparkSession: SparkSession): Seq[Row] = {
+ recordDeltaOperation(tahoeFileIndex.deltaLog, "delta.dml.update") {
+ val deltaLog = tahoeFileIndex.deltaLog
+ gpuDeltaLog.withNewTransaction { txn =>
+ DeltaLog.assertRemovable(txn.snapshot)
+ if (hasBeenExecuted(txn, sparkSession)) {
+ sendDriverMetrics(sparkSession, metrics)
+ return Seq.empty
+ }
+ performUpdate(sparkSession, deltaLog, txn)
+ }
+ // Re-cache all cached plans(including this relation itself, if it's cached) that refer to
+ // this data source relation.
+ sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, target)
+ }
+ Seq(Row(metrics("numUpdatedRows").value))
+ }
+
+ private def performUpdate(
+ sparkSession: SparkSession, deltaLog: DeltaLog, txn: OptimisticTransaction): Unit = {
+ import org.apache.spark.sql.delta.implicits._
+
+ var numTouchedFiles: Long = 0
+ var numRewrittenFiles: Long = 0
+ var numAddedBytes: Long = 0
+ var numRemovedBytes: Long = 0
+ var numAddedChangeFiles: Long = 0
+ var changeFileBytes: Long = 0
+ var scanTimeMs: Long = 0
+ var rewriteTimeMs: Long = 0
+
+ val startTime = System.nanoTime()
+ val numFilesTotal = txn.snapshot.numOfFiles
+
+ val updateCondition = condition.getOrElse(Literal.TrueLiteral)
+ val (metadataPredicates, dataPredicates) =
+ DeltaTableUtils.splitMetadataAndDataPredicates(
+ updateCondition, txn.metadata.partitionColumns, sparkSession)
+ val candidateFiles = txn.filterFiles(metadataPredicates ++ dataPredicates)
+ val nameToAddFile = generateCandidateFileMap(deltaLog.dataPath, candidateFiles)
+
+ scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000
+
+ val filesToRewrite: Seq[AddFile] = if (candidateFiles.isEmpty) {
+ // Case 1: Do nothing if no row qualifies the partition predicates
+ // that are part of Update condition
+ Nil
+ } else if (dataPredicates.isEmpty) {
+ // Case 2: Update all the rows from the files that are in the specified partitions
+ // when the data filter is empty
+ candidateFiles
+ } else {
+ // Case 3: Find all the affected files using the user-specified condition
+ val fileIndex = new TahoeBatchFileIndex(
+ sparkSession, "update", candidateFiles, deltaLog, tahoeFileIndex.path, txn.snapshot)
+ // Keep everything from the resolved target except a new TahoeFileIndex
+ // that only involves the affected files instead of all files.
+ val newTarget = DeltaTableUtils.replaceFileIndex(target, fileIndex)
+ val data = Dataset.ofRows(sparkSession, newTarget)
+ val updatedRowCount = metrics("numUpdatedRows")
+ val updatedRowUdf = DeltaUDF.boolean {
+ new GpuDeltaMetricUpdateUDF(updatedRowCount)
+ }.asNondeterministic()
+ val pathsToRewrite =
+ withStatusCode("DELTA", UpdateCommand.FINDING_TOUCHED_FILES_MSG) {
+ data.filter(new Column(updateCondition))
+ .select(input_file_name())
+ .filter(updatedRowUdf())
+ .distinct()
+ .as[String]
+ .collect()
+ }
+
+ scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000
+
+ pathsToRewrite.map(getTouchedFile(deltaLog.dataPath, _, nameToAddFile)).toSeq
+ }
+
+ numTouchedFiles = filesToRewrite.length
+
+ val newActions = if (filesToRewrite.isEmpty) {
+ // Do nothing if no row qualifies the UPDATE condition
+ Nil
+ } else {
+ // Generate the new files containing the updated values
+ withStatusCode("DELTA", UpdateCommand.rewritingFilesMsg(filesToRewrite.size)) {
+ rewriteFiles(sparkSession, txn, tahoeFileIndex.path,
+ filesToRewrite.map(_.path), nameToAddFile, updateCondition)
+ }
+ }
+
+ rewriteTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 - scanTimeMs
+
+ val (changeActions, addActions) = newActions.partition(_.isInstanceOf[AddCDCFile])
+ numRewrittenFiles = addActions.size
+ numAddedBytes = addActions.map(_.getFileSize).sum
+ numAddedChangeFiles = changeActions.size
+ changeFileBytes = changeActions.collect { case f: AddCDCFile => f.size }.sum
+
+ val totalActions = if (filesToRewrite.isEmpty) {
+ // Do nothing if no row qualifies the UPDATE condition
+ Nil
+ } else {
+ // Delete the old files and return those delete actions along with the new AddFile actions for
+ // files containing the updated values
+ val operationTimestamp = System.currentTimeMillis()
+ val deleteActions = filesToRewrite.map(_.removeWithTimestamp(operationTimestamp))
+
+ numRemovedBytes = filesToRewrite.map(_.getFileSize).sum
+ deleteActions ++ newActions
+ }
+
+ metrics("numAddedFiles").set(numRewrittenFiles)
+ metrics("numAddedBytes").set(numAddedBytes)
+ metrics("numAddedChangeFiles").set(numAddedChangeFiles)
+ metrics("changeFileBytes").set(changeFileBytes)
+ metrics("numRemovedFiles").set(numTouchedFiles)
+ metrics("numRemovedBytes").set(numRemovedBytes)
+ metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000)
+ metrics("scanTimeMs").set(scanTimeMs)
+ metrics("rewriteTimeMs").set(rewriteTimeMs)
+ // In the case where the numUpdatedRows is not captured, we can siphon out the metrics from
+ // the BasicWriteStatsTracker. This is for case 2 where the update condition contains only
+ // metadata predicates and so the entire partition is re-written.
+ val outputRows = txn.getMetric("numOutputRows").map(_.value).getOrElse(-1L)
+ if (metrics("numUpdatedRows").value == 0 && outputRows != 0 &&
+ metrics("numCopiedRows").value == 0) {
+ // We know that numTouchedRows = numCopiedRows + numUpdatedRows.
+ // Since an entire partition was re-written, no rows were copied.
+ // So numTouchedRows == numUpdateRows
+ metrics("numUpdatedRows").set(metrics("numTouchedRows").value)
+ } else {
+ // This is for case 3 where the update condition contains both metadata and data predicates
+ // so relevant files will have some rows updated and some rows copied. We don't need to
+ // consider case 1 here, where no files match the update condition, as we know that
+ // `totalActions` is empty.
+ metrics("numCopiedRows").set(
+ metrics("numTouchedRows").value - metrics("numUpdatedRows").value)
+ }
+ txn.registerSQLMetrics(sparkSession, metrics)
+ val finalActions = createSetTransaction(sparkSession, deltaLog).toSeq ++ totalActions
+ txn.commitIfNeeded(finalActions, DeltaOperations.Update(condition.map(_.toString)))
+ sendDriverMetrics(sparkSession, metrics)
+
+ recordDeltaEvent(
+ deltaLog,
+ "delta.dml.update.stats",
+ data = UpdateMetric(
+ condition = condition.map(_.sql).getOrElse("true"),
+ numFilesTotal,
+ numTouchedFiles,
+ numRewrittenFiles,
+ numAddedChangeFiles,
+ changeFileBytes,
+ scanTimeMs,
+ rewriteTimeMs)
+ )
+ }
+
+ /**
+ * Scan all the affected files and write out the updated files.
+ *
+ * When CDF is enabled, includes the generation of CDC preimage and postimage columns for
+ * changed rows.
+ *
+ * @return the list of [[AddFile]]s and [[AddCDCFile]]s that have been written.
+ */
+ private def rewriteFiles(
+ spark: SparkSession,
+ txn: OptimisticTransaction,
+ rootPath: Path,
+ inputLeafFiles: Seq[String],
+ nameToAddFileMap: Map[String, AddFile],
+ condition: Expression): Seq[FileAction] = {
+ // Containing the map from the relative file path to AddFile
+ val baseRelation = buildBaseRelation(
+ spark, txn, "update", rootPath, inputLeafFiles, nameToAddFileMap)
+ val newTarget = DeltaTableUtils.replaceFileIndex(target, baseRelation.location)
+ val targetDf = Dataset.ofRows(spark, newTarget)
+
+ // Number of total rows that we have seen, i.e. are either copying or updating (sum of both).
+ // This will be used later, along with numUpdatedRows, to determine numCopiedRows.
+ val numTouchedRows = metrics("numTouchedRows")
+ val numTouchedRowsUdf = DeltaUDF.boolean {
+ new GpuDeltaMetricUpdateUDF(numTouchedRows)
+ }.asNondeterministic()
+
+ val updatedDataFrame = UpdateCommand.withUpdatedColumns(
+ target,
+ updateExpressions,
+ condition,
+ targetDf
+ .filter(numTouchedRowsUdf())
+ .withColumn(UpdateCommand.CONDITION_COLUMN_NAME, new Column(condition)),
+ UpdateCommand.shouldOutputCdc(txn))
+
+ txn.writeFiles(updatedDataFrame)
+ }
+}
diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py
index f4696f3af71..4e6303c0b23 100644
--- a/integration_tests/src/main/python/window_function_test.py
+++ b/integration_tests/src/main/python/window_function_test.py
@@ -501,7 +501,8 @@ def test_window_running_no_part(b_gen, batch_size):
if isinstance(b_gen.data_type, NumericType) and not isinstance(b_gen, FloatGen) and not isinstance(b_gen, DoubleGen):
query_parts.append('sum(b) over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as sum_col')
- if spark_version() > "3.1.1":
+ # The option to IGNORE NULLS in NTH_VALUE is not available prior to Spark 3.2.1.
+ if spark_version() >= "3.2.1":
query_parts.append('NTH_VALUE(b, 1) IGNORE NULLS OVER '
'(ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS nth_1_ignore_nulls')
@@ -620,8 +621,8 @@ def test_window_running(b_gen, c_gen, batch_size):
if isinstance(c_gen.data_type, NumericType) and (not isinstance(c_gen, FloatGen)) and (not isinstance(c_gen, DoubleGen)) and (not isinstance(c_gen, DecimalGen)):
query_parts.append('sum(c) over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as sum_col')
- # The option to IGNORE NULLS in NTH_VALUE is not available prior to Spark 3.1.2.
- if spark_version() > "3.1.1":
+ # The option to IGNORE NULLS in NTH_VALUE is not available prior to Spark 3.2.1.
+ if spark_version() >= "3.2.1":
query_parts.append('NTH_VALUE(c, 1) IGNORE NULLS OVER '
'(PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS nth_1_ignore_nulls')
diff --git a/pom.xml b/pom.xml
index a4a72f7ef31..13d16667476 100644
--- a/pom.xml
+++ b/pom.xml
@@ -306,6 +306,7 @@
delta-lake/delta-21xdelta-lake/delta-22x
+ delta-lake/delta-23x
@@ -326,6 +327,7 @@
delta-lake/delta-21xdelta-lake/delta-22x
+ delta-lake/delta-23x
@@ -346,6 +348,7 @@
delta-lake/delta-21xdelta-lake/delta-22x
+ delta-lake/delta-23x
@@ -366,6 +369,7 @@
delta-lake/delta-21xdelta-lake/delta-22x
+ delta-lake/delta-23x
@@ -428,6 +432,7 @@
shim-deps/clouderadelta-lake/delta-21xdelta-lake/delta-22x
+ delta-lake/delta-23x
@@ -450,6 +455,7 @@
shim-deps/clouderadelta-lake/delta-21xdelta-lake/delta-22x
+ delta-lake/delta-23x
diff --git a/scala2.13/aggregator/pom.xml b/scala2.13/aggregator/pom.xml
index 1000c73c219..4b6aca7d716 100644
--- a/scala2.13/aggregator/pom.xml
+++ b/scala2.13/aggregator/pom.xml
@@ -385,6 +385,12 @@
${project.version}${spark.version.classifier}
+
+ com.nvidia
+ rapids-4-spark-delta-23x_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+
@@ -408,6 +414,12 @@
${project.version}${spark.version.classifier}
+
+ com.nvidia
+ rapids-4-spark-delta-23x_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+
@@ -431,6 +443,12 @@
${project.version}${spark.version.classifier}
+
+ com.nvidia
+ rapids-4-spark-delta-23x_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+
@@ -471,6 +489,12 @@
${project.version}${spark.version.classifier}
+
+ com.nvidia
+ rapids-4-spark-delta-23x_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+
@@ -494,6 +518,12 @@
${project.version}${spark.version.classifier}
+
+ com.nvidia
+ rapids-4-spark-delta-23x_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+
@@ -534,6 +564,12 @@
${project.version}${spark.version.classifier}
+
+ com.nvidia
+ rapids-4-spark-delta-23x_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+
diff --git a/scala2.13/delta-lake/delta-23x/pom.xml b/scala2.13/delta-lake/delta-23x/pom.xml
new file mode 100644
index 00000000000..6193d34ab44
--- /dev/null
+++ b/scala2.13/delta-lake/delta-23x/pom.xml
@@ -0,0 +1,98 @@
+
+
+
+ 4.0.0
+
+
+ com.nvidia
+ rapids-4-spark-parent_2.13
+ 23.12.0-SNAPSHOT
+ ../../pom.xml
+
+
+ rapids-4-spark-delta-23x_2.13
+ RAPIDS Accelerator for Apache Spark Delta Lake 2.3.x Support
+ Delta Lake 2.3.x support for the RAPIDS Accelerator for Apache Spark
+ 23.12.0-SNAPSHOT
+
+
+ ../delta-lake/delta-23x
+ false
+ **/*
+ package
+
+
+
+
+ com.nvidia
+ rapids-4-spark-sql_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+ provided
+
+
+ io.delta
+ delta-core_${scala.binary.version}
+ 2.3.0
+ provided
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+
+
+
+
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-common-sources
+ generate-sources
+
+ add-source
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ org.apache.rat
+ apache-rat-plugin
+
+
+
+
diff --git a/scala2.13/pom.xml b/scala2.13/pom.xml
index 69422e58a2c..31b70ec2cd6 100644
--- a/scala2.13/pom.xml
+++ b/scala2.13/pom.xml
@@ -306,6 +306,7 @@
delta-lake/delta-21xdelta-lake/delta-22x
+ delta-lake/delta-23x
@@ -326,6 +327,7 @@
delta-lake/delta-21xdelta-lake/delta-22x
+ delta-lake/delta-23x
@@ -346,6 +348,7 @@
delta-lake/delta-21xdelta-lake/delta-22x
+ delta-lake/delta-23x
@@ -366,6 +369,7 @@
delta-lake/delta-21xdelta-lake/delta-22x
+ delta-lake/delta-23x
@@ -428,6 +432,7 @@
shim-deps/clouderadelta-lake/delta-21xdelta-lake/delta-22x
+ delta-lake/delta-23x
@@ -450,6 +455,7 @@
shim-deps/clouderadelta-lake/delta-21xdelta-lake/delta-22x
+ delta-lake/delta-23x
diff --git a/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/RapidsStack.scala b/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/RapidsStack.scala
new file mode 100644
index 00000000000..e89485e752b
--- /dev/null
+++ b/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/RapidsStack.scala
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2019-2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids
+
+import scala.collection.mutable.ArrayStack
+
+class RapidsStack[T] extends Proxy {
+ private val stack = new ArrayStack[T]()
+
+ override def self = stack
+
+ def push(elem1: T): Unit = {
+ self.push(elem1)
+ }
+
+ def pop(): T = {
+ self.pop()
+ }
+
+ def isEmpty: Boolean = {
+ self.isEmpty
+ }
+
+ def nonEmpty: Boolean = {
+ self.nonEmpty
+ }
+
+ def size: Int = {
+ self.size
+ }
+
+ def toSeq: Seq[T] = {
+ self.toSeq
+ }
+
+ def clear(): Unit = {
+ self.clear()
+ }
+}
diff --git a/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/ScalaStack.scala b/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/RapidsStack.scala
similarity index 60%
rename from sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/ScalaStack.scala
rename to sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/RapidsStack.scala
index ad23ab8ffb4..9bdf35526a2 100644
--- a/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/ScalaStack.scala
+++ b/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/RapidsStack.scala
@@ -18,4 +18,36 @@ package com.nvidia.spark.rapids
import scala.collection.mutable.Stack
-class ScalaStack[T] extends Stack[T]
+class RapidsStack[T] extends Proxy {
+ private val stack = new Stack[T]()
+ override def self = stack
+
+ def push(elem1: T): RapidsStack[T] = {
+ self.push(elem1)
+ this
+ }
+
+ def pop(): T = {
+ self.pop()
+ }
+
+ def isEmpty: Boolean = {
+ self.isEmpty
+ }
+
+ def nonEmpty: Boolean = {
+ self.nonEmpty
+ }
+
+ def size(): Int = {
+ self.size
+ }
+
+ def toSeq(): Seq[T] = {
+ self.toSeq
+ }
+
+ def clear(): Unit = {
+ self.clear()
+ }
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala
index 76f6c1ed99d..77c3bbae7b3 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala
@@ -198,7 +198,7 @@ abstract class SplittableJoinIterator(
// If the join explodes this holds batches from the stream side split into smaller pieces.
private val pendingSplits = scala.collection.mutable.Queue[LazySpillableColumnarBatch]()
- protected def computeNumJoinRows(cb: ColumnarBatch): Long
+ protected def computeNumJoinRows(cb: LazySpillableColumnarBatch): Long
/**
* Create a join gatherer.
@@ -225,7 +225,7 @@ abstract class SplittableJoinIterator(
}
opTime.ns {
withResource(scb) { scb =>
- val numJoinRows = computeNumJoinRows(scb.getBatch)
+ val numJoinRows = computeNumJoinRows(scb)
// We want the gather maps size to be around the target size. There are two gather maps
// that are made up of ints, so compute how many rows on the stream side will produce the
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala
index aa7f4f2e48f..368c99a548a 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala
@@ -450,8 +450,8 @@ case class GpuOutOfCoreSortIterator(
while (!pending.isEmpty && sortedSize < targetSize) {
// Keep going until we have enough data to return
var bytesLeftToFetch = targetSize
- val pendingSort = new ScalaStack[SpillableColumnarBatch]()
- closeOnExcept(pendingSort) { _ =>
+ val pendingSort = new RapidsStack[SpillableColumnarBatch]()
+ closeOnExcept(pendingSort.toSeq) { _ =>
while (!pending.isEmpty &&
(bytesLeftToFetch - pending.peek().buffer.sizeInBytes >= 0 || pendingSort.isEmpty)) {
val buffer = pending.poll().buffer
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala
index 3a55ad94bfd..80ec540ff84 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala
@@ -19,6 +19,7 @@ package com.nvidia.spark.rapids
import ai.rapids.cudf.{ColumnVector, ColumnView, DeviceMemoryBuffer, DType, GatherMap, NvtxColor, NvtxRange, OrderByArg, OutOfBoundsPolicy, Scalar, Table}
import com.nvidia.spark.Retryable
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
+import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit
import org.apache.spark.TaskContext
import org.apache.spark.sql.types._
@@ -419,7 +420,7 @@ abstract class BaseCrossJoinGatherMap(leftCount: Int, rightCount: Int)
extends LazySpillableGatherMap {
override val getRowCount: Long = leftCount.toLong * rightCount.toLong
- override def toColumnView(startRow: Long, numRows: Int): ColumnView = {
+ override def toColumnView(startRow: Long, numRows: Int): ColumnView = withRetryNoSplit {
withResource(GpuScalar.from(startRow, LongType)) { startScalar =>
withResource(ai.rapids.cudf.ColumnVector.sequence(startScalar, numRows)) { rowNum =>
compute(rowNum)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SortUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SortUtils.scala
index 7b463da3bf5..ecc9e9cb462 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SortUtils.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SortUtils.scala
@@ -242,9 +242,9 @@ class GpuSorter(
* @return the sorted data.
*/
final def mergeSortAndCloseWithRetry(
- spillableBatches: ScalaStack[SpillableColumnarBatch],
+ spillableBatches: RapidsStack[SpillableColumnarBatch],
sortTime: GpuMetric): SpillableColumnarBatch = {
- closeOnExcept(spillableBatches) { _ =>
+ closeOnExcept(spillableBatches.toSeq) { _ =>
assert(spillableBatches.nonEmpty)
}
withResource(new NvtxWithMetrics("merge sort", NvtxColor.DARK_GREEN, sortTime)) { _ =>
@@ -277,9 +277,9 @@ class GpuSorter(
}
}
} else {
- closeOnExcept(spillableBatches) { _ =>
- val batchesToMerge = new ScalaStack[SpillableColumnarBatch]()
- closeOnExcept(batchesToMerge) { _ =>
+ closeOnExcept(spillableBatches.toSeq) { _ =>
+ val batchesToMerge = new RapidsStack[SpillableColumnarBatch]()
+ closeOnExcept(batchesToMerge.toSeq) { _ =>
while (spillableBatches.nonEmpty || batchesToMerge.size > 1) {
// pop a spillable batch if there is one, and add it to `batchesToMerge`.
if (spillableBatches.nonEmpty) {
@@ -299,7 +299,7 @@ class GpuSorter(
// we no longer care about the old batches, we closed them
closeOnExcept(merged) { _ =>
- batchesToMerge.safeClose()
+ batchesToMerge.toSeq.safeClose()
batchesToMerge.clear()
}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExecBase.scala
index 01d720d5186..e20c84b2b88 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExecBase.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExecBase.scala
@@ -178,20 +178,26 @@ class ConditionalNestedLoopJoinIterator(
}
}
- override def computeNumJoinRows(cb: ColumnarBatch): Long = {
- withResource(GpuColumnVector.from(builtBatch.getBatch)) { builtTable =>
- withResource(GpuColumnVector.from(cb)) { streamTable =>
- val (left, right) = buildSide match {
- case GpuBuildLeft => (builtTable, streamTable)
- case GpuBuildRight => (streamTable, builtTable)
- }
- joinType match {
- case _: InnerLike =>left.conditionalInnerJoinRowCount(right, condition)
- case LeftOuter => left.conditionalLeftJoinRowCount(right, condition)
- case RightOuter => right.conditionalLeftJoinRowCount(left, condition)
- case LeftSemi => left.conditionalLeftSemiJoinRowCount(right, condition)
- case LeftAnti => left.conditionalLeftAntiJoinRowCount(right, condition)
- case _ => throw new IllegalStateException(s"Unsupported join type $joinType")
+ override def computeNumJoinRows(scb: LazySpillableColumnarBatch): Long = {
+ scb.checkpoint()
+ builtBatch.checkpoint()
+ withRetryNoSplit {
+ withRestoreOnRetry(Seq(builtBatch, scb)) {
+ withResource(GpuColumnVector.from(builtBatch.getBatch)) { builtTable =>
+ withResource(GpuColumnVector.from(scb.getBatch)) { streamTable =>
+ val (left, right) = buildSide match {
+ case GpuBuildLeft => (builtTable, streamTable)
+ case GpuBuildRight => (streamTable, builtTable)
+ }
+ joinType match {
+ case _: InnerLike => left.conditionalInnerJoinRowCount(right, condition)
+ case LeftOuter => left.conditionalLeftJoinRowCount(right, condition)
+ case RightOuter => right.conditionalLeftJoinRowCount(left, condition)
+ case LeftSemi => left.conditionalLeftSemiJoinRowCount(right, condition)
+ case LeftAnti => left.conditionalLeftAntiJoinRowCount(right, condition)
+ case _ => throw new IllegalStateException(s"Unsupported join type $joinType")
+ }
+ }
}
}
}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala
index 234936769b3..cbaa1cbe47c 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala
@@ -292,14 +292,14 @@ abstract class BaseHashJoinIterator(
1.0
}
- override def computeNumJoinRows(cb: ColumnarBatch): Long = {
+ override def computeNumJoinRows(cb: LazySpillableColumnarBatch): Long = {
// TODO: Replace this estimate with exact join row counts using the corresponding cudf APIs
// being added in https://github.com/rapidsai/cudf/issues/9053.
joinType match {
// Full Outer join is implemented via LeftOuter/RightOuter, so use same estimate.
case _: InnerLike | LeftOuter | RightOuter | FullOuter =>
- Math.ceil(cb.numRows() * streamMagnificationFactor).toLong
- case _ => cb.numRows()
+ Math.ceil(cb.numRows * streamMagnificationFactor).toLong
+ case _ => cb.numRows
}
}