diff --git a/delta-lake/README.md b/delta-lake/README.md
index ff9d2553e31..f48c20fb7c3 100644
--- a/delta-lake/README.md
+++ b/delta-lake/README.md
@@ -19,6 +19,7 @@ and directory contains the corresponding support code.
| Databricks 10.4 | Databricks 10.4 | `delta-spark321db` |
| Databricks 11.3 | Databricks 11.3 | `delta-spark330db` |
| Databricks 12.2 | Databricks 12.2 | `delta-spark332db` |
+| Databricks 13.3 | Databricks 13.3 | `delta-spark341db` |
Delta Lake is not supported on all Spark versions, and for Spark versions where it is not
supported the `delta-stub` project is used.
diff --git a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala
index 083ebd09979..2ea31f0ae06 100644
--- a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala
+++ b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala
@@ -25,7 +25,7 @@ import com.databricks.sql.managedcatalog.UnityCatalogV2Proxy
import com.databricks.sql.transaction.tahoe.{DeltaLog, DeltaOptions, DeltaParquetFileFormat}
import com.databricks.sql.transaction.tahoe.catalog.{DeltaCatalog, DeltaTableV2}
import com.databricks.sql.transaction.tahoe.commands.{DeleteCommand, DeleteCommandEdge, MergeIntoCommand, MergeIntoCommandEdge, UpdateCommand, UpdateCommandEdge, WriteIntoDelta}
-import com.databricks.sql.transaction.tahoe.rapids.{GpuDeltaCatalog, GpuDeltaLog, GpuWriteIntoDelta}
+import com.databricks.sql.transaction.tahoe.rapids.{GpuDeltaLog, GpuWriteIntoDelta}
import com.databricks.sql.transaction.tahoe.sources.{DeltaDataSource, DeltaSourceUtils}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.delta.shims.DeltaLogShim
@@ -38,15 +38,15 @@ import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.{FileFormat, LogicalRelation, SaveIntoDataSourceCommand}
import org.apache.spark.sql.execution.datasources.v2.{AppendDataExecV1, AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec, OverwriteByExpressionExecV1}
-import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.ExternalSource
import org.apache.spark.sql.sources.{CreatableRelationProvider, InsertableRelation}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* Common implementation of the DeltaProvider interface for all Databricks versions.
*/
-object DatabricksDeltaProvider extends DeltaProviderImplBase {
+trait DatabricksDeltaProviderBase extends DeltaProviderImplBase {
override def getCreatableRelationRules: Map[Class[_ <: CreatableRelationProvider],
CreatableRelationProviderRule[_ <: CreatableRelationProvider]] = {
Seq(
@@ -116,6 +116,15 @@ object DatabricksDeltaProvider extends DeltaProviderImplBase {
catalogClass == classOf[DeltaCatalog] || catalogClass == classOf[UnityCatalogV2Proxy]
}
+ private def getWriteOptions(options: Any): Map[String, String] = {
+ // For Databricks 13.3 AtomicCreateTableAsSelectExec writeOptions is a Map[String, String]
+ // while in all the other versions it's a CaseInsensitiveMap
+ options match {
+ case c: CaseInsensitiveStringMap => c.asCaseSensitiveMap().asScala.toMap
+ case _ => options.asInstanceOf[Map[String, String]]
+ }
+ }
+
override def tagForGpu(
cpuExec: AtomicCreateTableAsSelectExec,
meta: AtomicCreateTableAsSelectExecMeta): Unit = {
@@ -131,22 +140,7 @@ object DatabricksDeltaProvider extends DeltaProviderImplBase {
meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider")
}
RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None,
- cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session)
- }
-
- override def convertToGpu(
- cpuExec: AtomicCreateTableAsSelectExec,
- meta: AtomicCreateTableAsSelectExecMeta): GpuExec = {
- GpuAtomicCreateTableAsSelectExec(
- cpuExec.output,
- new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
- cpuExec.ident,
- cpuExec.partitioning,
- cpuExec.plan,
- meta.childPlans.head.convertIfNeeded(),
- cpuExec.tableSpec,
- cpuExec.writeOptions,
- cpuExec.ifNotExists)
+ getWriteOptions(cpuExec.writeOptions), cpuExec.session)
}
override def tagForGpu(
@@ -164,23 +158,7 @@ object DatabricksDeltaProvider extends DeltaProviderImplBase {
meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider")
}
RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None,
- cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session)
- }
-
- override def convertToGpu(
- cpuExec: AtomicReplaceTableAsSelectExec,
- meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = {
- GpuAtomicReplaceTableAsSelectExec(
- cpuExec.output,
- new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
- cpuExec.ident,
- cpuExec.partitioning,
- cpuExec.plan,
- meta.childPlans.head.convertIfNeeded(),
- cpuExec.tableSpec,
- cpuExec.writeOptions,
- cpuExec.orCreate,
- cpuExec.invalidateCache)
+ getWriteOptions(cpuExec.writeOptions), cpuExec.session)
}
private case class DeltaWriteV1Config(
@@ -360,13 +338,4 @@ class DeltaCreatableRelationProviderMeta(
}
override def convertToGpu(): GpuCreatableRelationProvider = new GpuDeltaDataSource(conf)
-}
-
-/**
- * Implements the Delta Probe interface for probing the Delta Lake provider on Databricks.
- * @note This is instantiated via reflection from ShimLoader.
- */
-class DeltaProbeImpl extends DeltaProbe {
- // Delta Lake is built-in for Databricks instances, so no probing is necessary.
- override def getDeltaProvider: DeltaProvider = DatabricksDeltaProvider
-}
+}
\ No newline at end of file
diff --git a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala
index 467f16986aa..3968c48bff0 100644
--- a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala
+++ b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022, NVIDIA CORPORATION.
+ * Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@ package com.nvidia.spark.rapids.delta.shims
import com.databricks.sql.expressions.JoinedProjection
import com.databricks.sql.transaction.tahoe.DeltaColumnMapping
-import com.databricks.sql.transaction.tahoe.stats.UsesMetadataFields
import com.databricks.sql.transaction.tahoe.util.JsonUtils
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
@@ -47,6 +46,4 @@ object ShimJoinedProjection {
object ShimJsonUtils {
def fromJson[T: Manifest](json: String): T = JsonUtils.fromJson[T](json)
-}
-
-trait ShimUsesMetadataFields extends UsesMetadataFields
+}
\ No newline at end of file
diff --git a/delta-lake/common/src/main/scala/org/apache/spark/sql/rapids/delta/DeltaShufflePartitionsUtil.scala b/delta-lake/common/src/main/scala/org/apache/spark/sql/rapids/delta/DeltaShufflePartitionsUtil.scala
index c573819f31e..395757f0a30 100644
--- a/delta-lake/common/src/main/scala/org/apache/spark/sql/rapids/delta/DeltaShufflePartitionsUtil.scala
+++ b/delta-lake/common/src/main/scala/org/apache/spark/sql/rapids/delta/DeltaShufflePartitionsUtil.scala
@@ -179,9 +179,8 @@ object DeltaShufflePartitionsUtil {
c.child
case _ => p
}
- case ShuffleExchangeExec(_, child, shuffleOrigin)
- if !shuffleOrigin.equals(ENSURE_REQUIREMENTS) =>
- child
+ case s: ShuffleExchangeExec if !s.shuffleOrigin.equals(ENSURE_REQUIREMENTS) =>
+ s.child
case CoalesceExec(_, child) =>
child
case _ =>
diff --git a/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala b/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala
new file mode 100644
index 00000000000..2194522ab82
--- /dev/null
+++ b/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids.delta
+
+/**
+ * Implements the Delta Probe interface for probing the Delta Lake provider on Databricks.
+ * @note This is instantiated via reflection from ShimLoader.
+ */
+class DeltaProbeImpl extends DeltaProbe {
+ // Delta Lake is built-in for Databricks instances, so no probing is necessary.
+ override def getDeltaProvider: DeltaProvider = DeltaSpark321DBProvider
+}
diff --git a/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark321DBProvider.scala b/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark321DBProvider.scala
new file mode 100644
index 00000000000..44e5721bafc
--- /dev/null
+++ b/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark321DBProvider.scala
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta
+
+import com.databricks.sql.transaction.tahoe.rapids.GpuDeltaCatalog
+import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec}
+
+import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
+import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec}
+
+object DeltaSpark321DBProvider extends DatabricksDeltaProviderBase {
+
+ override def convertToGpu(
+ cpuExec: AtomicCreateTableAsSelectExec,
+ meta: AtomicCreateTableAsSelectExecMeta): GpuExec = {
+ GpuAtomicCreateTableAsSelectExec(
+ cpuExec.output,
+ new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
+ cpuExec.ident,
+ cpuExec.partitioning,
+ cpuExec.plan,
+ meta.childPlans.head.convertIfNeeded(),
+ cpuExec.tableSpec,
+ cpuExec.writeOptions,
+ cpuExec.ifNotExists)
+ }
+
+ override def convertToGpu(
+ cpuExec: AtomicReplaceTableAsSelectExec,
+ meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = {
+ GpuAtomicReplaceTableAsSelectExec(
+ cpuExec.output,
+ new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
+ cpuExec.ident,
+ cpuExec.partitioning,
+ cpuExec.plan,
+ meta.childPlans.head.convertIfNeeded(),
+ cpuExec.tableSpec,
+ cpuExec.writeOptions,
+ cpuExec.orCreate,
+ cpuExec.invalidateCache)
+ }
+}
diff --git a/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala b/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala
new file mode 100644
index 00000000000..f722837778a
--- /dev/null
+++ b/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala
@@ -0,0 +1,20 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids.delta.shims
+
+import com.databricks.sql.transaction.tahoe.stats.UsesMetadataFields
+
+trait ShimUsesMetadataFields extends UsesMetadataFields
diff --git a/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala b/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala
new file mode 100644
index 00000000000..01a386c3769
--- /dev/null
+++ b/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids.delta
+
+/**
+ * Implements the Delta Probe interface for probing the Delta Lake provider on Databricks.
+ * @note This is instantiated via reflection from ShimLoader.
+ */
+class DeltaProbeImpl extends DeltaProbe {
+ // Delta Lake is built-in for Databricks instances, so no probing is necessary.
+ override def getDeltaProvider: DeltaProvider = DeltaSpark330DBProvider
+}
diff --git a/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark330DBProvider.scala b/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark330DBProvider.scala
new file mode 100644
index 00000000000..c7a5cf1db27
--- /dev/null
+++ b/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark330DBProvider.scala
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta
+
+import com.databricks.sql.transaction.tahoe.rapids.GpuDeltaCatalog
+import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec}
+
+import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
+import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec}
+
+object DeltaSpark330DBProvider extends DatabricksDeltaProviderBase {
+
+ override def convertToGpu(
+ cpuExec: AtomicCreateTableAsSelectExec,
+ meta: AtomicCreateTableAsSelectExecMeta): GpuExec = {
+ GpuAtomicCreateTableAsSelectExec(
+ cpuExec.output,
+ new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
+ cpuExec.ident,
+ cpuExec.partitioning,
+ cpuExec.plan,
+ meta.childPlans.head.convertIfNeeded(),
+ cpuExec.tableSpec,
+ cpuExec.writeOptions,
+ cpuExec.ifNotExists)
+ }
+
+ override def convertToGpu(
+ cpuExec: AtomicReplaceTableAsSelectExec,
+ meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = {
+ GpuAtomicReplaceTableAsSelectExec(
+ cpuExec.output,
+ new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
+ cpuExec.ident,
+ cpuExec.partitioning,
+ cpuExec.plan,
+ meta.childPlans.head.convertIfNeeded(),
+ cpuExec.tableSpec,
+ cpuExec.writeOptions,
+ cpuExec.orCreate,
+ cpuExec.invalidateCache)
+ }
+}
diff --git a/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala b/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala
new file mode 100644
index 00000000000..f722837778a
--- /dev/null
+++ b/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala
@@ -0,0 +1,20 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids.delta.shims
+
+import com.databricks.sql.transaction.tahoe.stats.UsesMetadataFields
+
+trait ShimUsesMetadataFields extends UsesMetadataFields
diff --git a/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala b/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala
new file mode 100644
index 00000000000..e18cbcff2d3
--- /dev/null
+++ b/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids.delta
+
+/**
+ * Implements the Delta Probe interface for probing the Delta Lake provider on Databricks.
+ * @note This is instantiated via reflection from ShimLoader.
+ */
+class DeltaProbeImpl extends DeltaProbe {
+ // Delta Lake is built-in for Databricks instances, so no probing is necessary.
+ override def getDeltaProvider: DeltaProvider = DeltaSpark332DBProvider
+}
diff --git a/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark332DBProvider.scala b/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark332DBProvider.scala
new file mode 100644
index 00000000000..283ad44fb30
--- /dev/null
+++ b/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark332DBProvider.scala
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta
+
+import com.databricks.sql.transaction.tahoe.rapids.GpuDeltaCatalog
+import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec}
+
+import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
+import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec}
+
+object DeltaSpark332DBProvider extends DatabricksDeltaProviderBase {
+
+ override def convertToGpu(
+ cpuExec: AtomicCreateTableAsSelectExec,
+ meta: AtomicCreateTableAsSelectExecMeta): GpuExec = {
+ GpuAtomicCreateTableAsSelectExec(
+ cpuExec.output,
+ new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
+ cpuExec.ident,
+ cpuExec.partitioning,
+ cpuExec.plan,
+ meta.childPlans.head.convertIfNeeded(),
+ cpuExec.tableSpec,
+ cpuExec.writeOptions,
+ cpuExec.ifNotExists)
+ }
+
+ override def convertToGpu(
+ cpuExec: AtomicReplaceTableAsSelectExec,
+ meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = {
+ GpuAtomicReplaceTableAsSelectExec(
+ cpuExec.output,
+ new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
+ cpuExec.ident,
+ cpuExec.partitioning,
+ cpuExec.plan,
+ meta.childPlans.head.convertIfNeeded(),
+ cpuExec.tableSpec,
+ cpuExec.writeOptions,
+ cpuExec.orCreate,
+ cpuExec.invalidateCache)
+ }
+}
diff --git a/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala b/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala
new file mode 100644
index 00000000000..f722837778a
--- /dev/null
+++ b/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala
@@ -0,0 +1,20 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids.delta.shims
+
+import com.databricks.sql.transaction.tahoe.stats.UsesMetadataFields
+
+trait ShimUsesMetadataFields extends UsesMetadataFields
diff --git a/delta-lake/delta-spark341db/pom.xml b/delta-lake/delta-spark341db/pom.xml
new file mode 100644
index 00000000000..64e920eb8f1
--- /dev/null
+++ b/delta-lake/delta-spark341db/pom.xml
@@ -0,0 +1,296 @@
+
+
+
+ 4.0.0
+
+
+ com.nvidia
+ rapids-4-spark-jdk-profiles_2.12
+ 23.12.0-SNAPSHOT
+ ../../jdk-profiles/pom.xml
+
+
+ rapids-4-spark-delta-spark341db_2.12
+ RAPIDS Accelerator for Apache Spark Databricks 13.3 Delta Lake Support
+ Databricks 13.3 Delta Lake support for the RAPIDS Accelerator for Apache Spark
+ 23.12.0-SNAPSHOT
+
+
+ false
+ **/*
+ package
+
+
+
+
+ com.nvidia
+ rapids-4-spark-sql_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+ provided
+
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-annotation_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-network-common_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-network-shuffle_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-launcher_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-unsafe_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.avro
+ avro-mapred
+ ${spark.version}
+ provided
+
+
+ org.apache.avro
+ avro
+ ${spark.version}
+ provided
+
+
+ org.apache.hive
+ hive-exec
+ ${spark.version}
+ provided
+
+
+ org.apache.hive
+ hive-serde
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-hive_${scala.binary.version}
+
+
+ com.fasterxml.jackson.core
+ jackson-core
+ ${spark.version}
+ provided
+
+
+ com.fasterxml.jackson.core
+ jackson-annotations
+ ${spark.version}
+ provided
+
+
+ org.json4s
+ json4s-ast_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.json4s
+ json4s-core_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.commons
+ commons-io
+ ${spark.version}
+ provided
+
+
+ org.scala-lang
+ scala-reflect
+ ${scala.version}
+ provided
+
+
+ org.apache.commons
+ commons-lang3
+ ${spark.version}
+ provided
+
+
+ com.esotericsoftware.kryo
+ kryo-shaded-db
+ ${spark.version}
+ provided
+
+
+ org.apache.parquet
+ parquet-hadoop
+ ${spark.version}
+ provided
+
+
+ org.apache.parquet
+ parquet-common
+ ${spark.version}
+ provided
+
+
+ org.apache.parquet
+ parquet-column
+ ${spark.version}
+ provided
+
+
+ org.apache.parquet
+ parquet-format
+ ${spark.version}
+ provided
+
+
+ org.apache.arrow
+ arrow-memory
+ ${spark.version}
+ provided
+
+
+ org.apache.arrow
+ arrow-vector
+ ${spark.version}
+ provided
+
+
+ org.apache.hadoop
+ hadoop-client
+ ${hadoop.client.version}
+ provided
+
+
+ org.apache.orc
+ orc-core
+ ${spark.version}
+ provided
+
+
+ org.apache.orc
+ orc-shims
+ ${spark.version}
+ provided
+
+
+ org.apache.orc
+ orc-mapreduce
+ ${spark.version}
+ provided
+
+
+ org.apache.hive
+ hive-storage-api
+ ${spark.version}
+ provided
+
+
+ com.google.protobuf
+ protobuf-java
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-common-utils_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-sql-api_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+
+
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-common-sources
+ generate-sources
+
+ add-source
+
+
+
+
+
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ org.apache.rat
+ apache-rat-plugin
+
+
+
+
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala
new file mode 100644
index 00000000000..320485eb1ee
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala
@@ -0,0 +1,464 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from CreateDeltaTableCommand.scala in the
+ * Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.databricks.sql.transaction.tahoe.rapids
+
+import com.databricks.sql.transaction.tahoe._
+import com.databricks.sql.transaction.tahoe.actions.Metadata
+import com.databricks.sql.transaction.tahoe.commands.{TableCreationModes, WriteIntoDelta}
+import com.databricks.sql.transaction.tahoe.metering.DeltaLogging
+import com.databricks.sql.transaction.tahoe.schema.SchemaUtils
+import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf
+import com.nvidia.spark.rapids.RapidsConf
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.execution.command.{LeafRunnableCommand, RunnableCommand}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Single entry point for all write or declaration operations for Delta tables accessed through
+ * the table name.
+ *
+ * @param table The table identifier for the Delta table
+ * @param existingTableOpt The existing table for the same identifier if exists
+ * @param mode The save mode when writing data. Relevant when the query is empty or set to Ignore
+ * with `CREATE TABLE IF NOT EXISTS`.
+ * @param query The query to commit into the Delta table if it exist. This can come from
+ * - CTAS
+ * - saveAsTable
+ */
+case class GpuCreateDeltaTableCommand(
+ table: CatalogTable,
+ existingTableOpt: Option[CatalogTable],
+ mode: SaveMode,
+ query: Option[LogicalPlan],
+ operation: TableCreationModes.CreationMode = TableCreationModes.Create,
+ tableByPath: Boolean = false,
+ override val output: Seq[Attribute] = Nil)(@transient rapidsConf: RapidsConf)
+ extends LeafRunnableCommand
+ with DeltaLogging {
+
+ override def otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf)
+
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val table = this.table
+
+ assert(table.tableType != CatalogTableType.VIEW)
+ assert(table.identifier.database.isDefined, "Database should've been fixed at analysis")
+ // There is a subtle race condition here, where the table can be created by someone else
+ // while this command is running. Nothing we can do about that though :(
+ val tableExists = existingTableOpt.isDefined
+ if (mode == SaveMode.Ignore && tableExists) {
+ // Early exit on ignore
+ return Nil
+ } else if (mode == SaveMode.ErrorIfExists && tableExists) {
+ throw DeltaErrors.tableAlreadyExists(table)
+ }
+
+ val tableWithLocation = if (tableExists) {
+ val existingTable = existingTableOpt.get
+ table.storage.locationUri match {
+ case Some(location) if location.getPath != existingTable.location.getPath =>
+ throw DeltaErrors.tableLocationMismatch(table, existingTable)
+ case _ =>
+ }
+ table.copy(
+ storage = existingTable.storage,
+ tableType = existingTable.tableType)
+ } else if (table.storage.locationUri.isEmpty) {
+ // We are defining a new managed table
+ assert(table.tableType == CatalogTableType.MANAGED)
+ val loc = sparkSession.sessionState.catalog.defaultTablePath(table.identifier)
+ table.copy(storage = table.storage.copy(locationUri = Some(loc)))
+ } else {
+ // 1. We are defining a new external table
+ // 2. It's a managed table which already has the location populated. This can happen in DSV2
+ // CTAS flow.
+ table
+ }
+
+ val isManagedTable = tableWithLocation.tableType == CatalogTableType.MANAGED
+ val tableLocation = new Path(tableWithLocation.location)
+ val gpuDeltaLog = GpuDeltaLog.forTable(sparkSession, tableLocation, rapidsConf)
+ val hadoopConf = gpuDeltaLog.deltaLog.newDeltaHadoopConf()
+ val fs = tableLocation.getFileSystem(hadoopConf)
+ val options = new DeltaOptions(table.storage.properties, sparkSession.sessionState.conf)
+ var result: Seq[Row] = Nil
+
+ recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.ddl.createTable") {
+ val txn = gpuDeltaLog.startTransaction()
+ val opStartTs = System.currentTimeMillis()
+ if (query.isDefined) {
+ // If the mode is Ignore or ErrorIfExists, the table must not exist, or we would return
+ // earlier. And the data should not exist either, to match the behavior of
+ // Ignore/ErrorIfExists mode. This means the table path should not exist or is empty.
+ if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) {
+ assert(!tableExists)
+ // We may have failed a previous write. The retry should still succeed even if we have
+ // garbage data
+ if (txn.readVersion > -1 || !fs.exists(gpuDeltaLog.deltaLog.logPath)) {
+ assertPathEmpty(hadoopConf, tableWithLocation)
+ }
+ }
+ // We are either appending/overwriting with saveAsTable or creating a new table with CTAS or
+ // we are creating a table as part of a RunnableCommand
+ query.get match {
+ case writer: WriteIntoDelta =>
+ // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that
+ // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1.
+ if (!isV1Writer) {
+ replaceMetadataIfNecessary(
+ txn, tableWithLocation, options, writer.data.schema.asNullable)
+ }
+ val actions = writer.write(txn, sparkSession)
+ val op = getOperation(txn.metadata, isManagedTable, Some(options))
+ txn.commit(actions, op)
+ case cmd: RunnableCommand =>
+ result = cmd.run(sparkSession)
+ case other =>
+ // When using V1 APIs, the `other` plan is not yet optimized, therefore, it is safe
+ // to once again go through analysis
+ val data = Dataset.ofRows(sparkSession, other)
+
+ // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that
+ // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1.
+ if (!isV1Writer) {
+ replaceMetadataIfNecessary(
+ txn, tableWithLocation, options, other.schema.asNullable)
+ }
+
+ val actions = WriteIntoDelta(
+ deltaLog = gpuDeltaLog.deltaLog,
+ mode = mode,
+ options,
+ partitionColumns = table.partitionColumnNames,
+ configuration = tableWithLocation.properties + ("comment" -> table.comment.orNull),
+ data = data).write(txn, sparkSession)
+
+ val op = getOperation(txn.metadata, isManagedTable, Some(options))
+ txn.commit(actions, op)
+ }
+ } else {
+ def createTransactionLogOrVerify(): Unit = {
+ if (isManagedTable) {
+ // When creating a managed table, the table path should not exist or is empty, or
+ // users would be surprised to see the data, or see the data directory being dropped
+ // after the table is dropped.
+ assertPathEmpty(hadoopConf, tableWithLocation)
+ }
+
+ // This is either a new table, or, we never defined the schema of the table. While it is
+ // unexpected that `txn.metadata.schema` to be empty when txn.readVersion >= 0, we still
+ // guard against it, in case of checkpoint corruption bugs.
+ val noExistingMetadata = txn.readVersion == -1 || txn.metadata.schema.isEmpty
+ if (noExistingMetadata) {
+ assertTableSchemaDefined(fs, tableLocation, tableWithLocation, txn, sparkSession)
+ assertPathEmpty(hadoopConf, tableWithLocation)
+ // This is a user provided schema.
+ // Doesn't come from a query, Follow nullability invariants.
+ val newMetadata = getProvidedMetadata(tableWithLocation, table.schema.json)
+ txn.updateMetadataForNewTable(newMetadata)
+
+ val op = getOperation(newMetadata, isManagedTable, None)
+ txn.commit(Nil, op)
+ } else {
+ verifyTableMetadata(txn, tableWithLocation)
+ }
+ }
+ // We are defining a table using the Create or Replace Table statements.
+ operation match {
+ case TableCreationModes.Create =>
+ require(!tableExists, "Can't recreate a table when it exists")
+ createTransactionLogOrVerify()
+
+ case TableCreationModes.CreateOrReplace if !tableExists =>
+ // If the table doesn't exist, CREATE OR REPLACE must provide a schema
+ if (tableWithLocation.schema.isEmpty) {
+ throw DeltaErrors.schemaNotProvidedException
+ }
+ createTransactionLogOrVerify()
+ case _ =>
+ // When the operation is a REPLACE or CREATE OR REPLACE, then the schema shouldn't be
+ // empty, since we'll use the entry to replace the schema
+ if (tableWithLocation.schema.isEmpty) {
+ throw DeltaErrors.schemaNotProvidedException
+ }
+ // We need to replace
+ replaceMetadataIfNecessary(txn, tableWithLocation, options, tableWithLocation.schema)
+ // Truncate the table
+ val operationTimestamp = System.currentTimeMillis()
+ val removes = txn.filterFiles().map(_.removeWithTimestamp(operationTimestamp))
+ val op = getOperation(txn.metadata, isManagedTable, None)
+ txn.commit(removes, op)
+ }
+ }
+
+ // We would have failed earlier on if we couldn't ignore the existence of the table
+ // In addition, we just might using saveAsTable to append to the table, so ignore the creation
+ // if it already exists.
+ // Note that someone may have dropped and recreated the table in a separate location in the
+ // meantime... Unfortunately we can't do anything there at the moment, because Hive sucks.
+ logInfo(s"Table is path-based table: $tableByPath. Update catalog with mode: $operation")
+ updateCatalog(
+ sparkSession,
+ tableWithLocation,
+ gpuDeltaLog.deltaLog.update(checkIfUpdatedSinceTs = Some(opStartTs)),
+ txn)
+
+ result
+ }
+ }
+
+ private def getProvidedMetadata(table: CatalogTable, schemaString: String): Metadata = {
+ Metadata(
+ description = table.comment.orNull,
+ schemaString = schemaString,
+ partitionColumns = table.partitionColumnNames,
+ configuration = table.properties,
+ createdTime = Some(System.currentTimeMillis()))
+ }
+
+ private def assertPathEmpty(
+ hadoopConf: Configuration,
+ tableWithLocation: CatalogTable): Unit = {
+ val path = new Path(tableWithLocation.location)
+ val fs = path.getFileSystem(hadoopConf)
+ // Verify that the table location associated with CREATE TABLE doesn't have any data. Note that
+ // we intentionally diverge from this behavior w.r.t regular datasource tables (that silently
+ // overwrite any previous data)
+ if (fs.exists(path) && fs.listStatus(path).nonEmpty) {
+ throw DeltaErrors.createTableWithNonEmptyLocation(
+ tableWithLocation.identifier.toString,
+ tableWithLocation.location.toString)
+ }
+ }
+
+ private def assertTableSchemaDefined(
+ fs: FileSystem,
+ path: Path,
+ table: CatalogTable,
+ txn: OptimisticTransaction,
+ sparkSession: SparkSession): Unit = {
+ // If we allow creating an empty schema table and indeed the table is new, we just need to
+ // make sure:
+ // 1. txn.readVersion == -1 to read a new table
+ // 2. for external tables: path must either doesn't exist or is completely empty
+ val allowCreatingTableWithEmptySchema = sparkSession.sessionState
+ .conf.getConf(DeltaSQLConf.DELTA_ALLOW_CREATE_EMPTY_SCHEMA_TABLE) && txn.readVersion == -1
+
+ // Users did not specify the schema. We expect the schema exists in Delta.
+ if (table.schema.isEmpty) {
+ if (table.tableType == CatalogTableType.EXTERNAL) {
+ if (fs.exists(path) && fs.listStatus(path).nonEmpty) {
+ throw DeltaErrors.createExternalTableWithoutLogException(
+ path, table.identifier.quotedString, sparkSession)
+ } else {
+ if (allowCreatingTableWithEmptySchema) return
+ throw DeltaErrors.createExternalTableWithoutSchemaException(
+ path, table.identifier.quotedString, sparkSession)
+ }
+ } else {
+ if (allowCreatingTableWithEmptySchema) return
+ throw DeltaErrors.createManagedTableWithoutSchemaException(
+ table.identifier.quotedString, sparkSession)
+ }
+ }
+ }
+
+ /**
+ * Verify against our transaction metadata that the user specified the right metadata for the
+ * table.
+ */
+ private def verifyTableMetadata(
+ txn: OptimisticTransaction,
+ tableDesc: CatalogTable): Unit = {
+ val existingMetadata = txn.metadata
+ val path = new Path(tableDesc.location)
+
+ // The delta log already exists. If they give any configuration, we'll make sure it all matches.
+ // Otherwise we'll just go with the metadata already present in the log.
+ // The schema compatibility checks will be made in `WriteIntoDelta` for CreateTable
+ // with a query
+ if (txn.readVersion > -1) {
+ if (tableDesc.schema.nonEmpty) {
+ // We check exact alignment on create table if everything is provided
+ // However, if in column mapping mode, we can safely ignore the related metadata fields in
+ // existing metadata because new table desc will not have related metadata assigned yet
+ val differences = SchemaUtils.reportDifferences(
+ DeltaColumnMapping.dropColumnMappingMetadata(existingMetadata.schema),
+ tableDesc.schema)
+ if (differences.nonEmpty) {
+ throw DeltaErrors.createTableWithDifferentSchemaException(
+ path, tableDesc.schema, existingMetadata.schema, differences)
+ }
+ }
+
+ // If schema is specified, we must make sure the partitioning matches, even the partitioning
+ // is not specified.
+ if (tableDesc.schema.nonEmpty &&
+ tableDesc.partitionColumnNames != existingMetadata.partitionColumns) {
+ throw DeltaErrors.createTableWithDifferentPartitioningException(
+ path, tableDesc.partitionColumnNames, existingMetadata.partitionColumns)
+ }
+
+ if (tableDesc.properties.nonEmpty && tableDesc.properties != existingMetadata.configuration) {
+ throw DeltaErrors.createTableWithDifferentPropertiesException(
+ path, tableDesc.properties, existingMetadata.configuration)
+ }
+ }
+ }
+
+ /**
+ * Based on the table creation operation, and parameters, we can resolve to different operations.
+ * A lot of this is needed for legacy reasons in Databricks Runtime.
+ * @param metadata The table metadata, which we are creating or replacing
+ * @param isManagedTable Whether we are creating or replacing a managed table
+ * @param options Write options, if this was a CTAS/RTAS
+ */
+ private def getOperation(
+ metadata: Metadata,
+ isManagedTable: Boolean,
+ options: Option[DeltaOptions]): DeltaOperations.Operation = operation match {
+ // This is legacy saveAsTable behavior in Databricks Runtime
+ case TableCreationModes.Create if existingTableOpt.isDefined && query.isDefined =>
+ DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere,
+ options.flatMap(_.userMetadata))
+
+ // DataSourceV2 table creation
+ // CREATE TABLE (non-DataFrameWriter API) doesn't have options syntax
+ // (userMetadata uses SQLConf in this case)
+ case TableCreationModes.Create =>
+ DeltaOperations.CreateTable(metadata, isManagedTable, query.isDefined)
+
+ // DataSourceV2 table replace
+ // REPLACE TABLE (non-DataFrameWriter API) doesn't have options syntax
+ // (userMetadata uses SQLConf in this case)
+ case TableCreationModes.Replace =>
+ DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = false, query.isDefined)
+
+ // Legacy saveAsTable with Overwrite mode
+ case TableCreationModes.CreateOrReplace if options.exists(_.replaceWhere.isDefined) =>
+ DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere,
+ options.flatMap(_.userMetadata))
+
+ // New DataSourceV2 saveAsTable with overwrite mode behavior
+ case TableCreationModes.CreateOrReplace =>
+ DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = true, query.isDefined,
+ options.flatMap(_.userMetadata))
+ }
+
+ /**
+ * Similar to getOperation, here we disambiguate the catalog alterations we need to do based
+ * on the table operation, and whether we have reached here through legacy code or DataSourceV2
+ * code paths.
+ */
+ private def updateCatalog(
+ spark: SparkSession,
+ table: CatalogTable,
+ snapshot: Snapshot,
+ txn: OptimisticTransaction): Unit = {
+ val cleaned = cleanupTableDefinition(table, snapshot)
+ operation match {
+ case _ if tableByPath => // do nothing with the metastore if this is by path
+ case TableCreationModes.Create =>
+ spark.sessionState.catalog.createTable(
+ cleaned,
+ ignoreIfExists = existingTableOpt.isDefined,
+ validateLocation = false)
+ case TableCreationModes.Replace | TableCreationModes.CreateOrReplace
+ if existingTableOpt.isDefined =>
+ spark.sessionState.catalog.alterTable(table)
+ case TableCreationModes.Replace =>
+ val ident = Identifier.of(table.identifier.database.toArray, table.identifier.table)
+ throw DeltaErrors.cannotReplaceMissingTableException(ident)
+ case TableCreationModes.CreateOrReplace =>
+ spark.sessionState.catalog.createTable(
+ cleaned,
+ ignoreIfExists = false,
+ validateLocation = false)
+ }
+ }
+
+ /** Clean up the information we pass on to store in the catalog. */
+ private def cleanupTableDefinition(table: CatalogTable, snapshot: Snapshot): CatalogTable = {
+ // These actually have no effect on the usability of Delta, but feature flagging legacy
+ // behavior for now
+ val storageProps = if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) {
+ // Legacy behavior
+ table.storage
+ } else {
+ table.storage.copy(properties = Map.empty)
+ }
+
+ table.copy(
+ schema = new StructType(),
+ properties = Map.empty,
+ partitionColumnNames = Nil,
+ // Remove write specific options when updating the catalog
+ storage = storageProps,
+ tracksPartitionsInCatalog = true)
+ }
+
+ /**
+ * With DataFrameWriterV2, methods like `replace()` or `createOrReplace()` mean that the
+ * metadata of the table should be replaced. If overwriteSchema=false is provided with these
+ * methods, then we will verify that the metadata match exactly.
+ */
+ private def replaceMetadataIfNecessary(
+ txn: OptimisticTransaction,
+ tableDesc: CatalogTable,
+ options: DeltaOptions,
+ schema: StructType): Unit = {
+ val isReplace = (operation == TableCreationModes.CreateOrReplace ||
+ operation == TableCreationModes.Replace)
+ // If a user explicitly specifies not to overwrite the schema, during a replace, we should
+ // tell them that it's not supported
+ val dontOverwriteSchema = options.options.contains(DeltaOptions.OVERWRITE_SCHEMA_OPTION) &&
+ !options.canOverwriteSchema
+ if (isReplace && dontOverwriteSchema) {
+ throw DeltaErrors.illegalUsageException(DeltaOptions.OVERWRITE_SCHEMA_OPTION, "replacing")
+ }
+ if (txn.readVersion > -1L && isReplace && !dontOverwriteSchema) {
+ // When a table already exists, and we're using the DataFrameWriterV2 API to replace
+ // or createOrReplace a table, we blindly overwrite the metadata.
+ txn.updateMetadataForNewTable(getProvidedMetadata(table, schema.json))
+ }
+ }
+
+ /**
+ * Horrible hack to differentiate between DataFrameWriterV1 and V2 so that we can decide
+ * what to do with table metadata. In DataFrameWriterV1, mode("overwrite").saveAsTable,
+ * behaves as a CreateOrReplace table, but we have asked for "overwriteSchema" as an
+ * explicit option to overwrite partitioning or schema information. With DataFrameWriterV2,
+ * the behavior asked for by the user is clearer: .createOrReplace(), which means that we
+ * should overwrite schema and/or partitioning. Therefore we have this hack.
+ */
+ private def isV1Writer: Boolean = {
+ Thread.currentThread().getStackTrace.exists(_.toString.contains(
+ classOf[DataFrameWriter[_]].getCanonicalName + "."))
+ }
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeleteCommand.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeleteCommand.scala
new file mode 100644
index 00000000000..97ae9dc46ab
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeleteCommand.scala
@@ -0,0 +1,369 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from DeleteCommand.scala
+ * in the Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.databricks.sql.transaction.tahoe.rapids
+
+import com.databricks.sql.transaction.tahoe.{DeltaConfigs, DeltaLog, DeltaOperations, DeltaTableUtils, DeltaUDF, OptimisticTransaction}
+import com.databricks.sql.transaction.tahoe.DeltaCommitTag._
+import com.databricks.sql.transaction.tahoe.RowTracking
+import com.databricks.sql.transaction.tahoe.actions.{AddCDCFile, FileAction}
+import com.databricks.sql.transaction.tahoe.commands.{DeleteCommandMetrics, DeleteMetric, DeltaCommand, DMLUtils}
+import com.databricks.sql.transaction.tahoe.commands.MergeIntoCommandBase.totalBytesAndDistinctPartitionValues
+import com.databricks.sql.transaction.tahoe.files.TahoeBatchFileIndex
+import com.databricks.sql.transaction.tahoe.rapids.GpuDeleteCommand.{rewritingFilesMsg, FINDING_TOUCHED_FILES_MSG}
+import com.nvidia.spark.rapids.delta.GpuDeltaMetricUpdateUDF
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, Not}
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.execution.command.LeafRunnableCommand
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.functions.input_file_name
+import org.apache.spark.sql.types.LongType
+
+/**
+ * GPU version of Delta Lake DeleteCommand.
+ *
+ * Performs a Delete based on the search condition
+ *
+ * Algorithm:
+ * 1) Scan all the files and determine which files have
+ * the rows that need to be deleted.
+ * 2) Traverse the affected files and rebuild the touched files.
+ * 3) Use the Delta protocol to atomically write the remaining rows to new files and remove
+ * the affected files that are identified in step 1.
+ */
+case class GpuDeleteCommand(
+ gpuDeltaLog: GpuDeltaLog,
+ target: LogicalPlan,
+ condition: Option[Expression])
+ extends LeafRunnableCommand with DeltaCommand with DeleteCommandMetrics {
+
+ override def innerChildren: Seq[QueryPlan[_]] = Seq(target)
+
+ override val output: Seq[Attribute] = Seq(AttributeReference("num_affected_rows", LongType)())
+
+ @transient private lazy val sc: SparkContext = SparkContext.getOrCreate()
+
+ // DeleteCommandMetrics does not include deletion vector metrics, so add them here because
+ // the commit command needs to collect these metrics for inclusion in the delta log event
+ override lazy val metrics = createMetrics ++ Map(
+ "numDeletionVectorsAdded" -> SQLMetrics.createMetric(sc, "number of deletion vectors added."),
+ "numDeletionVectorsRemoved" ->
+ SQLMetrics.createMetric(sc, "number of deletion vectors removed.")
+ )
+
+ final override def run(sparkSession: SparkSession): Seq[Row] = {
+ val deltaLog = gpuDeltaLog.deltaLog
+ recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.dml.delete") {
+ gpuDeltaLog.withNewTransaction { txn =>
+ DeltaLog.assertRemovable(txn.snapshot)
+ val deleteCommitTags = performDelete(sparkSession, deltaLog, txn)
+ val deleteActions = deleteCommitTags.actions
+ if (deleteActions.nonEmpty) {
+ txn.commitIfNeeded(deleteActions, DeltaOperations.Delete(condition.toSeq),
+ deleteCommitTags.stringTags)
+ }
+ }
+ // Re-cache all cached plans(including this relation itself, if it's cached) that refer to
+ // this data source relation.
+ sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, target)
+ }
+
+ // Adjust for deletes at partition boundaries. Deletes at partition boundaries is a metadata
+ // operation, therefore we don't actually have any information around how many rows were deleted
+ // While this info may exist in the file statistics, it's not guaranteed that we have these
+ // statistics. To avoid any performance regressions, we currently just return a -1 in such cases
+ if (metrics("numRemovedFiles").value > 0 && metrics("numDeletedRows").value == 0) {
+ Seq(Row(-1L))
+ } else {
+ Seq(Row(metrics("numDeletedRows").value))
+ }
+ }
+
+ def performDelete(
+ sparkSession: SparkSession,
+ deltaLog: DeltaLog,
+ txn: OptimisticTransaction): DMLUtils.TaggedCommitData = {
+ import com.databricks.sql.transaction.tahoe.implicits._
+
+ var numRemovedFiles: Long = 0
+ var numAddedFiles: Long = 0
+ var numAddedChangeFiles: Long = 0
+ var scanTimeMs: Long = 0
+ var rewriteTimeMs: Long = 0
+ var numBytesAdded: Long = 0
+ var changeFileBytes: Long = 0
+ var numBytesRemoved: Long = 0
+ var numFilesBeforeSkipping: Long = 0
+ var numBytesBeforeSkipping: Long = 0
+ var numFilesAfterSkipping: Long = 0
+ var numBytesAfterSkipping: Long = 0
+ var numPartitionsAfterSkipping: Option[Long] = None
+ var numPartitionsRemovedFrom: Option[Long] = None
+ var numPartitionsAddedTo: Option[Long] = None
+ var numDeletedRows: Option[Long] = None
+ var numCopiedRows: Option[Long] = None
+
+ val startTime = System.nanoTime()
+ val numFilesTotal = txn.snapshot.numOfFiles
+
+ val deleteActions: Seq[FileAction] = condition match {
+ case None =>
+ // Case 1: Delete the whole table if the condition is true
+ val allFiles = txn.filterFiles(Nil)
+
+ numRemovedFiles = allFiles.size
+ scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000
+ val (numBytes, numPartitions) = totalBytesAndDistinctPartitionValues(allFiles)
+ numBytesRemoved = numBytes
+ numFilesBeforeSkipping = numRemovedFiles
+ numBytesBeforeSkipping = numBytes
+ numFilesAfterSkipping = numRemovedFiles
+ numBytesAfterSkipping = numBytes
+ if (txn.metadata.partitionColumns.nonEmpty) {
+ numPartitionsAfterSkipping = Some(numPartitions)
+ numPartitionsRemovedFrom = Some(numPartitions)
+ numPartitionsAddedTo = Some(0)
+ }
+ val operationTimestamp = System.currentTimeMillis()
+ allFiles.map(_.removeWithTimestamp(operationTimestamp))
+ case Some(cond) =>
+ val (metadataPredicates, otherPredicates) =
+ DeltaTableUtils.splitMetadataAndDataPredicates(
+ cond, txn.metadata.partitionColumns, sparkSession)
+
+ numFilesBeforeSkipping = txn.snapshot.numOfFiles
+ numBytesBeforeSkipping = txn.snapshot.sizeInBytes
+
+ if (otherPredicates.isEmpty) {
+ // Case 2: The condition can be evaluated using metadata only.
+ // Delete a set of files without the need of scanning any data files.
+ val operationTimestamp = System.currentTimeMillis()
+ val candidateFiles = txn.filterFiles(metadataPredicates)
+
+ scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000
+ numRemovedFiles = candidateFiles.size
+ numBytesRemoved = candidateFiles.map(_.size).sum
+ numFilesAfterSkipping = candidateFiles.size
+ val (numCandidateBytes, numCandidatePartitions) =
+ totalBytesAndDistinctPartitionValues(candidateFiles)
+ numBytesAfterSkipping = numCandidateBytes
+ if (txn.metadata.partitionColumns.nonEmpty) {
+ numPartitionsAfterSkipping = Some(numCandidatePartitions)
+ numPartitionsRemovedFrom = Some(numCandidatePartitions)
+ numPartitionsAddedTo = Some(0)
+ }
+ candidateFiles.map(_.removeWithTimestamp(operationTimestamp))
+ } else {
+ // Case 3: Delete the rows based on the condition.
+ val candidateFiles = txn.filterFiles(metadataPredicates ++ otherPredicates)
+
+ numFilesAfterSkipping = candidateFiles.size
+ val (numCandidateBytes, numCandidatePartitions) =
+ totalBytesAndDistinctPartitionValues(candidateFiles)
+ numBytesAfterSkipping = numCandidateBytes
+ if (txn.metadata.partitionColumns.nonEmpty) {
+ numPartitionsAfterSkipping = Some(numCandidatePartitions)
+ }
+
+ val nameToAddFileMap = generateCandidateFileMap(deltaLog.dataPath, candidateFiles)
+
+ val fileIndex = new TahoeBatchFileIndex(
+ sparkSession, "delete", candidateFiles, deltaLog, deltaLog.dataPath, txn.snapshot)
+ // Keep everything from the resolved target except a new TahoeFileIndex
+ // that only involves the affected files instead of all files.
+ val newTarget = DeltaTableUtils.replaceFileIndex(target, fileIndex)
+ val data = Dataset.ofRows(sparkSession, newTarget)
+ val deletedRowCount = metrics("numDeletedRows")
+ val deletedRowUdf = DeltaUDF.boolean {
+ new GpuDeltaMetricUpdateUDF(deletedRowCount)
+ }.asNondeterministic()
+ val filesToRewrite =
+ withStatusCode("DELTA", FINDING_TOUCHED_FILES_MSG) {
+ if (candidateFiles.isEmpty) {
+ Array.empty[String]
+ } else {
+ data.filter(new Column(cond))
+ .select(input_file_name())
+ .filter(deletedRowUdf())
+ .distinct()
+ .as[String]
+ .collect()
+ }
+ }
+
+ numRemovedFiles = filesToRewrite.length
+ scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000
+ if (filesToRewrite.isEmpty) {
+ // Case 3.1: no row matches and no delete will be triggered
+ if (txn.metadata.partitionColumns.nonEmpty) {
+ numPartitionsRemovedFrom = Some(0)
+ numPartitionsAddedTo = Some(0)
+ }
+ Nil
+ } else {
+ // Case 3.2: some files need an update to remove the deleted files
+ // Do the second pass and just read the affected files
+ val baseRelation = buildBaseRelation(
+ sparkSession, txn, "delete", deltaLog.dataPath, filesToRewrite, nameToAddFileMap)
+ // Keep everything from the resolved target except a new TahoeFileIndex
+ // that only involves the affected files instead of all files.
+ val newTarget = DeltaTableUtils.replaceFileIndex(target, baseRelation.location)
+ val targetDF = Dataset.ofRows(sparkSession, newTarget)
+ val filterCond = Not(EqualNullSafe(cond, Literal.TrueLiteral))
+ val rewrittenActions = rewriteFiles(txn, targetDF, filterCond, filesToRewrite.length)
+ val (changeFiles, rewrittenFiles) = rewrittenActions
+ .partition(_.isInstanceOf[AddCDCFile])
+ numAddedFiles = rewrittenFiles.size
+ val removedFiles = filesToRewrite.map(f =>
+ getTouchedFile(deltaLog.dataPath, f, nameToAddFileMap))
+ val (removedBytes, removedPartitions) =
+ totalBytesAndDistinctPartitionValues(removedFiles)
+ numBytesRemoved = removedBytes
+ val (rewrittenBytes, rewrittenPartitions) =
+ totalBytesAndDistinctPartitionValues(rewrittenFiles)
+ numBytesAdded = rewrittenBytes
+ if (txn.metadata.partitionColumns.nonEmpty) {
+ numPartitionsRemovedFrom = Some(removedPartitions)
+ numPartitionsAddedTo = Some(rewrittenPartitions)
+ }
+ numAddedChangeFiles = changeFiles.size
+ changeFileBytes = changeFiles.collect { case f: AddCDCFile => f.size }.sum
+ rewriteTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 - scanTimeMs
+ numDeletedRows = Some(metrics("numDeletedRows").value)
+ numCopiedRows = Some(metrics("numTouchedRows").value - metrics("numDeletedRows").value)
+
+ val operationTimestamp = System.currentTimeMillis()
+ removeFilesFromPaths(deltaLog, nameToAddFileMap, filesToRewrite,
+ operationTimestamp) ++ rewrittenActions
+ }
+ }
+ }
+ metrics("numRemovedFiles").set(numRemovedFiles)
+ metrics("numAddedFiles").set(numAddedFiles)
+ val executionTimeMs = (System.nanoTime() - startTime) / 1000 / 1000
+ metrics("executionTimeMs").set(executionTimeMs)
+ metrics("scanTimeMs").set(scanTimeMs)
+ metrics("rewriteTimeMs").set(rewriteTimeMs)
+ metrics("numAddedChangeFiles").set(numAddedChangeFiles)
+ metrics("changeFileBytes").set(changeFileBytes)
+ metrics("numAddedBytes").set(numBytesAdded)
+ metrics("numRemovedBytes").set(numBytesRemoved)
+ metrics("numFilesBeforeSkipping").set(numFilesBeforeSkipping)
+ metrics("numBytesBeforeSkipping").set(numBytesBeforeSkipping)
+ metrics("numFilesAfterSkipping").set(numFilesAfterSkipping)
+ metrics("numBytesAfterSkipping").set(numBytesAfterSkipping)
+ numPartitionsAfterSkipping.foreach(metrics("numPartitionsAfterSkipping").set)
+ numPartitionsAddedTo.foreach(metrics("numPartitionsAddedTo").set)
+ numPartitionsRemovedFrom.foreach(metrics("numPartitionsRemovedFrom").set)
+ numCopiedRows.foreach(metrics("numCopiedRows").set)
+ metrics("numDeletionVectorsAdded").set(0)
+ metrics("numDeletionVectorsRemoved").set(0)
+ txn.registerSQLMetrics(sparkSession, metrics)
+ // This is needed to make the SQL metrics visible in the Spark UI
+ val executionId = sparkSession.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ SQLMetrics.postDriverMetricUpdates(
+ sparkSession.sparkContext, executionId, metrics.values.toSeq)
+
+ recordDeltaEvent(
+ deltaLog,
+ "delta.dml.delete.stats",
+ data = DeleteMetric(
+ condition = condition.map(_.sql).getOrElse("true"),
+ numFilesTotal,
+ numFilesAfterSkipping,
+ numAddedFiles,
+ numRemovedFiles,
+ numAddedFiles,
+ numAddedChangeFiles = numAddedChangeFiles,
+ numFilesBeforeSkipping,
+ numBytesBeforeSkipping,
+ numFilesAfterSkipping,
+ numBytesAfterSkipping,
+ numPartitionsAfterSkipping,
+ numPartitionsAddedTo,
+ numPartitionsRemovedFrom,
+ numCopiedRows,
+ numDeletedRows,
+ numBytesAdded,
+ numBytesRemoved,
+ changeFileBytes = changeFileBytes,
+ scanTimeMs,
+ rewriteTimeMs)
+ )
+
+ DMLUtils.TaggedCommitData(deleteActions)
+ .withTag(PreservedRowTrackingTag, RowTracking.isEnabled(txn.protocol, txn.metadata))
+ .withTag(NoRowsCopiedTag, metrics("numCopiedRows").value == 0)
+ }
+
+ /**
+ * Returns the list of `AddFile`s and `AddCDCFile`s that have been re-written.
+ */
+ private def rewriteFiles(
+ txn: OptimisticTransaction,
+ baseData: DataFrame,
+ filterCondition: Expression,
+ numFilesToRewrite: Long): Seq[FileAction] = {
+ val shouldWriteCdc = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(txn.metadata)
+
+ // number of total rows that we have seen / are either copying or deleting (sum of both).
+ val numTouchedRows = metrics("numTouchedRows")
+ val numTouchedRowsUdf = DeltaUDF.boolean {
+ new GpuDeltaMetricUpdateUDF(numTouchedRows)
+ }.asNondeterministic()
+
+ withStatusCode(
+ "DELTA", rewritingFilesMsg(numFilesToRewrite)) {
+ val dfToWrite = if (shouldWriteCdc) {
+ import com.databricks.sql.transaction.tahoe.commands.cdc.CDCReader._
+ // The logic here ends up being surprisingly elegant, with all source rows ending up in
+ // the output. Recall that we flipped the user-provided delete condition earlier, before the
+ // call to `rewriteFiles`. All rows which match this latest `filterCondition` are retained
+ // as table data, while all rows which don't match are removed from the rewritten table data
+ // but do get included in the output as CDC events.
+ baseData
+ .filter(numTouchedRowsUdf())
+ .withColumn(
+ CDC_TYPE_COLUMN_NAME,
+ new Column(If(filterCondition, CDC_TYPE_NOT_CDC, CDC_TYPE_DELETE))
+ )
+ } else {
+ baseData
+ .filter(numTouchedRowsUdf())
+ .filter(new Column(filterCondition))
+ }
+
+ txn.writeFiles(dfToWrite)
+ }
+ }
+}
+
+object GpuDeleteCommand {
+ val FINDING_TOUCHED_FILES_MSG: String = "Finding files to rewrite for DELETE operation"
+
+ def rewritingFilesMsg(numFilesToRewrite: Long): String =
+ s"Rewriting $numFilesToRewrite files for DELETE operation"
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala
new file mode 100644
index 00000000000..8c62f4f3fd7
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala
@@ -0,0 +1,218 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from DeltaDataSource.scala in the
+ * Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.databricks.sql.transaction.tahoe.rapids
+
+import java.util
+
+import com.databricks.sql.transaction.tahoe.{DeltaConfigs, DeltaErrors}
+import com.databricks.sql.transaction.tahoe.commands.TableCreationModes
+import com.databricks.sql.transaction.tahoe.metering.DeltaLogging
+import com.databricks.sql.transaction.tahoe.sources.DeltaSourceUtils
+import com.nvidia.spark.rapids.RapidsConf
+
+import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, Table}
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.execution.command.LeafRunnableCommand
+import org.apache.spark.sql.execution.datasources.PartitioningUtils
+import org.apache.spark.sql.types.StructType
+
+class GpuDeltaCatalog(
+ override val cpuCatalog: StagingTableCatalog,
+ override val rapidsConf: RapidsConf)
+ extends GpuDeltaCatalogBase with SupportsPathIdentifier with DeltaLogging {
+
+ override protected def buildGpuCreateDeltaTableCommand(
+ rapidsConf: RapidsConf,
+ table: CatalogTable,
+ existingTableOpt: Option[CatalogTable],
+ mode: SaveMode,
+ query: Option[LogicalPlan],
+ operation: TableCreationModes.CreationMode,
+ tableByPath: Boolean): LeafRunnableCommand = {
+ GpuCreateDeltaTableCommand(
+ table,
+ existingTableOpt,
+ mode,
+ query,
+ operation,
+ tableByPath = tableByPath
+ )(rapidsConf)
+ }
+
+ override protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] = {
+ // If this is a path identifier, we cannot return an existing CatalogTable. The Create command
+ // will check the file system itself
+ if (isPathIdentifier(table)) return None
+ val tableExists = catalog.tableExists(table)
+ if (tableExists) {
+ val oldTable = catalog.getTableMetadata(table)
+ if (oldTable.tableType == CatalogTableType.VIEW) {
+ throw new AnalysisException(
+ s"$table is a view. You may not write data into a view.")
+ }
+ if (!DeltaSourceUtils.isDeltaTable(oldTable.provider)) {
+ throw new AnalysisException(s"$table is not a Delta table. Please drop this " +
+ "table first if you would like to recreate it with Delta Lake.")
+ }
+ Some(oldTable)
+ } else {
+ None
+ }
+ }
+
+ override protected def verifyTableAndSolidify(
+ tableDesc: CatalogTable,
+ query: Option[LogicalPlan]): CatalogTable = {
+
+ if (tableDesc.bucketSpec.isDefined) {
+ throw DeltaErrors.operationNotSupportedException("Bucketing", tableDesc.identifier)
+ }
+
+ val schema = query.map { plan =>
+ assert(tableDesc.schema.isEmpty, "Can't specify table schema in CTAS.")
+ plan.schema.asNullable
+ }.getOrElse(tableDesc.schema)
+
+ PartitioningUtils.validatePartitionColumn(
+ schema,
+ tableDesc.partitionColumnNames,
+ caseSensitive = false) // Delta is case insensitive
+
+ val validatedConfigurations = DeltaConfigs.validateConfigurations(tableDesc.properties)
+
+ val db = tableDesc.identifier.database.getOrElse(catalog.getCurrentDatabase)
+ val tableIdentWithDB = tableDesc.identifier.copy(database = Some(db))
+ tableDesc.copy(
+ identifier = tableIdentWithDB,
+ schema = schema,
+ properties = validatedConfigurations)
+ }
+
+ override protected def createGpuStagedDeltaTableV2(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String],
+ operation: TableCreationModes.CreationMode): StagedTable = {
+ new GpuStagedDeltaTableV2WithLogging(ident, schema, partitions, properties, operation)
+ }
+
+ override def loadTable(ident: Identifier, timestamp: Long): Table = {
+ cpuCatalog.loadTable(ident, timestamp)
+ }
+
+ override def loadTable(ident: Identifier, version: String): Table = {
+ cpuCatalog.loadTable(ident, version)
+ }
+
+ /**
+ * Creates a Delta table using GPU for writing the data
+ *
+ * @param ident The identifier of the table
+ * @param schema The schema of the table
+ * @param partitions The partition transforms for the table
+ * @param allTableProperties The table properties that configure the behavior of the table or
+ * provide information about the table
+ * @param writeOptions Options specific to the write during table creation or replacement
+ * @param sourceQuery A query if this CREATE request came from a CTAS or RTAS
+ * @param operation The specific table creation mode, whether this is a
+ * Create/Replace/Create or Replace
+ */
+ override def createDeltaTable(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ allTableProperties: util.Map[String, String],
+ writeOptions: Map[String, String],
+ sourceQuery: Option[DataFrame],
+ operation: TableCreationModes.CreationMode
+ ): Table = recordFrameProfile(
+ "DeltaCatalog", "createDeltaTable") {
+ super.createDeltaTable(
+ ident,
+ schema,
+ partitions,
+ allTableProperties,
+ writeOptions,
+ sourceQuery,
+ operation)
+ }
+
+ override def createTable(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table =
+ recordFrameProfile("DeltaCatalog", "createTable") {
+ super.createTable(ident, schema, partitions, properties)
+ }
+
+ override def stageReplace(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): StagedTable =
+ recordFrameProfile("DeltaCatalog", "stageReplace") {
+ super.stageReplace(ident, schema, partitions, properties)
+ }
+
+ override def stageCreateOrReplace(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): StagedTable =
+ recordFrameProfile("DeltaCatalog", "stageCreateOrReplace") {
+ super.stageCreateOrReplace(ident, schema, partitions, properties)
+ }
+
+ override def stageCreate(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): StagedTable =
+ recordFrameProfile("DeltaCatalog", "stageCreate") {
+ super.stageCreate(ident, schema, partitions, properties)
+ }
+
+ /**
+ * A staged Delta table, which creates a HiveMetaStore entry and appends data if this was a
+ * CTAS/RTAS command. We have a ugly way of using this API right now, but it's the best way to
+ * maintain old behavior compatibility between Databricks Runtime and OSS Delta Lake.
+ */
+ protected class GpuStagedDeltaTableV2WithLogging(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String],
+ operation: TableCreationModes.CreationMode)
+ extends GpuStagedDeltaTableV2(ident, schema, partitions, properties, operation) {
+
+ override def commitStagedChanges(): Unit = recordFrameProfile(
+ "DeltaCatalog", "commitStagedChanges") {
+ super.commitStagedChanges()
+ }
+ }
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDoAutoCompaction.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDoAutoCompaction.scala
new file mode 100644
index 00000000000..9726511ad44
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDoAutoCompaction.scala
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from DoAutoCompaction.scala
+ * from https://github.com/delta-io/delta/pull/1156
+ * in the Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.databricks.sql.transaction.tahoe.rapids
+
+import com.databricks.sql.transaction.tahoe._
+import com.databricks.sql.transaction.tahoe.actions.Action
+import com.databricks.sql.transaction.tahoe.hooks.PostCommitHook
+import com.databricks.sql.transaction.tahoe.metering.DeltaLogging
+
+import org.apache.spark.sql.SparkSession
+
+object GpuDoAutoCompaction extends PostCommitHook
+ with DeltaLogging
+ with Serializable {
+ override val name: String = "Triggers compaction if necessary"
+
+ override def run(spark: SparkSession,
+ txn: OptimisticTransactionImpl,
+ committedVersion: Long,
+ postCommitSnapshot: Snapshot,
+ committedActions: Seq[Action]): Unit = {
+ val gpuTxn = txn.asInstanceOf[GpuOptimisticTransaction]
+ val newTxn = new GpuDeltaLog(gpuTxn.deltaLog, gpuTxn.rapidsConf).startTransaction()
+ // Note: The Databricks AutoCompact PostCommitHook cannot be used here
+ // (with a GpuOptimisticTransaction). It appears that AutoCompact creates a new transaction,
+ // thereby circumventing GpuOptimisticTransaction (which intercepts Parquet writes
+ // to go through the GPU).
+ new GpuOptimizeExecutor(spark, newTxn, Seq.empty, Seq.empty, committedActions).optimize()
+ }
+
+ override def handleError(error: Throwable, version: Long): Unit =
+ throw DeltaErrors.postCommitHookFailedException(this, version, name, error)
+}
\ No newline at end of file
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuMergeIntoCommand.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuMergeIntoCommand.scala
new file mode 100644
index 00000000000..a0a4e4263c3
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuMergeIntoCommand.scala
@@ -0,0 +1,1187 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from MergeIntoCommand.scala
+ * in the Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.databricks.sql.transaction.tahoe.rapids
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import com.databricks.sql.transaction.tahoe._
+import com.databricks.sql.transaction.tahoe.DeltaOperations.MergePredicate
+import com.databricks.sql.transaction.tahoe.actions.{AddCDCFile, AddFile, FileAction}
+import com.databricks.sql.transaction.tahoe.commands.DeltaCommand
+import com.databricks.sql.transaction.tahoe.schema.ImplicitMetadataOperation
+import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf
+import com.databricks.sql.transaction.tahoe.util.{AnalysisHelper, SetAccumulator}
+import com.fasterxml.jackson.databind.annotation.JsonDeserialize
+import com.nvidia.spark.rapids.{BaseExprMeta, GpuOverrides, RapidsConf}
+import com.nvidia.spark.rapids.delta._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BasePredicate, Expression, Literal, NamedExpression, PredicateHelper, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
+import org.apache.spark.sql.catalyst.plans.logical.{DeltaMergeIntoClause, DeltaMergeIntoMatchedClause, DeltaMergeIntoMatchedDeleteClause, DeltaMergeIntoMatchedUpdateClause, DeltaMergeIntoNotMatchedBySourceClause, DeltaMergeIntoNotMatchedClause, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.execution.command.LeafRunnableCommand
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DataTypes, LongType, StringType, StructType}
+case class GpuMergeDataSizes(
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ rows: Option[Long] = None,
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ files: Option[Long] = None,
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ bytes: Option[Long] = None,
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ partitions: Option[Long] = None)
+
+/**
+ * Represents the state of a single merge clause:
+ * - merge clause's (optional) predicate
+ * - action type (insert, update, delete)
+ * - action's expressions
+ */
+case class GpuMergeClauseStats(
+ condition: Option[String],
+ actionType: String,
+ actionExpr: Seq[String])
+
+object GpuMergeClauseStats {
+ def apply(mergeClause: DeltaMergeIntoClause): GpuMergeClauseStats = {
+ GpuMergeClauseStats(
+ condition = mergeClause.condition.map(_.sql),
+ mergeClause.clauseType.toLowerCase(),
+ actionExpr = mergeClause.actions.map(_.sql))
+ }
+}
+
+/** State for a GPU merge operation */
+case class GpuMergeStats(
+ // Merge condition expression
+ conditionExpr: String,
+
+ // Expressions used in old MERGE stats, now always Null
+ updateConditionExpr: String,
+ updateExprs: Seq[String],
+ insertConditionExpr: String,
+ insertExprs: Seq[String],
+ deleteConditionExpr: String,
+
+ // Newer expressions used in MERGE with any number of MATCHED/NOT MATCHED
+ matchedStats: Seq[GpuMergeClauseStats],
+ notMatchedStats: Seq[GpuMergeClauseStats],
+
+ // Data sizes of source and target at different stages of processing
+ source: GpuMergeDataSizes,
+ targetBeforeSkipping: GpuMergeDataSizes,
+ targetAfterSkipping: GpuMergeDataSizes,
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ sourceRowsInSecondScan: Option[Long],
+
+ // Data change sizes
+ targetFilesRemoved: Long,
+ targetFilesAdded: Long,
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ targetChangeFilesAdded: Option[Long],
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ targetChangeFileBytes: Option[Long],
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ targetBytesRemoved: Option[Long],
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ targetBytesAdded: Option[Long],
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ targetPartitionsRemovedFrom: Option[Long],
+ @JsonDeserialize(contentAs = classOf[java.lang.Long])
+ targetPartitionsAddedTo: Option[Long],
+ targetRowsCopied: Long,
+ targetRowsUpdated: Long,
+ targetRowsInserted: Long,
+ targetRowsDeleted: Long
+)
+
+object GpuMergeStats {
+
+ def fromMergeSQLMetrics(
+ metrics: Map[String, SQLMetric],
+ condition: Expression,
+ matchedClauses: Seq[DeltaMergeIntoMatchedClause],
+ notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause],
+ isPartitioned: Boolean): GpuMergeStats = {
+
+ def metricValueIfPartitioned(metricName: String): Option[Long] = {
+ if (isPartitioned) Some(metrics(metricName).value) else None
+ }
+
+ GpuMergeStats(
+ // Merge condition expression
+ conditionExpr = condition.sql,
+
+ // Newer expressions used in MERGE with any number of MATCHED/NOT MATCHED
+ matchedStats = matchedClauses.map(GpuMergeClauseStats(_)),
+ notMatchedStats = notMatchedClauses.map(GpuMergeClauseStats(_)),
+
+ // Data sizes of source and target at different stages of processing
+ source = GpuMergeDataSizes(rows = Some(metrics("numSourceRows").value)),
+ targetBeforeSkipping =
+ GpuMergeDataSizes(
+ files = Some(metrics("numTargetFilesBeforeSkipping").value),
+ bytes = Some(metrics("numTargetBytesBeforeSkipping").value)),
+ targetAfterSkipping =
+ GpuMergeDataSizes(
+ files = Some(metrics("numTargetFilesAfterSkipping").value),
+ bytes = Some(metrics("numTargetBytesAfterSkipping").value),
+ partitions = metricValueIfPartitioned("numTargetPartitionsAfterSkipping")),
+ sourceRowsInSecondScan =
+ metrics.get("numSourceRowsInSecondScan").map(_.value).filter(_ >= 0),
+
+ // Data change sizes
+ targetFilesAdded = metrics("numTargetFilesAdded").value,
+ targetChangeFilesAdded = metrics.get("numTargetChangeFilesAdded").map(_.value),
+ targetChangeFileBytes = metrics.get("numTargetChangeFileBytes").map(_.value),
+ targetFilesRemoved = metrics("numTargetFilesRemoved").value,
+ targetBytesAdded = Some(metrics("numTargetBytesAdded").value),
+ targetBytesRemoved = Some(metrics("numTargetBytesRemoved").value),
+ targetPartitionsRemovedFrom = metricValueIfPartitioned("numTargetPartitionsRemovedFrom"),
+ targetPartitionsAddedTo = metricValueIfPartitioned("numTargetPartitionsAddedTo"),
+ targetRowsCopied = metrics("numTargetRowsCopied").value,
+ targetRowsUpdated = metrics("numTargetRowsUpdated").value,
+ targetRowsInserted = metrics("numTargetRowsInserted").value,
+ targetRowsDeleted = metrics("numTargetRowsDeleted").value,
+
+ // Deprecated fields
+ updateConditionExpr = null,
+ updateExprs = null,
+ insertConditionExpr = null,
+ insertExprs = null,
+ deleteConditionExpr = null)
+ }
+}
+
+/**
+ * GPU version of Delta Lake's MergeIntoCommand.
+ *
+ * Performs a merge of a source query/table into a Delta table.
+ *
+ * Issues an error message when the ON search_condition of the MERGE statement can match
+ * a single row from the target table with multiple rows of the source table-reference.
+ *
+ * Algorithm:
+ *
+ * Phase 1: Find the input files in target that are touched by the rows that satisfy
+ * the condition and verify that no two source rows match with the same target row.
+ * This is implemented as an inner-join using the given condition. See [[findTouchedFiles]]
+ * for more details.
+ *
+ * Phase 2: Read the touched files again and write new files with updated and/or inserted rows.
+ *
+ * Phase 3: Use the Delta protocol to atomically remove the touched files and add the new files.
+ *
+ * @param source Source data to merge from
+ * @param target Target table to merge into
+ * @param gpuDeltaLog Delta log to use
+ * @param condition Condition for a source row to match with a target row
+ * @param matchedClauses All info related to matched clauses.
+ * @param notMatchedClauses All info related to not matched clause.
+ * @param migratedSchema The final schema of the target - may be changed by schema evolution.
+ */
+case class GpuMergeIntoCommand(
+ @transient source: LogicalPlan,
+ @transient target: LogicalPlan,
+ @transient gpuDeltaLog: GpuDeltaLog,
+ condition: Expression,
+ matchedClauses: Seq[DeltaMergeIntoMatchedClause],
+ notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause],
+ notMatchedBySourceClauses: Seq[DeltaMergeIntoNotMatchedBySourceClause],
+ migratedSchema: Option[StructType])(
+ @transient val rapidsConf: RapidsConf)
+ extends LeafRunnableCommand
+ with DeltaCommand with PredicateHelper with AnalysisHelper with ImplicitMetadataOperation {
+
+ import GpuMergeIntoCommand._
+
+ import SQLMetrics._
+ import com.databricks.sql.transaction.tahoe.commands.cdc.CDCReader._
+
+ override val otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf)
+
+ override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE)
+ override val canOverwriteSchema: Boolean = false
+
+ override val output: Seq[Attribute] = Seq(
+ AttributeReference("num_affected_rows", LongType)(),
+ AttributeReference("num_updated_rows", LongType)(),
+ AttributeReference("num_deleted_rows", LongType)(),
+ AttributeReference("num_inserted_rows", LongType)())
+
+ @transient private lazy val sc: SparkContext = SparkContext.getOrCreate()
+ @transient private lazy val targetDeltaLog: DeltaLog = gpuDeltaLog.deltaLog
+ /**
+ * Map to get target output attributes by name.
+ * The case sensitivity of the map is set accordingly to Spark configuration.
+ */
+ @transient private lazy val targetOutputAttributesMap: Map[String, Attribute] = {
+ val attrMap: Map[String, Attribute] = target
+ .outputSet.view
+ .map(attr => attr.name -> attr).toMap
+ if (conf.caseSensitiveAnalysis) {
+ attrMap
+ } else {
+ CaseInsensitiveMap(attrMap)
+ }
+ }
+
+ /** Whether this merge statement has only a single insert (NOT MATCHED) clause. */
+ private def isSingleInsertOnly: Boolean = matchedClauses.isEmpty && notMatchedClauses.length == 1
+ /** Whether this merge statement has only MATCHED clauses. */
+ private def isMatchedOnly: Boolean = notMatchedClauses.isEmpty && matchedClauses.nonEmpty
+
+ // We over-count numTargetRowsDeleted when there are multiple matches;
+ // this is the amount of the overcount, so we can subtract it to get a correct final metric.
+ private var multipleMatchDeleteOnlyOvercount: Option[Long] = None
+
+ override lazy val metrics = Map[String, SQLMetric](
+ "numSourceRows" -> createMetric(sc, "number of source rows"),
+ "numSourceRowsInSecondScan" ->
+ createMetric(sc, "number of source rows (during repeated scan)"),
+ "numTargetRowsCopied" -> createMetric(sc, "number of target rows rewritten unmodified"),
+ "numTargetRowsInserted" -> createMetric(sc, "number of inserted rows"),
+ "numTargetRowsUpdated" -> createMetric(sc, "number of updated rows"),
+ "numTargetRowsDeleted" -> createMetric(sc, "number of deleted rows"),
+ "numTargetFilesBeforeSkipping" -> createMetric(sc, "number of target files before skipping"),
+ "numTargetFilesAfterSkipping" -> createMetric(sc, "number of target files after skipping"),
+ "numTargetFilesRemoved" -> createMetric(sc, "number of files removed to target"),
+ "numTargetFilesAdded" -> createMetric(sc, "number of files added to target"),
+ "numTargetChangeFilesAdded" ->
+ createMetric(sc, "number of change data capture files generated"),
+ "numTargetChangeFileBytes" ->
+ createMetric(sc, "total size of change data capture files generated"),
+ "numTargetBytesBeforeSkipping" -> createMetric(sc, "number of target bytes before skipping"),
+ "numTargetBytesAfterSkipping" -> createMetric(sc, "number of target bytes after skipping"),
+ "numTargetBytesRemoved" -> createMetric(sc, "number of target bytes removed"),
+ "numTargetBytesAdded" -> createMetric(sc, "number of target bytes added"),
+ "numTargetPartitionsAfterSkipping" ->
+ createMetric(sc, "number of target partitions after skipping"),
+ "numTargetPartitionsRemovedFrom" ->
+ createMetric(sc, "number of target partitions from which files were removed"),
+ "numTargetPartitionsAddedTo" ->
+ createMetric(sc, "number of target partitions to which files were added"),
+ "executionTimeMs" ->
+ createMetric(sc, "time taken to execute the entire operation"),
+ "scanTimeMs" ->
+ createMetric(sc, "time taken to scan the files for matches"),
+ "rewriteTimeMs" ->
+ createMetric(sc, "time taken to rewrite the matched files"))
+
+ override def run(spark: SparkSession): Seq[Row] = {
+ recordDeltaOperation(targetDeltaLog, "delta.dml.merge") {
+ val startTime = System.nanoTime()
+ gpuDeltaLog.withNewTransaction { deltaTxn =>
+ if (target.schema.size != deltaTxn.metadata.schema.size) {
+ throw DeltaErrors.schemaChangedSinceAnalysis(
+ atAnalysis = target.schema, latestSchema = deltaTxn.metadata.schema)
+ }
+
+ if (canMergeSchema) {
+ updateMetadata(
+ spark, deltaTxn, migratedSchema.getOrElse(target.schema),
+ deltaTxn.metadata.partitionColumns, deltaTxn.metadata.configuration,
+ isOverwriteMode = false, rearrangeOnly = false)
+ }
+
+ val deltaActions = {
+ if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) {
+ writeInsertsOnlyWhenNoMatchedClauses(spark, deltaTxn)
+ } else {
+ val filesToRewrite = findTouchedFiles(spark, deltaTxn)
+ val newWrittenFiles = withStatusCode("DELTA", "Writing merged data") {
+ writeAllChanges(spark, deltaTxn, filesToRewrite)
+ }
+ filesToRewrite.map(_.remove) ++ newWrittenFiles
+ }
+ }
+
+ // Metrics should be recorded before commit (where they are written to delta logs).
+ metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000)
+ deltaTxn.registerSQLMetrics(spark, metrics)
+
+ // This is a best-effort sanity check.
+ if (metrics("numSourceRowsInSecondScan").value >= 0 &&
+ metrics("numSourceRows").value != metrics("numSourceRowsInSecondScan").value) {
+ log.warn(s"Merge source has ${metrics("numSourceRows").value} rows in initial scan but " +
+ s"${metrics("numSourceRowsInSecondScan").value} rows in second scan")
+ if (conf.getConf(DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED)) {
+ throw DeltaErrors.sourceNotDeterministicInMergeException(spark)
+ }
+ }
+
+ deltaTxn.commit(
+ deltaActions,
+ DeltaOperations.Merge(
+ Option(condition),
+ matchedClauses.map(DeltaOperations.MergePredicate(_)),
+ notMatchedClauses.map(DeltaOperations.MergePredicate(_)),
+ // We do not support notMatchedBySourcePredicates yet and fall back to CPU
+ // See https://github.com/NVIDIA/spark-rapids/issues/8415
+ notMatchedBySourcePredicates = Seq.empty[MergePredicate]
+ ))
+
+ // Record metrics
+ val stats = GpuMergeStats.fromMergeSQLMetrics(
+ metrics, condition, matchedClauses, notMatchedClauses,
+ deltaTxn.metadata.partitionColumns.nonEmpty)
+ recordDeltaEvent(targetDeltaLog, "delta.dml.merge.stats", data = stats)
+
+ }
+ spark.sharedState.cacheManager.recacheByPlan(spark, target)
+ }
+ // This is needed to make the SQL metrics visible in the Spark UI. Also this needs
+ // to be outside the recordMergeOperation because this method will update some metric.
+ val executionId = spark.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ SQLMetrics.postDriverMetricUpdates(spark.sparkContext, executionId, metrics.values.toSeq)
+ Seq(Row(metrics("numTargetRowsUpdated").value + metrics("numTargetRowsDeleted").value +
+ metrics("numTargetRowsInserted").value, metrics("numTargetRowsUpdated").value,
+ metrics("numTargetRowsDeleted").value, metrics("numTargetRowsInserted").value))
+ }
+
+ /**
+ * Find the target table files that contain the rows that satisfy the merge condition. This is
+ * implemented as an inner-join between the source query/table and the target table using
+ * the merge condition.
+ */
+ private def findTouchedFiles(
+ spark: SparkSession,
+ deltaTxn: OptimisticTransaction
+ ): Seq[AddFile] = recordMergeOperation(sqlMetricName = "scanTimeMs") {
+
+ // Accumulator to collect all the distinct touched files
+ val touchedFilesAccum = new SetAccumulator[String]()
+ spark.sparkContext.register(touchedFilesAccum, TOUCHED_FILES_ACCUM_NAME)
+
+ // UDFs to records touched files names and add them to the accumulator
+ val recordTouchedFileName = udf(new GpuDeltaRecordTouchedFileNameUDF(touchedFilesAccum))
+ .asNondeterministic()
+
+ // Skip data based on the merge condition
+ val targetOnlyPredicates =
+ splitConjunctivePredicates(condition).filter(_.references.subsetOf(target.outputSet))
+ val dataSkippedFiles = deltaTxn.filterFiles(targetOnlyPredicates)
+
+ // UDF to increment metrics
+ val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows")
+ val sourceDF = Dataset.ofRows(spark, source)
+ .filter(new Column(incrSourceRowCountExpr))
+
+ // Apply inner join to between source and target using the merge condition to find matches
+ // In addition, we attach two columns
+ // - a monotonically increasing row id for target rows to later identify whether the same
+ // target row is modified by multiple user or not
+ // - the target file name the row is from to later identify the files touched by matched rows
+ val targetDF = Dataset.ofRows(spark, buildTargetPlanWithFiles(deltaTxn, dataSkippedFiles))
+ .withColumn(ROW_ID_COL, monotonically_increasing_id())
+ .withColumn(FILE_NAME_COL, input_file_name())
+ val joinToFindTouchedFiles = sourceDF.join(targetDF, new Column(condition), "inner")
+
+ // Process the matches from the inner join to record touched files and find multiple matches
+ val collectTouchedFiles = joinToFindTouchedFiles
+ .select(col(ROW_ID_COL), recordTouchedFileName(col(FILE_NAME_COL)).as("one"))
+
+ // Calculate frequency of matches per source row
+ val matchedRowCounts = collectTouchedFiles.groupBy(ROW_ID_COL).agg(sum("one").as("count"))
+
+ // Get multiple matches and simultaneously collect (using touchedFilesAccum) the file names
+ // multipleMatchCount = # of target rows with more than 1 matching source row (duplicate match)
+ // multipleMatchSum = total # of duplicate matched rows
+ import spark.implicits._
+ val (multipleMatchCount, multipleMatchSum) = matchedRowCounts
+ .filter("count > 1")
+ .select(coalesce(count("*"), lit(0)), coalesce(sum("count"), lit(0)))
+ .as[(Long, Long)]
+ .collect()
+ .head
+
+ val hasMultipleMatches = multipleMatchCount > 0
+
+ // Throw error if multiple matches are ambiguous or cannot be computed correctly.
+ val canBeComputedUnambiguously = {
+ // Multiple matches are not ambiguous when there is only one unconditional delete as
+ // all the matched row pairs in the 2nd join in `writeAllChanges` will get deleted.
+ val isUnconditionalDelete = matchedClauses.headOption match {
+ case Some(DeltaMergeIntoMatchedDeleteClause(None)) => true
+ case _ => false
+ }
+ matchedClauses.size == 1 && isUnconditionalDelete
+ }
+
+ if (hasMultipleMatches && !canBeComputedUnambiguously) {
+ throw DeltaErrors.multipleSourceRowMatchingTargetRowInMergeException(spark)
+ }
+
+ if (hasMultipleMatches) {
+ // This is only allowed for delete-only queries.
+ // This query will count the duplicates for numTargetRowsDeleted in Job 2,
+ // because we count matches after the join and not just the target rows.
+ // We have to compensate for this by subtracting the duplicates later,
+ // so we need to record them here.
+ val duplicateCount = multipleMatchSum - multipleMatchCount
+ multipleMatchDeleteOnlyOvercount = Some(duplicateCount)
+ }
+
+ // Get the AddFiles using the touched file names.
+ val touchedFileNames = touchedFilesAccum.value.iterator().asScala.toSeq
+ logTrace(s"findTouchedFiles: matched files:\n\t${touchedFileNames.mkString("\n\t")}")
+
+ val nameToAddFileMap = generateCandidateFileMap(targetDeltaLog.dataPath, dataSkippedFiles)
+ val touchedAddFiles = touchedFileNames.map(f =>
+ getTouchedFile(targetDeltaLog.dataPath, f, nameToAddFileMap))
+
+ // When the target table is empty, and the optimizer optimized away the join entirely
+ // numSourceRows will be incorrectly 0. We need to scan the source table once to get the correct
+ // metric here.
+ if (metrics("numSourceRows").value == 0 &&
+ (dataSkippedFiles.isEmpty || targetDF.take(1).isEmpty)) {
+ val numSourceRows = sourceDF.count()
+ metrics("numSourceRows").set(numSourceRows)
+ }
+
+ // Update metrics
+ metrics("numTargetFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles
+ metrics("numTargetBytesBeforeSkipping") += deltaTxn.snapshot.sizeInBytes
+ val (afterSkippingBytes, afterSkippingPartitions) =
+ totalBytesAndDistinctPartitionValues(dataSkippedFiles)
+ metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size
+ metrics("numTargetBytesAfterSkipping") += afterSkippingBytes
+ metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions
+ val (removedBytes, removedPartitions) = totalBytesAndDistinctPartitionValues(touchedAddFiles)
+ metrics("numTargetFilesRemoved") += touchedAddFiles.size
+ metrics("numTargetBytesRemoved") += removedBytes
+ metrics("numTargetPartitionsRemovedFrom") += removedPartitions
+ touchedAddFiles
+ }
+
+ /**
+ * This is an optimization of the case when there is no update clause for the merge.
+ * We perform an left anti join on the source data to find the rows to be inserted.
+ *
+ * This will currently only optimize for the case when there is a _single_ notMatchedClause.
+ */
+ private def writeInsertsOnlyWhenNoMatchedClauses(
+ spark: SparkSession,
+ deltaTxn: OptimisticTransaction
+ ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") {
+
+ // UDFs to update metrics
+ val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows")
+ val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted")
+
+ val outputColNames = getTargetOutputCols(deltaTxn).map(_.name)
+ // we use head here since we know there is only a single notMatchedClause
+ val outputExprs = notMatchedClauses.head.resolvedActions.map(_.expr)
+ val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) =>
+ new Column(Alias(expr, name)())
+ }
+
+ // source DataFrame
+ val sourceDF = Dataset.ofRows(spark, source)
+ .filter(new Column(incrSourceRowCountExpr))
+ .filter(new Column(notMatchedClauses.head.condition.getOrElse(Literal.TrueLiteral)))
+
+ // Skip data based on the merge condition
+ val conjunctivePredicates = splitConjunctivePredicates(condition)
+ val targetOnlyPredicates =
+ conjunctivePredicates.filter(_.references.subsetOf(target.outputSet))
+ val dataSkippedFiles = deltaTxn.filterFiles(targetOnlyPredicates)
+
+ // target DataFrame
+ val targetDF = Dataset.ofRows(
+ spark, buildTargetPlanWithFiles(deltaTxn, dataSkippedFiles))
+
+ val insertDf = sourceDF.join(targetDF, new Column(condition), "leftanti")
+ .select(outputCols: _*)
+ .filter(new Column(incrInsertedCountExpr))
+
+ val newFiles = deltaTxn
+ .writeFiles(repartitionIfNeeded(spark, insertDf, deltaTxn.metadata.partitionColumns))
+
+ // Update metrics
+ metrics("numTargetFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles
+ metrics("numTargetBytesBeforeSkipping") += deltaTxn.snapshot.sizeInBytes
+ val (afterSkippingBytes, afterSkippingPartitions) =
+ totalBytesAndDistinctPartitionValues(dataSkippedFiles)
+ metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size
+ metrics("numTargetBytesAfterSkipping") += afterSkippingBytes
+ metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions
+ metrics("numTargetFilesRemoved") += 0
+ metrics("numTargetBytesRemoved") += 0
+ metrics("numTargetPartitionsRemovedFrom") += 0
+ val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles)
+ metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile])
+ metrics("numTargetBytesAdded") += addedBytes
+ metrics("numTargetPartitionsAddedTo") += addedPartitions
+ newFiles
+ }
+
+ /**
+ * Write new files by reading the touched files and updating/inserting data using the source
+ * query/table. This is implemented using a full|right-outer-join using the merge condition.
+ *
+ * Note that unlike the insert-only code paths with just one control column INCR_ROW_COUNT_COL,
+ * this method has two additional control columns ROW_DROPPED_COL for dropping deleted rows and
+ * CDC_TYPE_COL_NAME used for handling CDC when enabled.
+ */
+ private def writeAllChanges(
+ spark: SparkSession,
+ deltaTxn: OptimisticTransaction,
+ filesToRewrite: Seq[AddFile]
+ ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") {
+ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+
+ val cdcEnabled = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(deltaTxn.metadata)
+
+ var targetOutputCols = getTargetOutputCols(deltaTxn)
+ var outputRowSchema = deltaTxn.metadata.schema
+
+ // When we have duplicate matches (only allowed when the whenMatchedCondition is a delete with
+ // no match condition) we will incorrectly generate duplicate CDC rows.
+ // Duplicate matches can be due to:
+ // - Duplicate rows in the source w.r.t. the merge condition
+ // - A target-only or source-only merge condition, which essentially turns our join into a cross
+ // join with the target/source satisfiying the merge condition.
+ // These duplicate matches are dropped from the main data output since this is a delete
+ // operation, but the duplicate CDC rows are not removed by default.
+ // See https://github.com/delta-io/delta/issues/1274
+
+ // We address this specific scenario by adding row ids to the target before performing our join.
+ // There should only be one CDC delete row per target row so we can use these row ids to dedupe
+ // the duplicate CDC delete rows.
+
+ // We also need to address the scenario when there are duplicate matches with delete and we
+ // insert duplicate rows. Here we need to additionally add row ids to the source before the
+ // join to avoid dropping these valid duplicate inserted rows and their corresponding cdc rows.
+
+ // When there is an insert clause, we set SOURCE_ROW_ID_COL=null for all delete rows because we
+ // need to drop the duplicate matches.
+ val isDeleteWithDuplicateMatchesAndCdc = multipleMatchDeleteOnlyOvercount.nonEmpty && cdcEnabled
+
+ // Generate a new logical plan that has same output attributes exprIds as the target plan.
+ // This allows us to apply the existing resolved update/insert expressions.
+ val newTarget = buildTargetPlanWithFiles(deltaTxn, filesToRewrite)
+ val joinType = if (isMatchedOnly &&
+ spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED)) {
+ "rightOuter"
+ } else {
+ "fullOuter"
+ }
+
+ logDebug(s"""writeAllChanges using $joinType join:
+ | source.output: ${source.outputSet}
+ | target.output: ${target.outputSet}
+ | condition: $condition
+ | newTarget.output: ${newTarget.outputSet}
+ """.stripMargin)
+
+ // UDFs to update metrics
+ // Make UDFs that appear in the custom join processor node deterministic, as they always
+ // return true and update a metric. Catalyst precludes non-deterministic UDFs that are not
+ // allowed outside a very specific set of Catalyst nodes (Project, Filter, Window, Aggregate).
+ val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRowsInSecondScan")
+ val incrUpdatedCountExpr = makeMetricUpdateUDF("numTargetRowsUpdated", deterministic = true)
+ val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted", deterministic = true)
+ val incrNoopCountExpr = makeMetricUpdateUDF("numTargetRowsCopied", deterministic = true)
+ val incrDeletedCountExpr = makeMetricUpdateUDF("numTargetRowsDeleted", deterministic = true)
+
+ // Apply an outer join to find both, matches and non-matches. We are adding two boolean fields
+ // with value `true`, one to each side of the join. Whether this field is null or not after
+ // the outer join, will allow us to identify whether the resultant joined row was a
+ // matched inner result or an unmatched result with null on one side.
+ // We add row IDs to the targetDF if we have a delete-when-matched clause with duplicate
+ // matches and CDC is enabled, and additionally add row IDs to the source if we also have an
+ // insert clause. See above at isDeleteWithDuplicateMatchesAndCdc definition for more details.
+ var sourceDF = Dataset.ofRows(spark, source)
+ .withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr))
+ var targetDF = Dataset.ofRows(spark, newTarget)
+ .withColumn(TARGET_ROW_PRESENT_COL, lit(true))
+ if (isDeleteWithDuplicateMatchesAndCdc) {
+ targetDF = targetDF.withColumn(TARGET_ROW_ID_COL, monotonically_increasing_id())
+ if (notMatchedClauses.nonEmpty) { // insert clause
+ sourceDF = sourceDF.withColumn(SOURCE_ROW_ID_COL, monotonically_increasing_id())
+ }
+ }
+ val joinedDF = sourceDF.join(targetDF, new Column(condition), joinType)
+ val joinedPlan = joinedDF.queryExecution.analyzed
+
+ def resolveOnJoinedPlan(exprs: Seq[Expression]): Seq[Expression] = {
+ tryResolveReferencesForExpressions(spark, exprs, joinedPlan)
+ }
+
+ // ==== Generate the expressions to process full-outer join output and generate target rows ====
+ // If there are N columns in the target table, there will be N + 3 columns after processing
+ // - N columns for target table
+ // - ROW_DROPPED_COL to define whether the generated row should dropped or written
+ // - INCR_ROW_COUNT_COL containing a UDF to update the output row row counter
+ // - CDC_TYPE_COLUMN_NAME containing the type of change being performed in a particular row
+
+ // To generate these N + 3 columns, we will generate N + 3 expressions and apply them to the
+ // rows in the joinedDF. The CDC column will be either used for CDC generation or dropped before
+ // performing the final write, and the other two will always be dropped after executing the
+ // metrics UDF and filtering on ROW_DROPPED_COL.
+
+ // We produce rows for both the main table data (with CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC),
+ // and rows for the CDC data which will be output to CDCReader.CDC_LOCATION.
+ // See [[CDCReader]] for general details on how partitioning on the CDC type column works.
+
+ // In the following two functions `matchedClauseOutput` and `notMatchedClauseOutput`, we
+ // produce a Seq[Expression] for each intended output row.
+ // Depending on the clause and whether CDC is enabled, we output between 0 and 3 rows, as a
+ // Seq[Seq[Expression]]
+
+ // There is one corner case outlined above at isDeleteWithDuplicateMatchesAndCdc definition.
+ // When we have a delete-ONLY merge with duplicate matches we have N + 4 columns:
+ // N target cols, TARGET_ROW_ID_COL, ROW_DROPPED_COL, INCR_ROW_COUNT_COL, CDC_TYPE_COLUMN_NAME
+ // When we have a delete-when-matched merge with duplicate matches + an insert clause, we have
+ // N + 5 columns:
+ // N target cols, TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL, ROW_DROPPED_COL, INCR_ROW_COUNT_COL,
+ // CDC_TYPE_COLUMN_NAME
+ // These ROW_ID_COL will always be dropped before the final write.
+
+ if (isDeleteWithDuplicateMatchesAndCdc) {
+ targetOutputCols = targetOutputCols :+ UnresolvedAttribute(TARGET_ROW_ID_COL)
+ outputRowSchema = outputRowSchema.add(TARGET_ROW_ID_COL, DataTypes.LongType)
+ if (notMatchedClauses.nonEmpty) { // there is an insert clause, make SRC_ROW_ID_COL=null
+ targetOutputCols = targetOutputCols :+ Alias(Literal(null), SOURCE_ROW_ID_COL)()
+ outputRowSchema = outputRowSchema.add(SOURCE_ROW_ID_COL, DataTypes.LongType)
+ }
+ }
+
+ if (cdcEnabled) {
+ outputRowSchema = outputRowSchema
+ .add(ROW_DROPPED_COL, DataTypes.BooleanType)
+ .add(INCR_ROW_COUNT_COL, DataTypes.BooleanType)
+ .add(CDC_TYPE_COLUMN_NAME, DataTypes.StringType)
+ }
+
+ def matchedClauseOutput(clause: DeltaMergeIntoMatchedClause): Seq[Seq[Expression]] = {
+ val exprs = clause match {
+ case u: DeltaMergeIntoMatchedUpdateClause =>
+ // Generate update expressions and set ROW_DELETED_COL = false and
+ // CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC
+ val mainDataOutput = u.resolvedActions.map(_.expr) :+ FalseLiteral :+
+ incrUpdatedCountExpr :+ CDC_TYPE_NOT_CDC_LITERAL
+ if (cdcEnabled) {
+ // For update preimage, we have do a no-op copy with ROW_DELETED_COL = false and
+ // CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_PREIMAGE and INCR_ROW_COUNT_COL as a no-op
+ // (because the metric will be incremented in `mainDataOutput`)
+ val preImageOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+
+ Literal(CDC_TYPE_UPDATE_PREIMAGE)
+ // For update postimage, we have the same expressions as for mainDataOutput but with
+ // INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in
+ // `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_POSTIMAGE
+ val postImageOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+
+ Literal(CDC_TYPE_UPDATE_POSTIMAGE)
+ Seq(mainDataOutput, preImageOutput, postImageOutput)
+ } else {
+ Seq(mainDataOutput)
+ }
+ case _: DeltaMergeIntoMatchedDeleteClause =>
+ // Generate expressions to set the ROW_DELETED_COL = true and CDC_TYPE_COLUMN_NAME =
+ // CDC_TYPE_NOT_CDC
+ val mainDataOutput = targetOutputCols :+ TrueLiteral :+ incrDeletedCountExpr :+
+ CDC_TYPE_NOT_CDC_LITERAL
+ if (cdcEnabled) {
+ // For delete we do a no-op copy with ROW_DELETED_COL = false, INCR_ROW_COUNT_COL as a
+ // no-op (because the metric will be incremented in `mainDataOutput`) and
+ // CDC_TYPE_COLUMN_NAME = CDC_TYPE_DELETE
+ val deleteCdcOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+
+ Literal(CDC_TYPE_DELETE)
+ Seq(mainDataOutput, deleteCdcOutput)
+ } else {
+ Seq(mainDataOutput)
+ }
+ }
+ exprs.map(resolveOnJoinedPlan)
+ }
+
+ def notMatchedClauseOutput(clause: DeltaMergeIntoNotMatchedClause): Seq[Seq[Expression]] = {
+ // Generate insert expressions and set ROW_DELETED_COL = false and
+ // CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC
+ val insertExprs = clause.resolvedActions.map(_.expr)
+ val mainDataOutput = resolveOnJoinedPlan(
+ if (isDeleteWithDuplicateMatchesAndCdc) {
+ // Must be delete-when-matched merge with duplicate matches + insert clause
+ // Therefore we must keep the target row id and source row id. Since this is a not-matched
+ // clause we know the target row-id will be null. See above at
+ // isDeleteWithDuplicateMatchesAndCdc definition for more details.
+ insertExprs :+
+ Alias(Literal(null), TARGET_ROW_ID_COL)() :+ UnresolvedAttribute(SOURCE_ROW_ID_COL) :+
+ FalseLiteral :+ incrInsertedCountExpr :+ CDC_TYPE_NOT_CDC_LITERAL
+ } else {
+ insertExprs :+ FalseLiteral :+ incrInsertedCountExpr :+ CDC_TYPE_NOT_CDC_LITERAL
+ }
+ )
+ if (cdcEnabled) {
+ // For insert we have the same expressions as for mainDataOutput, but with
+ // INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in
+ // `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_INSERT
+ val insertCdcOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+ Literal(CDC_TYPE_INSERT)
+ Seq(mainDataOutput, insertCdcOutput)
+ } else {
+ Seq(mainDataOutput)
+ }
+ }
+
+ def clauseCondition(clause: DeltaMergeIntoClause): Expression = {
+ // if condition is None, then expression always evaluates to true
+ val condExpr = clause.condition.getOrElse(TrueLiteral)
+ resolveOnJoinedPlan(Seq(condExpr)).head
+ }
+
+ val targetRowHasNoMatch = resolveOnJoinedPlan(Seq(col(SOURCE_ROW_PRESENT_COL).isNull.expr)).head
+ val sourceRowHasNoMatch = resolveOnJoinedPlan(Seq(col(TARGET_ROW_PRESENT_COL).isNull.expr)).head
+ val matchedConditions = matchedClauses.map(clauseCondition)
+ val matchedOutputs = matchedClauses.map(matchedClauseOutput)
+ val notMatchedConditions = notMatchedClauses.map(clauseCondition)
+ val notMatchedOutputs = notMatchedClauses.map(notMatchedClauseOutput)
+ // TODO support notMatchedBySourceClauses which is new in DBR 12.2
+ // https://github.com/NVIDIA/spark-rapids/issues/8415
+ val notMatchedBySourceConditions = Seq.empty
+ val notMatchedBySourceOutputs = Seq.empty
+ val noopCopyOutput =
+ resolveOnJoinedPlan(targetOutputCols :+ FalseLiteral :+ incrNoopCountExpr :+
+ CDC_TYPE_NOT_CDC_LITERAL)
+ val deleteRowOutput =
+ resolveOnJoinedPlan(targetOutputCols :+ TrueLiteral :+ TrueLiteral :+
+ CDC_TYPE_NOT_CDC_LITERAL)
+ var outputDF = addMergeJoinProcessor(spark, joinedPlan, outputRowSchema,
+ targetRowHasNoMatch = targetRowHasNoMatch,
+ sourceRowHasNoMatch = sourceRowHasNoMatch,
+ matchedConditions = matchedConditions,
+ matchedOutputs = matchedOutputs,
+ notMatchedConditions = notMatchedConditions,
+ notMatchedOutputs = notMatchedOutputs,
+ notMatchedBySourceConditions = notMatchedBySourceConditions,
+ notMatchedBySourceOutputs = notMatchedBySourceOutputs,
+ noopCopyOutput = noopCopyOutput,
+ deleteRowOutput = deleteRowOutput)
+
+ if (isDeleteWithDuplicateMatchesAndCdc) {
+ // When we have a delete when matched clause with duplicate matches we have to remove
+ // duplicate CDC rows. This scenario is further explained at
+ // isDeleteWithDuplicateMatchesAndCdc definition.
+
+ // To remove duplicate CDC rows generated by the duplicate matches we dedupe by
+ // TARGET_ROW_ID_COL since there should only be one CDC delete row per target row.
+ // When there is an insert clause in addition to the delete clause we additionally dedupe by
+ // SOURCE_ROW_ID_COL and CDC_TYPE_COLUMN_NAME to avoid dropping valid duplicate inserted rows
+ // and their corresponding CDC rows.
+ val columnsToDedupeBy = if (notMatchedClauses.nonEmpty) { // insert clause
+ Seq(TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL, CDC_TYPE_COLUMN_NAME)
+ } else {
+ Seq(TARGET_ROW_ID_COL)
+ }
+ outputDF = outputDF
+ .dropDuplicates(columnsToDedupeBy)
+ .drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL, TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL)
+ } else {
+ outputDF = outputDF.drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL)
+ }
+
+ logDebug("writeAllChanges: join output plan:\n" + outputDF.queryExecution)
+
+ // Write to Delta
+ val newFiles = deltaTxn
+ .writeFiles(repartitionIfNeeded(spark, outputDF, deltaTxn.metadata.partitionColumns))
+
+ // Update metrics
+ val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles)
+ metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile])
+ metrics("numTargetChangeFilesAdded") += newFiles.count(_.isInstanceOf[AddCDCFile])
+ metrics("numTargetChangeFileBytes") += newFiles.collect{ case f: AddCDCFile => f.size }.sum
+ metrics("numTargetBytesAdded") += addedBytes
+ metrics("numTargetPartitionsAddedTo") += addedPartitions
+ if (multipleMatchDeleteOnlyOvercount.isDefined) {
+ // Compensate for counting duplicates during the query.
+ val actualRowsDeleted =
+ metrics("numTargetRowsDeleted").value - multipleMatchDeleteOnlyOvercount.get
+ assert(actualRowsDeleted >= 0)
+ metrics("numTargetRowsDeleted").set(actualRowsDeleted)
+ }
+
+ newFiles
+ }
+
+ private def addMergeJoinProcessor(
+ spark: SparkSession,
+ joinedPlan: LogicalPlan,
+ outputRowSchema: StructType,
+ targetRowHasNoMatch: Expression,
+ sourceRowHasNoMatch: Expression,
+ matchedConditions: Seq[Expression],
+ matchedOutputs: Seq[Seq[Seq[Expression]]],
+ notMatchedConditions: Seq[Expression],
+ notMatchedOutputs: Seq[Seq[Seq[Expression]]],
+ notMatchedBySourceConditions: Seq[Expression],
+ notMatchedBySourceOutputs: Seq[Seq[Seq[Expression]]],
+ noopCopyOutput: Seq[Expression],
+ deleteRowOutput: Seq[Expression]): Dataset[Row] = {
+ def wrap(e: Expression): BaseExprMeta[Expression] = {
+ GpuOverrides.wrapExpr(e, rapidsConf, None)
+ }
+
+ val targetRowHasNoMatchMeta = wrap(targetRowHasNoMatch)
+ val sourceRowHasNoMatchMeta = wrap(sourceRowHasNoMatch)
+ val matchedConditionsMetas = matchedConditions.map(wrap)
+ val matchedOutputsMetas = matchedOutputs.map(_.map(_.map(wrap)))
+ val notMatchedConditionsMetas = notMatchedConditions.map(wrap)
+ val notMatchedOutputsMetas = notMatchedOutputs.map(_.map(_.map(wrap)))
+ val notMatchedBySourceConditionsMetas = notMatchedBySourceConditions.map(wrap)
+ val notMatchedBySourceOutputsMetas = notMatchedBySourceOutputs.map(_.map(_.map(wrap)))
+ val noopCopyOutputMetas = noopCopyOutput.map(wrap)
+ val deleteRowOutputMetas = deleteRowOutput.map(wrap)
+ val allMetas = Seq(targetRowHasNoMatchMeta, sourceRowHasNoMatchMeta) ++
+ matchedConditionsMetas ++ matchedOutputsMetas.flatten.flatten ++
+ notMatchedConditionsMetas ++ notMatchedOutputsMetas.flatten.flatten ++
+ notMatchedBySourceConditionsMetas ++ notMatchedBySourceOutputsMetas.flatten.flatten ++
+ noopCopyOutputMetas ++ deleteRowOutputMetas
+ allMetas.foreach(_.tagForGpu())
+ val canReplace = allMetas.forall(_.canExprTreeBeReplaced) && rapidsConf.isOperatorEnabled(
+ "spark.rapids.sql.exec.RapidsProcessDeltaMergeJoinExec", false, false)
+ if (rapidsConf.shouldExplainAll || (rapidsConf.shouldExplain && !canReplace)) {
+ val exprExplains = allMetas.map(_.explain(rapidsConf.shouldExplainAll))
+ val execWorkInfo = if (canReplace) {
+ "will run on GPU"
+ } else {
+ "cannot run on GPU because not all merge processing expressions can be replaced"
+ }
+ logWarning(s" $execWorkInfo:\n" +
+ s" ${exprExplains.mkString(" ")}")
+ }
+
+ if (canReplace) {
+ val processedJoinPlan = RapidsProcessDeltaMergeJoin(
+ joinedPlan,
+ outputRowSchema.toAttributes,
+ targetRowHasNoMatch = targetRowHasNoMatch,
+ sourceRowHasNoMatch = sourceRowHasNoMatch,
+ matchedConditions = matchedConditions,
+ matchedOutputs = matchedOutputs,
+ notMatchedConditions = notMatchedConditions,
+ notMatchedOutputs = notMatchedOutputs,
+ notMatchedBySourceConditions = notMatchedBySourceConditions,
+ notMatchedBySourceOutputs = notMatchedBySourceOutputs,
+ noopCopyOutput = noopCopyOutput,
+ deleteRowOutput = deleteRowOutput)
+ Dataset.ofRows(spark, processedJoinPlan)
+ } else {
+ val joinedRowEncoder = RowEncoder(joinedPlan.schema)
+ val outputRowEncoder = RowEncoder(outputRowSchema).resolveAndBind()
+
+ val processor = new JoinedRowProcessor(
+ targetRowHasNoMatch = targetRowHasNoMatch,
+ sourceRowHasNoMatch = sourceRowHasNoMatch,
+ matchedConditions = matchedConditions,
+ matchedOutputs = matchedOutputs,
+ notMatchedConditions = notMatchedConditions,
+ notMatchedOutputs = notMatchedOutputs,
+ noopCopyOutput = noopCopyOutput,
+ deleteRowOutput = deleteRowOutput,
+ joinedAttributes = joinedPlan.output,
+ joinedRowEncoder = joinedRowEncoder,
+ outputRowEncoder = outputRowEncoder)
+
+ Dataset.ofRows(spark, joinedPlan).mapPartitions(processor.processPartition)(outputRowEncoder)
+ }
+ }
+
+ /**
+ * Build a new logical plan using the given `files` that has the same output columns (exprIds)
+ * as the `target` logical plan, so that existing update/insert expressions can be applied
+ * on this new plan.
+ */
+ private def buildTargetPlanWithFiles(
+ deltaTxn: OptimisticTransaction,
+ files: Seq[AddFile]): LogicalPlan = {
+ val targetOutputCols = getTargetOutputCols(deltaTxn)
+ val targetOutputColsMap = {
+ val colsMap: Map[String, NamedExpression] = targetOutputCols.view
+ .map(col => col.name -> col).toMap
+ if (conf.caseSensitiveAnalysis) {
+ colsMap
+ } else {
+ CaseInsensitiveMap(colsMap)
+ }
+ }
+
+ val plan = {
+ // We have to do surgery to use the attributes from `targetOutputCols` to scan the table.
+ // In cases of schema evolution, they may not be the same type as the original attributes.
+ val original =
+ deltaTxn.deltaLog.createDataFrame(deltaTxn.snapshot, files).queryExecution.analyzed
+ val transformed = original.transform {
+ case LogicalRelation(base, _, catalogTbl, isStreaming) =>
+ LogicalRelation(
+ base,
+ // We can ignore the new columns which aren't yet AttributeReferences.
+ targetOutputCols.collect { case a: AttributeReference => a },
+ catalogTbl,
+ isStreaming)
+ }
+
+ // In case of schema evolution & column mapping, we would also need to rebuild the file format
+ // because under column mapping, the reference schema within DeltaParquetFileFormat
+ // that is used to populate metadata needs to be updated
+ if (deltaTxn.metadata.columnMappingMode != NoMapping) {
+ val updatedFileFormat = deltaTxn.deltaLog.fileFormat(
+ deltaTxn.deltaLog.unsafeVolatileSnapshot.protocol, deltaTxn.metadata)
+ DeltaTableUtils.replaceFileFormat(transformed, updatedFileFormat)
+ } else {
+ transformed
+ }
+ }
+
+ // For each plan output column, find the corresponding target output column (by name) and
+ // create an alias
+ val aliases = plan.output.map {
+ case newAttrib: AttributeReference =>
+ val existingTargetAttrib = targetOutputColsMap.get(newAttrib.name)
+ .getOrElse {
+ throw new AnalysisException(
+ s"Could not find ${newAttrib.name} among the existing target output " +
+ targetOutputCols.mkString(","))
+ }.asInstanceOf[AttributeReference]
+
+ if (existingTargetAttrib.exprId == newAttrib.exprId) {
+ // It's not valid to alias an expression to its own exprId (this is considered a
+ // non-unique exprId by the analyzer), so we just use the attribute directly.
+ newAttrib
+ } else {
+ Alias(newAttrib, existingTargetAttrib.name)(exprId = existingTargetAttrib.exprId)
+ }
+ }
+
+ Project(aliases, plan)
+ }
+
+ /** Expressions to increment SQL metrics */
+ private def makeMetricUpdateUDF(name: String, deterministic: Boolean = false): Expression = {
+ // only capture the needed metric in a local variable
+ val metric = metrics(name)
+ var u = udf(new GpuDeltaMetricUpdateUDF(metric))
+ if (!deterministic) {
+ u = u.asNondeterministic()
+ }
+ u.apply().expr
+ }
+
+ private def getTargetOutputCols(txn: OptimisticTransaction): Seq[NamedExpression] = {
+ txn.metadata.schema.map { col =>
+ targetOutputAttributesMap
+ .get(col.name)
+ .map { a =>
+ AttributeReference(col.name, col.dataType, col.nullable)(a.exprId)
+ }
+ .getOrElse(Alias(Literal(null), col.name)()
+ )
+ }
+ }
+
+ /**
+ * Repartitions the output DataFrame by the partition columns if table is partitioned
+ * and `merge.repartitionBeforeWrite.enabled` is set to true.
+ */
+ protected def repartitionIfNeeded(
+ spark: SparkSession,
+ df: DataFrame,
+ partitionColumns: Seq[String]): DataFrame = {
+ if (partitionColumns.nonEmpty && spark.conf.get(DeltaSQLConf.MERGE_REPARTITION_BEFORE_WRITE)) {
+ df.repartition(partitionColumns.map(col): _*)
+ } else {
+ df
+ }
+ }
+
+ /**
+ * Execute the given `thunk` and return its result while recording the time taken to do it.
+ *
+ * @param sqlMetricName name of SQL metric to update with the time taken by the thunk
+ * @param thunk the code to execute
+ */
+ private def recordMergeOperation[A](sqlMetricName: String)(thunk: => A): A = {
+ val startTimeNs = System.nanoTime()
+ val r = thunk
+ val timeTakenMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)
+ if (sqlMetricName != null && timeTakenMs > 0) {
+ metrics(sqlMetricName) += timeTakenMs
+ }
+ r
+ }
+}
+
+object GpuMergeIntoCommand {
+ /**
+ * Spark UI will track all normal accumulators along with Spark tasks to show them on Web UI.
+ * However, the accumulator used by `MergeIntoCommand` can store a very large value since it
+ * tracks all files that need to be rewritten. We should ask Spark UI to not remember it,
+ * otherwise, the UI data may consume lots of memory. Hence, we use the prefix `internal.metrics.`
+ * to make this accumulator become an internal accumulator, so that it will not be tracked by
+ * Spark UI.
+ */
+ val TOUCHED_FILES_ACCUM_NAME = "internal.metrics.MergeIntoDelta.touchedFiles"
+
+ val ROW_ID_COL = "_row_id_"
+ val TARGET_ROW_ID_COL = "_target_row_id_"
+ val SOURCE_ROW_ID_COL = "_source_row_id_"
+ val FILE_NAME_COL = "_file_name_"
+ val SOURCE_ROW_PRESENT_COL = "_source_row_present_"
+ val TARGET_ROW_PRESENT_COL = "_target_row_present_"
+ val ROW_DROPPED_COL = GpuDeltaMergeConstants.ROW_DROPPED_COL
+ val INCR_ROW_COUNT_COL = "_incr_row_count_"
+
+ // Some Delta versions use Literal(null) which translates to a literal of NullType instead
+ // of the Literal(null, StringType) which is needed, so using a fixed version here
+ // rather than the version from Delta Lake.
+ val CDC_TYPE_NOT_CDC_LITERAL = Literal(null, StringType)
+
+ /**
+ * @param targetRowHasNoMatch whether a joined row is a target row with no match in the source
+ * table
+ * @param sourceRowHasNoMatch whether a joined row is a source row with no match in the target
+ * table
+ * @param matchedConditions condition for each match clause
+ * @param matchedOutputs corresponding output for each match clause. for each clause, we
+ * have 1-3 output rows, each of which is a sequence of expressions
+ * to apply to the joined row
+ * @param notMatchedConditions condition for each not-matched clause
+ * @param notMatchedOutputs corresponding output for each not-matched clause. for each clause,
+ * we have 1-2 output rows, each of which is a sequence of
+ * expressions to apply to the joined row
+ * @param noopCopyOutput no-op expression to copy a target row to the output
+ * @param deleteRowOutput expression to drop a row from the final output. this is used for
+ * source rows that don't match any not-matched clauses
+ * @param joinedAttributes schema of our outer-joined dataframe
+ * @param joinedRowEncoder joinedDF row encoder
+ * @param outputRowEncoder final output row encoder
+ */
+ class JoinedRowProcessor(
+ targetRowHasNoMatch: Expression,
+ sourceRowHasNoMatch: Expression,
+ matchedConditions: Seq[Expression],
+ matchedOutputs: Seq[Seq[Seq[Expression]]],
+ notMatchedConditions: Seq[Expression],
+ notMatchedOutputs: Seq[Seq[Seq[Expression]]],
+ noopCopyOutput: Seq[Expression],
+ deleteRowOutput: Seq[Expression],
+ joinedAttributes: Seq[Attribute],
+ joinedRowEncoder: ExpressionEncoder[Row],
+ outputRowEncoder: ExpressionEncoder[Row]) extends Serializable {
+
+ private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = {
+ UnsafeProjection.create(exprs, joinedAttributes)
+ }
+
+ private def generatePredicate(expr: Expression): BasePredicate = {
+ GeneratePredicate.generate(expr, joinedAttributes)
+ }
+
+ def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = {
+
+ val targetRowHasNoMatchPred = generatePredicate(targetRowHasNoMatch)
+ val sourceRowHasNoMatchPred = generatePredicate(sourceRowHasNoMatch)
+ val matchedPreds = matchedConditions.map(generatePredicate)
+ val matchedProjs = matchedOutputs.map(_.map(generateProjection))
+ val notMatchedPreds = notMatchedConditions.map(generatePredicate)
+ val notMatchedProjs = notMatchedOutputs.map(_.map(generateProjection))
+ val noopCopyProj = generateProjection(noopCopyOutput)
+ val deleteRowProj = generateProjection(deleteRowOutput)
+ val outputProj = UnsafeProjection.create(outputRowEncoder.schema)
+
+ // this is accessing ROW_DROPPED_COL. If ROW_DROPPED_COL is not in outputRowEncoder.schema
+ // then CDC must be disabled and it's the column after our output cols
+ def shouldDeleteRow(row: InternalRow): Boolean = {
+ row.getBoolean(
+ outputRowEncoder.schema.getFieldIndex(ROW_DROPPED_COL)
+ .getOrElse(outputRowEncoder.schema.fields.size)
+ )
+ }
+
+ def processRow(inputRow: InternalRow): Iterator[InternalRow] = {
+ if (targetRowHasNoMatchPred.eval(inputRow)) {
+ // Target row did not match any source row, so just copy it to the output
+ Iterator(noopCopyProj.apply(inputRow))
+ } else {
+ // identify which set of clauses to execute: matched or not-matched ones
+ val (predicates, projections, noopAction) = if (sourceRowHasNoMatchPred.eval(inputRow)) {
+ // Source row did not match with any target row, so insert the new source row
+ (notMatchedPreds, notMatchedProjs, deleteRowProj)
+ } else {
+ // Source row matched with target row, so update the target row
+ (matchedPreds, matchedProjs, noopCopyProj)
+ }
+
+ // find (predicate, projection) pair whose predicate satisfies inputRow
+ val pair = (predicates zip projections).find {
+ case (predicate, _) => predicate.eval(inputRow)
+ }
+
+ pair match {
+ case Some((_, projections)) =>
+ projections.map(_.apply(inputRow)).iterator
+ case None => Iterator(noopAction.apply(inputRow))
+ }
+ }
+ }
+
+ val toRow = joinedRowEncoder.createSerializer()
+ val fromRow = outputRowEncoder.createDeserializer()
+ rowIterator
+ .map(toRow)
+ .flatMap(processRow)
+ .filter(!shouldDeleteRow(_))
+ .map { notDeletedInternalRow =>
+ fromRow(outputProj(notDeletedInternalRow))
+ }
+ }
+ }
+
+ /** Count the number of distinct partition values among the AddFiles in the given set. */
+ def totalBytesAndDistinctPartitionValues(files: Seq[FileAction]): (Long, Int) = {
+ val distinctValues = new mutable.HashSet[Map[String, String]]()
+ var bytes = 0L
+ val iter = files.collect { case a: AddFile => a }.iterator
+ while (iter.hasNext) {
+ val file = iter.next()
+ distinctValues += file.partitionValues
+ bytes += file.size
+ }
+ // If the only distinct value map is an empty map, then it must be an unpartitioned table.
+ // Return 0 in that case.
+ val numDistinctValues =
+ if (distinctValues.size == 1 && distinctValues.head.isEmpty) 0 else distinctValues.size
+ (bytes, numDistinctValues)
+ }
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuOptimisticTransaction.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuOptimisticTransaction.scala
new file mode 100644
index 00000000000..3e836056b6d
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuOptimisticTransaction.scala
@@ -0,0 +1,312 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from OptimisticTransaction.scala and TransactionalWrite.scala
+ * in the Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.databricks.sql.transaction.tahoe.rapids
+
+import java.net.URI
+
+import scala.collection.mutable.ListBuffer
+
+import com.databricks.sql.transaction.tahoe._
+import com.databricks.sql.transaction.tahoe.actions.{AddFile, FileAction}
+import com.databricks.sql.transaction.tahoe.constraints.{Constraint, Constraints}
+import com.databricks.sql.transaction.tahoe.schema.InvariantViolationException
+import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf
+import com.nvidia.spark.rapids._
+import com.nvidia.spark.rapids.delta._
+import org.apache.commons.lang3.exception.ExceptionUtils
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormatWriter}
+import org.apache.spark.sql.functions.to_json
+import org.apache.spark.sql.rapids.{BasicColumnarWriteJobStatsTracker, ColumnarWriteJobStatsTracker, GpuFileFormatWriter, GpuWriteJobStatsTracker}
+import org.apache.spark.sql.rapids.delta.GpuIdentityColumn
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.{Clock, SerializableConfiguration}
+
+/**
+ * Used to perform a set of reads in a transaction and then commit a set of updates to the
+ * state of the log. All reads from the DeltaLog, MUST go through this instance rather
+ * than directly to the DeltaLog otherwise they will not be check for logical conflicts
+ * with concurrent updates.
+ *
+ * This class is not thread-safe.
+ *
+ * @param deltaLog The Delta Log for the table this transaction is modifying.
+ * @param snapshot The snapshot that this transaction is reading at.
+ * @param rapidsConf RAPIDS Accelerator config settings.
+ */
+class GpuOptimisticTransaction(
+ deltaLog: DeltaLog,
+ snapshot: Snapshot,
+ rapidsConf: RapidsConf)(implicit clock: Clock)
+ extends GpuOptimisticTransactionBase(deltaLog, snapshot, rapidsConf)(clock) {
+
+ /** Creates a new OptimisticTransaction.
+ *
+ * @param deltaLog The Delta Log for the table this transaction is modifying.
+ * @param rapidsConf RAPIDS Accelerator config settings
+ */
+ def this(deltaLog: DeltaLog, rapidsConf: RapidsConf)(implicit clock: Clock) = {
+ this(deltaLog, deltaLog.update(), rapidsConf)
+ }
+
+ private def getGpuStatsColExpr(
+ statsDataSchema: Seq[Attribute],
+ statsCollection: GpuStatisticsCollection): Expression = {
+ Dataset.ofRows(spark, LocalRelation(statsDataSchema))
+ .select(to_json(statsCollection.statsCollector))
+ .queryExecution.analyzed.expressions.head
+ }
+
+ /** Return the pair of optional stats tracker and stats collection class */
+ private def getOptionalGpuStatsTrackerAndStatsCollection(
+ output: Seq[Attribute],
+ partitionSchema: StructType, data: DataFrame): (
+ Option[GpuDeltaJobStatisticsTracker],
+ Option[GpuStatisticsCollection]) = {
+ if (spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_COLLECT_STATS)) {
+
+ val (statsDataSchema, statsCollectionSchema) = getStatsSchema(output, partitionSchema)
+
+ val indexedCols = DeltaConfigs.DATA_SKIPPING_NUM_INDEXED_COLS.fromMetaData(metadata)
+ val prefixLength =
+ spark.sessionState.conf.getConf(DeltaSQLConf.DATA_SKIPPING_STRING_PREFIX_LENGTH)
+ val tableSchema = {
+ // If collecting stats using the table schema, then pass in statsCollectionSchema.
+ // Otherwise pass in statsDataSchema to collect stats using the DataFrame schema.
+ if (spark.sessionState.conf.getConf(DeltaSQLConf
+ .DELTA_COLLECT_STATS_USING_TABLE_SCHEMA)) {
+ statsCollectionSchema.toStructType
+ } else {
+ statsDataSchema.toStructType
+ }
+ }
+
+ val _spark = spark
+ val protocol = deltaLog.unsafeVolatileSnapshot.protocol
+
+ val statsCollection = new GpuStatisticsCollection {
+ override val spark = _spark
+ override val deletionVectorsSupported =
+ protocol.isFeatureSupported(DeletionVectorsTableFeature)
+ override val tableDataSchema = tableSchema
+ override val dataSchema = statsDataSchema.toStructType
+ override val numIndexedCols = indexedCols
+ override val stringPrefixLength: Int = prefixLength
+ }
+
+ val statsColExpr = getGpuStatsColExpr(statsDataSchema, statsCollection)
+
+ val statsSchema = statsCollection.statCollectionSchema
+ val explodedDataSchema = statsCollection.explodedDataSchema
+ val batchStatsToRow = (batch: ColumnarBatch, row: InternalRow) => {
+ GpuStatisticsCollection.batchStatsToRow(statsSchema, explodedDataSchema, batch, row)
+ }
+ (Some(new GpuDeltaJobStatisticsTracker(statsDataSchema, statsColExpr, batchStatsToRow)),
+ Some(statsCollection))
+ } else {
+ (None, None)
+ }
+ }
+
+ override def writeFiles(
+ inputData: Dataset[_],
+ writeOptions: Option[DeltaOptions],
+ additionalConstraints: Seq[Constraint]): Seq[FileAction] = {
+ hasWritten = true
+
+ val spark = inputData.sparkSession
+ val (data, partitionSchema) = performCDCPartition(inputData)
+ val outputPath = deltaLog.dataPath
+
+ val (normalizedQueryExecution, output, generatedColumnConstraints, dataHighWaterMarks) = {
+ // TODO: is none ok to pass here?
+ normalizeData(deltaLog, None, data)
+ }
+ val highWaterMarks = trackHighWaterMarks.getOrElse(dataHighWaterMarks)
+
+ // Build a new plan with a stub GpuDeltaWrite node to work around undesired transitions between
+ // columns and rows when AQE is involved. Without this node in the plan, AdaptiveSparkPlanExec
+ // could be the root node of the plan. In that case we do not have enough context to know
+ // whether the AdaptiveSparkPlanExec should be columnar or not, since the GPU overrides do not
+ // see how the parent is using the AdaptiveSparkPlanExec outputs. By using this stub node that
+ // appears to be a data writing node to AQE (it derives from V2CommandExec), the
+ // AdaptiveSparkPlanExec will be planned as a child of this new node. That provides enough
+ // context to plan the AQE sub-plan properly with respect to columnar and row transitions.
+ // We could force the AQE node to be columnar here by explicitly replacing the node, but that
+ // breaks the connection between the queryExecution and the node that will actually execute.
+ val gpuWritePlan = Dataset.ofRows(spark, RapidsDeltaWrite(normalizedQueryExecution.logical))
+ val queryExecution = gpuWritePlan.queryExecution
+
+ val partitioningColumns = getPartitioningColumns(partitionSchema, output)
+
+ val committer = getCommitter(outputPath)
+
+ // If Statistics Collection is enabled, then create a stats tracker that will be injected during
+ // the FileFormatWriter.write call below and will collect per-file stats using
+ // StatisticsCollection
+ val (optionalStatsTracker, _) = getOptionalGpuStatsTrackerAndStatsCollection(output,
+ partitionSchema, data)
+
+ // schema should be normalized, therefore we can do an equality check
+ val (statsDataSchema, _) = getStatsSchema(output, partitionSchema)
+ val identityTracker = GpuIdentityColumn.createIdentityColumnStatsTracker(
+ spark,
+ statsDataSchema,
+ metadata.schema,
+ highWaterMarks)
+
+ val constraints =
+ Constraints.getAll(metadata, spark) ++ generatedColumnConstraints ++ additionalConstraints
+
+ val isOptimize = isOptimizeCommand(queryExecution.analyzed)
+
+ SQLExecution.withNewExecutionId(queryExecution, Option("deltaTransactionalWrite")) {
+ val outputSpec = FileFormatWriter.OutputSpec(
+ outputPath.toString,
+ Map.empty,
+ output)
+
+ // Remove any unnecessary row conversions added as part of Spark planning
+ val queryPhysicalPlan = queryExecution.executedPlan match {
+ case GpuColumnarToRowExec(child, _) => child
+ case p => p
+ }
+ val gpuRapidsWrite = queryPhysicalPlan match {
+ case g: GpuRapidsDeltaWriteExec => Some(g)
+ case _ => None
+ }
+
+ val empty2NullPlan = convertEmptyToNullIfNeeded(queryPhysicalPlan,
+ partitioningColumns, constraints)
+ val optimizedPlan =
+ applyOptimizeWriteIfNeeded(spark, empty2NullPlan, partitionSchema, isOptimize)
+ val planWithInvariants = addInvariantChecks(optimizedPlan, constraints)
+ val physicalPlan = convertToGpu(planWithInvariants)
+
+ val statsTrackers: ListBuffer[ColumnarWriteJobStatsTracker] = ListBuffer()
+
+ val hadoopConf = spark.sessionState.newHadoopConfWithOptions(
+ metadata.configuration ++ deltaLog.options)
+
+ if (spark.conf.get(DeltaSQLConf.DELTA_HISTORY_METRICS_ENABLED)) {
+ val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
+ val basicWriteJobStatsTracker = new BasicColumnarWriteJobStatsTracker(
+ serializableHadoopConf,
+ BasicWriteJobStatsTracker.metrics)
+ registerSQLMetrics(spark, basicWriteJobStatsTracker.driverSideMetrics)
+ statsTrackers.append(basicWriteJobStatsTracker)
+ gpuRapidsWrite.foreach { grw =>
+ val tracker = new GpuWriteJobStatsTracker(serializableHadoopConf,
+ grw.basicMetrics, grw.taskMetrics)
+ statsTrackers.append(tracker)
+ }
+ }
+
+ // Retain only a minimal selection of Spark writer options to avoid any potential
+ // compatibility issues
+ val options = writeOptions match {
+ case None => Map.empty[String, String]
+ case Some(writeOptions) =>
+ writeOptions.options.filterKeys { key =>
+ key.equalsIgnoreCase(DeltaOptions.MAX_RECORDS_PER_FILE) ||
+ key.equalsIgnoreCase(DeltaOptions.COMPRESSION)
+ }.toMap
+ }
+ val deltaFileFormat = deltaLog.fileFormat(deltaLog.unsafeVolatileSnapshot.protocol, metadata)
+ val gpuFileFormat = if (deltaFileFormat.getClass == classOf[DeltaParquetFileFormat]) {
+ new GpuParquetFileFormat
+ } else {
+ throw new IllegalStateException(s"file format $deltaFileFormat is not supported")
+ }
+
+ try {
+ logDebug(s"Physical plan for write:\n$physicalPlan")
+ GpuFileFormatWriter.write(
+ sparkSession = spark,
+ plan = physicalPlan,
+ fileFormat = gpuFileFormat,
+ committer = committer,
+ outputSpec = outputSpec,
+ hadoopConf = hadoopConf,
+ partitionColumns = partitioningColumns,
+ bucketSpec = None,
+ statsTrackers = optionalStatsTracker.toSeq ++ identityTracker.toSeq ++ statsTrackers,
+ options = options,
+ rapidsConf.stableSort,
+ rapidsConf.concurrentWriterPartitionFlushSize)
+ } catch {
+ case s: SparkException =>
+ // Pull an InvariantViolationException up to the top level if it was the root cause.
+ val violationException = ExceptionUtils.getRootCause(s)
+ if (violationException.isInstanceOf[InvariantViolationException]) {
+ throw violationException
+ } else {
+ throw s
+ }
+ }
+ }
+
+ val resultFiles = committer.addedStatuses.map { a =>
+ a.copy(stats = optionalStatsTracker.map(
+ _.recordedStats(new Path(new URI(a.path)).getName)).getOrElse(a.stats))
+ }.filter {
+ // In some cases, we can write out an empty `inputData`. Some examples of this (though, they
+ // may be fixed in the future) are the MERGE command when you delete with empty source, or
+ // empty target, or on disjoint tables. This is hard to catch before the write without
+ // collecting the DF ahead of time. Instead, we can return only the AddFiles that
+ // a) actually add rows, or
+ // b) don't have any stats so we don't know the number of rows at all
+ case a: AddFile => a.numLogicalRecords.forall(_ > 0)
+ case _ => true
+ }
+
+ identityTracker.foreach { tracker =>
+ updatedIdentityHighWaterMarks.appendAll(tracker.highWaterMarks.toSeq)
+ }
+ val fileActions = resultFiles.toSeq ++ committer.changeFiles
+
+ // Check if auto-compaction is enabled.
+ // (Auto compaction checks are derived from the work in
+ // https://github.com/delta-io/delta/pull/1156).
+ lazy val autoCompactEnabled =
+ spark.sessionState.conf
+ .getConf[String](DeltaSQLConf.DELTA_AUTO_COMPACT_ENABLED)
+ .getOrElse {
+ DeltaConfigs.AUTO_COMPACT.fromMetaData(metadata)
+ .getOrElse("false")
+ }.toBoolean
+
+ if (!isOptimize && autoCompactEnabled && fileActions.nonEmpty) {
+ registerPostCommitHook(GpuDoAutoCompaction)
+ }
+
+ fileActions
+ }
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuOptimizeExecutor.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuOptimizeExecutor.scala
new file mode 100644
index 00000000000..df619c3b1a7
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuOptimizeExecutor.scala
@@ -0,0 +1,405 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from:
+ * 1. DoAutoCompaction.scala from PR#1156 at https://github.com/delta-io/delta/pull/1156,
+ * 2. OptimizeTableCommand.scala from the Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.databricks.sql.transaction.tahoe.rapids
+
+import java.util.ConcurrentModificationException
+
+import scala.annotation.tailrec
+import scala.collection.mutable.ArrayBuffer
+
+import com.databricks.sql.io.skipping.MultiDimClustering
+import com.databricks.sql.transaction.tahoe._
+import com.databricks.sql.transaction.tahoe.DeltaOperations.Operation
+import com.databricks.sql.transaction.tahoe.actions.{Action, AddFile, FileAction, RemoveFile}
+import com.databricks.sql.transaction.tahoe.commands.DeltaCommand
+import com.databricks.sql.transaction.tahoe.commands.optimize._
+import com.databricks.sql.transaction.tahoe.files.SQLMetricsReporting
+import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf
+import com.nvidia.spark.rapids.delta.RapidsDeltaSQLConf
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext.SPARK_JOB_GROUP_ID
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetrics.createMetric
+import org.apache.spark.util.ThreadUtils
+
+class GpuOptimizeExecutor(
+ sparkSession: SparkSession,
+ txn: OptimisticTransaction,
+ partitionPredicate: Seq[Expression],
+ zOrderByColumns: Seq[String],
+ prevCommitActions: Seq[Action])
+ extends DeltaCommand with SQLMetricsReporting with Serializable {
+
+ /** Timestamp to use in [[FileAction]] */
+ private val operationTimestamp = System.currentTimeMillis
+
+ private val isMultiDimClustering = zOrderByColumns.nonEmpty
+ private val isAutoCompact = prevCommitActions.nonEmpty
+ private val optimizeType = GpuOptimizeType(isMultiDimClustering, isAutoCompact)
+
+ def optimize(): Seq[Row] = {
+ recordDeltaOperation(txn.deltaLog, "delta.optimize") {
+ val maxFileSize = optimizeType.maxFileSize
+ require(maxFileSize > 0, "maxFileSize must be > 0")
+
+ val minNumFilesInDir = optimizeType.minNumFiles
+ val (candidateFiles, filesToProcess) = optimizeType.targetFiles
+ val partitionSchema = txn.metadata.partitionSchema
+
+ // select all files in case of multi-dimensional clustering
+ val partitionsToCompact = filesToProcess
+ .groupBy(_.partitionValues)
+ .filter { case (_, filesInPartition) => filesInPartition.size >= minNumFilesInDir }
+ .toSeq
+
+ val groupedJobs = groupFilesIntoBins(partitionsToCompact, maxFileSize)
+ val jobs = optimizeType.targetBins(groupedJobs)
+
+ val maxThreads =
+ sparkSession.sessionState.conf.getConf(DeltaSQLConf.DELTA_OPTIMIZE_MAX_THREADS)
+ val updates = ThreadUtils.parmap(jobs, "OptimizeJob", maxThreads) { partitionBinGroup =>
+ runOptimizeBinJob(txn, partitionBinGroup._1, partitionBinGroup._2, maxFileSize)
+ }.flatten
+
+ val addedFiles = updates.collect { case a: AddFile => a }
+ val removedFiles = updates.collect { case r: RemoveFile => r }
+ if (addedFiles.nonEmpty) {
+ val operation = DeltaOperations.Optimize(partitionPredicate, zOrderByColumns)
+ val metrics = createMetrics(sparkSession.sparkContext, addedFiles, removedFiles)
+ commitAndRetry(txn, operation, updates, metrics) { newTxn =>
+ val newPartitionSchema = newTxn.metadata.partitionSchema
+ val candidateSetOld = candidateFiles.map(_.path).toSet
+ val candidateSetNew = newTxn.filterFiles(partitionPredicate).map(_.path).toSet
+
+ // As long as all of the files that we compacted are still part of the table,
+ // and the partitioning has not changed it is valid to continue to try
+ // and commit this checkpoint.
+ if (candidateSetOld.subsetOf(candidateSetNew) && partitionSchema == newPartitionSchema) {
+ true
+ } else {
+ val deleted = candidateSetOld -- candidateSetNew
+ logWarning(s"The following compacted files were delete " +
+ s"during checkpoint ${deleted.mkString(",")}. Aborting the compaction.")
+ false
+ }
+ }
+ }
+
+ val optimizeStats = OptimizeStats()
+ optimizeStats.addedFilesSizeStats.merge(addedFiles)
+ optimizeStats.removedFilesSizeStats.merge(removedFiles)
+ optimizeStats.numPartitionsOptimized = jobs.map(j => j._1).distinct.size
+ optimizeStats.numBatches = jobs.size
+ optimizeStats.totalConsideredFiles = candidateFiles.size
+ optimizeStats.totalFilesSkipped = optimizeStats.totalConsideredFiles - removedFiles.size
+ optimizeStats.totalClusterParallelism = sparkSession.sparkContext.defaultParallelism
+
+ if (isMultiDimClustering) {
+ val inputFileStats =
+ ZOrderFileStats(removedFiles.size, removedFiles.map(_.size.getOrElse(0L)).sum)
+ optimizeStats.zOrderStats = Some(ZOrderStats(
+ strategyName = "all", // means process all files in a partition
+ inputCubeFiles = ZOrderFileStats(0, 0),
+ inputOtherFiles = inputFileStats,
+ inputNumCubes = 0,
+ mergedFiles = inputFileStats,
+ // There will one z-cube for each partition
+ numOutputCubes = optimizeStats.numPartitionsOptimized))
+ }
+
+ return Seq(Row(txn.deltaLog.dataPath.toString, optimizeStats.toOptimizeMetrics))
+ }
+ }
+
+ /**
+ * Utility methods to group files into bins for optimize.
+ *
+ * @param partitionsToCompact List of files to compact group by partition.
+ * Partition is defined by the partition values (partCol -> partValue)
+ * @param maxTargetFileSize Max size (in bytes) of the compaction output file.
+ * @return Sequence of bins. Each bin contains one or more files from the same
+ * partition and targeted for one output file.
+ */
+ private def groupFilesIntoBins(
+ partitionsToCompact: Seq[(Map[String, String], Seq[AddFile])],
+ maxTargetFileSize: Long): Seq[(Map[String, String], Seq[AddFile])] = {
+
+ partitionsToCompact.flatMap {
+ case (partition, files) =>
+ val bins = new ArrayBuffer[Seq[AddFile]]()
+
+ val currentBin = new ArrayBuffer[AddFile]()
+ var currentBinSize = 0L
+
+ files.sortBy(_.size).foreach { file =>
+ // Generally, a bin is a group of existing files, whose total size does not exceed the
+ // desired maxFileSize. They will be coalesced into a single output file.
+ // However, if isMultiDimClustering = true, all files in a partition will be read by the
+ // same job, the data will be range-partitioned and numFiles = totalFileSize / maxFileSize
+ // will be produced. See below.
+ if (file.size + currentBinSize > maxTargetFileSize && !isMultiDimClustering) {
+ bins += currentBin.toVector
+ currentBin.clear()
+ currentBin += file
+ currentBinSize = file.size
+ } else {
+ currentBin += file
+ currentBinSize += file.size
+ }
+ }
+
+ if (currentBin.nonEmpty) {
+ bins += currentBin.toVector
+ }
+
+ bins.map(b => (partition, b))
+ // select bins that have at least two files or in case of multi-dim clustering
+ // select all bins
+ .filter(_._2.size > 1 || isMultiDimClustering)
+ }
+ }
+
+ /**
+ * Utility method to run a Spark job to compact the files in given bin
+ *
+ * @param txn [[OptimisticTransaction]] instance in use to commit the changes to DeltaLog.
+ * @param partition Partition values of the partition that files in [[bin]] belongs to.
+ * @param bin List of files to compact into one large file.
+ * @param maxFileSize Targeted output file size in bytes
+ */
+ private def runOptimizeBinJob(
+ txn: OptimisticTransaction,
+ partition: Map[String, String],
+ bin: Seq[AddFile],
+ maxFileSize: Long): Seq[FileAction] = {
+ val baseTablePath = txn.deltaLog.dataPath
+
+ val input = txn.deltaLog.createDataFrame(txn.snapshot, bin, actionTypeOpt = Some("Optimize"))
+ val repartitionDF = if (isMultiDimClustering) {
+ val totalSize = bin.map(_.size).sum
+ val approxNumFiles = Math.max(1, totalSize / maxFileSize).toInt
+ MultiDimClustering.cluster(
+ input,
+ approxNumFiles,
+ zOrderByColumns)
+ } else {
+ val useRepartition = sparkSession.sessionState.conf.getConf(
+ DeltaSQLConf.DELTA_OPTIMIZE_REPARTITION_ENABLED)
+ if (useRepartition) {
+ input.repartition(numPartitions = 1)
+ } else {
+ input.coalesce(numPartitions = 1)
+ }
+ }
+
+ val partitionDesc = partition.toSeq.map(entry => entry._1 + "=" + entry._2).mkString(",")
+
+ val partitionName = if (partition.isEmpty) "" else s" in partition ($partitionDesc)"
+ val description = s"$baseTablePath Optimizing ${bin.size} files" + partitionName
+ sparkSession.sparkContext.setJobGroup(
+ sparkSession.sparkContext.getLocalProperty(SPARK_JOB_GROUP_ID),
+ description)
+
+ val addFiles = txn.writeFiles(repartitionDF).collect {
+ case a: AddFile =>
+ a.copy(dataChange = false)
+ case other =>
+ throw new IllegalStateException(
+ s"Unexpected action $other with type ${other.getClass}. File compaction job output" +
+ s"should only have AddFiles")
+ }
+ val removeFiles = bin.map(f => f.removeWithTimestamp(operationTimestamp, dataChange = false))
+ val updates = addFiles ++ removeFiles
+ updates
+ }
+
+ private type PartitionedBin = (Map[String, String], Seq[AddFile])
+
+ private trait GpuOptimizeType {
+ def minNumFiles: Long
+
+ def maxFileSize: Long =
+ sparkSession.sessionState.conf.getConf(DeltaSQLConf.DELTA_OPTIMIZE_MAX_FILE_SIZE)
+
+ def targetFiles: (Seq[AddFile], Seq[AddFile])
+
+ def targetBins(jobs: Seq[PartitionedBin]): Seq[PartitionedBin] = jobs
+ }
+
+ private case class GpuCompaction() extends GpuOptimizeType {
+ def minNumFiles: Long = 2
+
+ def targetFiles: (Seq[AddFile], Seq[AddFile]) = {
+ val minFileSize = sparkSession.sessionState.conf.getConf(
+ DeltaSQLConf.DELTA_OPTIMIZE_MIN_FILE_SIZE)
+ require(minFileSize > 0, "minFileSize must be > 0")
+ val candidateFiles = txn.filterFiles(partitionPredicate)
+ val filesToProcess = candidateFiles.filter(_.size < minFileSize)
+ (candidateFiles, filesToProcess)
+ }
+ }
+
+ private case class GpuMultiDimOrdering() extends GpuOptimizeType {
+ def minNumFiles: Long = 1
+
+ def targetFiles: (Seq[AddFile], Seq[AddFile]) = {
+ // select all files in case of multi-dimensional clustering
+ val candidateFiles = txn.filterFiles(partitionPredicate)
+ (candidateFiles, candidateFiles)
+ }
+ }
+
+ private case class GpuAutoCompaction() extends GpuOptimizeType {
+ def minNumFiles: Long = {
+ val minNumFiles =
+ sparkSession.sessionState.conf.getConf(DeltaSQLConf.DELTA_AUTO_COMPACT_MIN_NUM_FILES)
+ require(minNumFiles > 0, "minNumFiles must be > 0")
+ minNumFiles
+ }
+
+ override def maxFileSize: Long =
+ sparkSession.sessionState.conf.getConf(DeltaSQLConf.DELTA_AUTO_COMPACT_MAX_FILE_SIZE)
+ .getOrElse(128 * 1024 * 1024)
+
+ override def targetFiles: (Seq[AddFile], Seq[AddFile]) = {
+ val autoCompactTarget =
+ sparkSession.sessionState.conf.getConf(RapidsDeltaSQLConf.AUTO_COMPACT_TARGET)
+ // Filter the candidate files according to autoCompact.target config.
+ lazy val addedFiles = prevCommitActions.collect { case a: AddFile => a }
+ val candidateFiles = autoCompactTarget match {
+ case "table" =>
+ txn.filterFiles()
+ case "commit" =>
+ addedFiles
+ case "partition" =>
+ val eligiblePartitions = addedFiles.map(_.partitionValues).toSet
+ txn.filterFiles().filter(f => eligiblePartitions.contains(f.partitionValues))
+ case _ =>
+ logError(s"Invalid config for autoCompact.target: $autoCompactTarget. " +
+ s"Falling back to the default value 'table'.")
+ txn.filterFiles()
+ }
+ val filesToProcess = candidateFiles.filter(_.size < maxFileSize)
+ (candidateFiles, filesToProcess)
+ }
+
+ override def targetBins(jobs: Seq[PartitionedBin]): Seq[PartitionedBin] = {
+ var acc = 0L
+ val maxCompactBytes =
+ sparkSession.sessionState.conf.getConf(RapidsDeltaSQLConf.AUTO_COMPACT_MAX_COMPACT_BYTES)
+ // bins with more files are prior to less files.
+ jobs
+ .sortBy { case (_, filesInBin) => -filesInBin.length }
+ .takeWhile { case (_, filesInBin) =>
+ acc += filesInBin.map(_.size).sum
+ acc <= maxCompactBytes
+ }
+ }
+ }
+
+ private object GpuOptimizeType {
+
+ def apply(isMultiDimClustering: Boolean, isAutoCompact: Boolean): GpuOptimizeType = {
+ if (isMultiDimClustering) {
+ GpuMultiDimOrdering()
+ } else if (isAutoCompact) {
+ GpuAutoCompaction()
+ } else {
+ GpuCompaction()
+ }
+ }
+ }
+
+ /**
+ * Attempts to commit the given actions to the log. In the case of a concurrent update,
+ * the given function will be invoked with a new transaction to allow custom conflict
+ * detection logic to indicate it is safe to try again, by returning `true`.
+ *
+ * This function will continue to try to commit to the log as long as `f` returns `true`,
+ * otherwise throws a subclass of [[ConcurrentModificationException]].
+ */
+ @tailrec
+ private def commitAndRetry(
+ txn: OptimisticTransaction,
+ optimizeOperation: Operation,
+ actions: Seq[Action],
+ metrics: Map[String, SQLMetric])(f: OptimisticTransaction => Boolean)
+ : Unit = {
+ try {
+ txn.registerSQLMetrics(sparkSession, metrics)
+ txn.commit(actions, optimizeOperation)
+ } catch {
+ case e: ConcurrentModificationException =>
+ val newTxn = txn.deltaLog.startTransaction()
+ if (f(newTxn)) {
+ logInfo("Retrying commit after checking for semantic conflicts with concurrent updates.")
+ commitAndRetry(newTxn, optimizeOperation, actions, metrics)(f)
+ } else {
+ logWarning("Semantic conflicts detected. Aborting operation.")
+ throw e
+ }
+ }
+ }
+
+ /** Create a map of SQL metrics for adding to the commit history. */
+ private def createMetrics(
+ sparkContext: SparkContext,
+ addedFiles: Seq[AddFile],
+ removedFiles: Seq[RemoveFile]): Map[String, SQLMetric] = {
+
+ def setAndReturnMetric(description: String, value: Long) = {
+ val metric = createMetric(sparkContext, description)
+ metric.set(value)
+ metric
+ }
+
+ def totalSize(actions: Seq[FileAction]): Long = {
+ var totalSize = 0L
+ actions.foreach { file =>
+ val fileSize = file match {
+ case addFile: AddFile => addFile.size
+ case removeFile: RemoveFile => removeFile.size.getOrElse(0L)
+ case default =>
+ throw new IllegalArgumentException(s"Unknown FileAction type: ${default.getClass}")
+ }
+ totalSize += fileSize
+ }
+ totalSize
+ }
+
+ val sizeStats = FileSizeStatsWithHistogram.create(addedFiles.map(_.size).sorted)
+ Map[String, SQLMetric](
+ "minFileSize" -> setAndReturnMetric("minimum file size", sizeStats.get.min),
+ "p25FileSize" -> setAndReturnMetric("25th percentile file size", sizeStats.get.p25),
+ "p50FileSize" -> setAndReturnMetric("50th percentile file size", sizeStats.get.p50),
+ "p75FileSize" -> setAndReturnMetric("75th percentile file size", sizeStats.get.p75),
+ "maxFileSize" -> setAndReturnMetric("maximum file size", sizeStats.get.max),
+ "numAddedFiles" -> setAndReturnMetric("total number of files added.", addedFiles.size),
+ "numRemovedFiles" -> setAndReturnMetric("total number of files removed.", removedFiles.size),
+ "numAddedBytes" -> setAndReturnMetric("total number of bytes added", totalSize(addedFiles)),
+ "numRemovedBytes" ->
+ setAndReturnMetric("total number of bytes removed", totalSize(removedFiles)))
+ }
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuUpdateCommand.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuUpdateCommand.scala
new file mode 100644
index 00000000000..531e34e9f3c
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuUpdateCommand.scala
@@ -0,0 +1,276 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * This file was derived from UpdateCommand.scala
+ * in the Delta Lake project at https://github.com/delta-io/delta.
+ *
+ * Copyright (2021) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.databricks.sql.transaction.tahoe.rapids
+
+import com.databricks.sql.transaction.tahoe.{DeltaLog, DeltaOperations, DeltaTableUtils, DeltaUDF, OptimisticTransaction}
+import com.databricks.sql.transaction.tahoe.DeltaCommitTag._
+import com.databricks.sql.transaction.tahoe.RowTracking
+import com.databricks.sql.transaction.tahoe.actions.{AddCDCFile, AddFile, FileAction}
+import com.databricks.sql.transaction.tahoe.commands.{DeltaCommand, DMLUtils, UpdateCommand, UpdateMetric}
+import com.databricks.sql.transaction.tahoe.files.{TahoeBatchFileIndex, TahoeFileIndex}
+import com.nvidia.spark.rapids.delta.GpuDeltaMetricUpdateUDF
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.{Column, Dataset, Row, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal}
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.execution.command.LeafRunnableCommand
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics.{createMetric, createTimingMetric}
+import org.apache.spark.sql.functions.input_file_name
+import org.apache.spark.sql.types.LongType
+
+case class GpuUpdateCommand(
+ gpuDeltaLog: GpuDeltaLog,
+ tahoeFileIndex: TahoeFileIndex,
+ target: LogicalPlan,
+ updateExpressions: Seq[Expression],
+ condition: Option[Expression])
+ extends LeafRunnableCommand with DeltaCommand {
+
+ override val output: Seq[Attribute] = {
+ Seq(AttributeReference("num_affected_rows", LongType)())
+ }
+
+ override def innerChildren: Seq[QueryPlan[_]] = Seq(target)
+
+ @transient private lazy val sc: SparkContext = SparkContext.getOrCreate()
+
+ override lazy val metrics = Map[String, SQLMetric](
+ "numAddedFiles" -> createMetric(sc, "number of files added."),
+ "numRemovedFiles" -> createMetric(sc, "number of files removed."),
+ "numUpdatedRows" -> createMetric(sc, "number of rows updated."),
+ "numCopiedRows" -> createMetric(sc, "number of rows copied."),
+ "executionTimeMs" ->
+ createTimingMetric(sc, "time taken to execute the entire operation"),
+ "scanTimeMs" ->
+ createTimingMetric(sc, "time taken to scan the files for matches"),
+ "rewriteTimeMs" ->
+ createTimingMetric(sc, "time taken to rewrite the matched files"),
+ "numAddedChangeFiles" -> createMetric(sc, "number of change data capture files generated"),
+ "changeFileBytes" -> createMetric(sc, "total size of change data capture files generated"),
+ "numTouchedRows" -> createMetric(sc, "number of rows touched (copied + updated)"),
+ "numDeletionVectorsAdded" -> createMetric(sc, "number of deletion vectors added."),
+ "numDeletionVectorsRemoved" -> createMetric(sc, "number of deletion vectors removed.")
+ )
+
+ final override def run(sparkSession: SparkSession): Seq[Row] = {
+ recordDeltaOperation(tahoeFileIndex.deltaLog, "delta.dml.update") {
+ val deltaLog = tahoeFileIndex.deltaLog
+ gpuDeltaLog.withNewTransaction { txn =>
+ DeltaLog.assertRemovable(txn.snapshot)
+ performUpdate(sparkSession, deltaLog, txn)
+ }
+ // Re-cache all cached plans(including this relation itself, if it's cached) that refer to
+ // this data source relation.
+ sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, target)
+ }
+ Seq(Row(metrics("numUpdatedRows").value))
+ }
+
+ private def performUpdate(
+ sparkSession: SparkSession, deltaLog: DeltaLog, txn: OptimisticTransaction): Unit = {
+ import com.databricks.sql.transaction.tahoe.implicits._
+
+ var numTouchedFiles: Long = 0
+ var numRewrittenFiles: Long = 0
+ var numAddedChangeFiles: Long = 0
+ var changeFileBytes: Long = 0
+ var scanTimeMs: Long = 0
+ var rewriteTimeMs: Long = 0
+
+ val startTime = System.nanoTime()
+ val numFilesTotal = txn.snapshot.numOfFiles
+
+ val updateCondition = condition.getOrElse(Literal.TrueLiteral)
+ val (metadataPredicates, dataPredicates) =
+ DeltaTableUtils.splitMetadataAndDataPredicates(
+ updateCondition, txn.metadata.partitionColumns, sparkSession)
+ val candidateFiles = txn.filterFiles(metadataPredicates ++ dataPredicates)
+ val nameToAddFile = generateCandidateFileMap(deltaLog.dataPath, candidateFiles)
+
+ scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000
+
+ val filesToRewrite: Seq[AddFile] = if (candidateFiles.isEmpty) {
+ // Case 1: Do nothing if no row qualifies the partition predicates
+ // that are part of Update condition
+ Nil
+ } else if (dataPredicates.isEmpty) {
+ // Case 2: Update all the rows from the files that are in the specified partitions
+ // when the data filter is empty
+ candidateFiles
+ } else {
+ // Case 3: Find all the affected files using the user-specified condition
+ val fileIndex = new TahoeBatchFileIndex(
+ sparkSession, "update", candidateFiles, deltaLog, tahoeFileIndex.path, txn.snapshot)
+ // Keep everything from the resolved target except a new TahoeFileIndex
+ // that only involves the affected files instead of all files.
+ val newTarget = DeltaTableUtils.replaceFileIndex(target, fileIndex)
+ val data = Dataset.ofRows(sparkSession, newTarget)
+ val updatedRowCount = metrics("numUpdatedRows")
+ val updatedRowUdf = DeltaUDF.boolean {
+ new GpuDeltaMetricUpdateUDF(updatedRowCount)
+ }.asNondeterministic()
+ val pathsToRewrite =
+ withStatusCode("DELTA", UpdateCommand.FINDING_TOUCHED_FILES_MSG) {
+ data.filter(new Column(updateCondition))
+ .select(input_file_name())
+ .filter(updatedRowUdf())
+ .distinct()
+ .as[String]
+ .collect()
+ }
+
+ scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000
+
+ pathsToRewrite.map(getTouchedFile(deltaLog.dataPath, _, nameToAddFile)).toSeq
+ }
+
+ numTouchedFiles = filesToRewrite.length
+
+ val newActions = if (filesToRewrite.isEmpty) {
+ // Do nothing if no row qualifies the UPDATE condition
+ Nil
+ } else {
+ // Generate the new files containing the updated values
+ withStatusCode("DELTA", UpdateCommand.rewritingFilesMsg(filesToRewrite.size)) {
+ rewriteFiles(sparkSession, txn, tahoeFileIndex.path,
+ filesToRewrite.map(_.path), nameToAddFile, updateCondition)
+ }
+ }
+
+ rewriteTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 - scanTimeMs
+
+ val (changeActions, addActions) = newActions.partition(_.isInstanceOf[AddCDCFile])
+ numRewrittenFiles = addActions.size
+ numAddedChangeFiles = changeActions.size
+ changeFileBytes = changeActions.collect { case f: AddCDCFile => f.size }.sum
+
+ val totalActions = if (filesToRewrite.isEmpty) {
+ // Do nothing if no row qualifies the UPDATE condition
+ Nil
+ } else {
+ // Delete the old files and return those delete actions along with the new AddFile actions for
+ // files containing the updated values
+ val operationTimestamp = System.currentTimeMillis()
+ val deleteActions = filesToRewrite.map(_.removeWithTimestamp(operationTimestamp))
+
+ deleteActions ++ newActions
+ }
+
+ if (totalActions.nonEmpty) {
+ metrics("numAddedFiles").set(numRewrittenFiles)
+ metrics("numAddedChangeFiles").set(numAddedChangeFiles)
+ metrics("changeFileBytes").set(changeFileBytes)
+ metrics("numRemovedFiles").set(numTouchedFiles)
+ metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000)
+ metrics("scanTimeMs").set(scanTimeMs)
+ metrics("rewriteTimeMs").set(rewriteTimeMs)
+ // In the case where the numUpdatedRows is not captured, we can siphon out the metrics from
+ // the BasicWriteStatsTracker. This is for case 2 where the update condition contains only
+ // metadata predicates and so the entire partition is re-written.
+ val outputRows = txn.getMetric("numOutputRows").map(_.value).getOrElse(-1L)
+ if (metrics("numUpdatedRows").value == 0 && outputRows != 0 &&
+ metrics("numCopiedRows").value == 0) {
+ // We know that numTouchedRows = numCopiedRows + numUpdatedRows.
+ // Since an entire partition was re-written, no rows were copied.
+ // So numTouchedRows == numUpdateRows
+ metrics("numUpdatedRows").set(metrics("numTouchedRows").value)
+ } else {
+ // This is for case 3 where the update condition contains both metadata and data predicates
+ // so relevant files will have some rows updated and some rows copied. We don't need to
+ // consider case 1 here, where no files match the update condition, as we know that
+ // `totalActions` is empty.
+ metrics("numCopiedRows").set(
+ metrics("numTouchedRows").value - metrics("numUpdatedRows").value)
+ }
+ metrics("numDeletionVectorsAdded").set(0)
+ metrics("numDeletionVectorsRemoved").set(0)
+ txn.registerSQLMetrics(sparkSession, metrics)
+ val tags = DMLUtils.TaggedCommitData.EMPTY
+ .withTag(PreservedRowTrackingTag, RowTracking.isEnabled(txn.protocol, txn.metadata))
+ .withTag(NoRowsCopiedTag, metrics("numCopiedRows").value == 0)
+ txn.commitIfNeeded(totalActions, DeltaOperations.Update(condition), tags.stringTags)
+ // This is needed to make the SQL metrics visible in the Spark UI
+ val executionId = sparkSession.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ SQLMetrics.postDriverMetricUpdates(
+ sparkSession.sparkContext, executionId, metrics.values.toSeq)
+ }
+
+ recordDeltaEvent(
+ deltaLog,
+ "delta.dml.update.stats",
+ data = UpdateMetric(
+ condition = condition.map(_.sql).getOrElse("true"),
+ numFilesTotal,
+ numTouchedFiles,
+ numRewrittenFiles,
+ numAddedChangeFiles,
+ changeFileBytes,
+ scanTimeMs,
+ rewriteTimeMs)
+ )
+ }
+
+ /**
+ * Scan all the affected files and write out the updated files.
+ *
+ * When CDF is enabled, includes the generation of CDC preimage and postimage columns for
+ * changed rows.
+ *
+ * @return the list of [[AddFile]]s and [[AddCDCFile]]s that have been written.
+ */
+ private def rewriteFiles(
+ spark: SparkSession,
+ txn: OptimisticTransaction,
+ rootPath: Path,
+ inputLeafFiles: Seq[String],
+ nameToAddFileMap: Map[String, AddFile],
+ condition: Expression): Seq[FileAction] = {
+ // Containing the map from the relative file path to AddFile
+ val baseRelation = buildBaseRelation(
+ spark, txn, "update", rootPath, inputLeafFiles, nameToAddFileMap)
+ val newTarget = DeltaTableUtils.replaceFileIndex(target, baseRelation.location)
+ val targetDf = Dataset.ofRows(spark, newTarget)
+
+ // Number of total rows that we have seen, i.e. are either copying or updating (sum of both).
+ // This will be used later, along with numUpdatedRows, to determine numCopiedRows.
+ val numTouchedRows = metrics("numTouchedRows")
+ val numTouchedRowsUdf = DeltaUDF.boolean {
+ new GpuDeltaMetricUpdateUDF(numTouchedRows)
+ }.asNondeterministic()
+
+ val updatedDataFrame = UpdateCommand.withUpdatedColumns(
+ target.output,
+ updateExpressions,
+ condition,
+ targetDf
+ .filter(numTouchedRowsUdf())
+ .withColumn(UpdateCommand.CONDITION_COLUMN_NAME, new Column(condition)),
+ UpdateCommand.shouldOutputCdc(txn))
+
+ txn.writeFiles(updatedDataFrame)
+ }
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala
new file mode 100644
index 00000000000..32b7656b45b
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids.delta
+
+/**
+ * Implements the Delta Probe interface for probing the Delta Lake provider on Databricks.
+ * @note This is instantiated via reflection from ShimLoader.
+ */
+class DeltaProbeImpl extends DeltaProbe {
+ // Delta Lake is built-in for Databricks instances, so no probing is necessary.
+ override def getDeltaProvider: DeltaProvider = DeltaSpark341DBProvider
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark341DBProvider.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark341DBProvider.scala
new file mode 100644
index 00000000000..cd204fa0440
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark341DBProvider.scala
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta
+
+import com.databricks.sql.transaction.tahoe.rapids.GpuDeltaCatalog
+import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec}
+
+import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
+import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec}
+
+object DeltaSpark341DBProvider extends DatabricksDeltaProviderBase {
+
+ override def convertToGpu(
+ cpuExec: AtomicCreateTableAsSelectExec,
+ meta: AtomicCreateTableAsSelectExecMeta): GpuExec = {
+ GpuAtomicCreateTableAsSelectExec(
+ cpuExec.output,
+ new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
+ cpuExec.ident,
+ cpuExec.partitioning,
+ cpuExec.query,
+ cpuExec.tableSpec,
+ cpuExec.writeOptions,
+ cpuExec.ifNotExists)
+ }
+
+ override def convertToGpu(
+ cpuExec: AtomicReplaceTableAsSelectExec,
+ meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = {
+ GpuAtomicReplaceTableAsSelectExec(
+ cpuExec.output,
+ new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
+ cpuExec.ident,
+ cpuExec.partitioning,
+ cpuExec.query,
+ cpuExec.tableSpec,
+ cpuExec.writeOptions,
+ cpuExec.orCreate,
+ cpuExec.invalidateCache)
+ }
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala
new file mode 100644
index 00000000000..969d005b573
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta
+
+import com.databricks.sql.transaction.tahoe.{DeltaColumnMappingMode, DeltaParquetFileFormat, IdMapping}
+import com.databricks.sql.transaction.tahoe.DeltaParquetFileFormat.IS_ROW_DELETED_COLUMN_NAME
+import com.nvidia.spark.rapids.SparkPlanMeta
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.FileSourceScanExec
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+case class GpuDeltaParquetFileFormat(
+ override val columnMappingMode: DeltaColumnMappingMode,
+ override val referenceSchema: StructType,
+ isSplittable: Boolean) extends GpuDeltaParquetFileFormatBase {
+
+ if (columnMappingMode == IdMapping) {
+ val requiredReadConf = SQLConf.PARQUET_FIELD_ID_READ_ENABLED
+ require(SparkSession.getActiveSession.exists(_.sessionState.conf.getConf(requiredReadConf)),
+ s"${requiredReadConf.key} must be enabled to support Delta id column mapping mode")
+ val requiredWriteConf = SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED
+ require(SparkSession.getActiveSession.exists(_.sessionState.conf.getConf(requiredWriteConf)),
+ s"${requiredWriteConf.key} must be enabled to support Delta id column mapping mode")
+ }
+
+ override def isSplitable(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ path: Path): Boolean = isSplittable
+}
+
+object GpuDeltaParquetFileFormat {
+ def tagSupportForGpuFileSourceScan(meta: SparkPlanMeta[FileSourceScanExec]): Unit = {
+ val format = meta.wrapped.relation.fileFormat.asInstanceOf[DeltaParquetFileFormat]
+ val requiredSchema = meta.wrapped.requiredSchema
+ if (requiredSchema.exists(_.name == IS_ROW_DELETED_COLUMN_NAME)) {
+ meta.willNotWorkOnGpu(
+ s"reading metadata column $IS_ROW_DELETED_COLUMN_NAME is not supported")
+ }
+ if (format.hasDeletionVectorMap()) {
+ meta.willNotWorkOnGpu("deletion vectors are not supported")
+ }
+ }
+
+ def convertToGpu(fmt: DeltaParquetFileFormat): GpuDeltaParquetFileFormat = {
+ GpuDeltaParquetFileFormat(fmt.columnMappingMode, fmt.referenceSchema, fmt.isSplittable)
+ }
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/DeleteCommandMetaShim.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/DeleteCommandMetaShim.scala
new file mode 100644
index 00000000000..96863c71ad0
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/DeleteCommandMetaShim.scala
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta.shims
+
+import com.databricks.sql.transaction.tahoe.commands.DeletionVectorUtils
+import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf
+import com.nvidia.spark.rapids.delta.{DeleteCommandEdgeMeta, DeleteCommandMeta}
+
+object DeleteCommandMetaShim {
+ def tagForGpu(meta: DeleteCommandMeta): Unit = {
+ val dvFeatureEnabled = DeletionVectorUtils.deletionVectorsWritable(
+ meta.deleteCmd.deltaLog.unsafeVolatileSnapshot)
+ if (dvFeatureEnabled && meta.deleteCmd.conf.getConf(
+ DeltaSQLConf.DELETE_USE_PERSISTENT_DELETION_VECTORS)) {
+ // https://github.com/NVIDIA/spark-rapids/issues/8654
+ meta.willNotWorkOnGpu("Deletion vector writes are not supported on GPU")
+ }
+ }
+
+ def tagForGpu(meta: DeleteCommandEdgeMeta): Unit = {
+ val dvFeatureEnabled = DeletionVectorUtils.deletionVectorsWritable(
+ meta.deleteCmd.deltaLog.unsafeVolatileSnapshot)
+ if (dvFeatureEnabled && meta.deleteCmd.conf.getConf(
+ DeltaSQLConf.DELETE_USE_PERSISTENT_DELETION_VECTORS)) {
+ // https://github.com/NVIDIA/spark-rapids/issues/8654
+ meta.willNotWorkOnGpu("Deletion vector writes are not supported on GPU")
+ }
+ }
+}
\ No newline at end of file
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/DeltaLogShim.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/DeltaLogShim.scala
new file mode 100644
index 00000000000..0bd231e05a6
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/DeltaLogShim.scala
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta.shims
+
+import com.databricks.sql.transaction.tahoe.DeltaLog
+import com.databricks.sql.transaction.tahoe.actions.Metadata
+
+import org.apache.spark.sql.execution.datasources.FileFormat
+
+object DeltaLogShim {
+ def fileFormat(deltaLog: DeltaLog): FileFormat = {
+ deltaLog.fileFormat(deltaLog.unsafeVolatileSnapshot.protocol,
+ deltaLog.unsafeVolatileSnapshot.metadata)
+ }
+ def getMetadata(deltaLog: DeltaLog): Metadata = {
+ deltaLog.unsafeVolatileSnapshot.metadata
+ }
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/InvariantViolationExceptionShim.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/InvariantViolationExceptionShim.scala
new file mode 100644
index 00000000000..58714a91fd4
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/InvariantViolationExceptionShim.scala
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids.delta.shims
+
+import com.databricks.sql.transaction.tahoe.constraints.Constraints._
+import com.databricks.sql.transaction.tahoe.schema.DeltaInvariantViolationException
+import com.databricks.sql.transaction.tahoe.schema.InvariantViolationException
+
+object InvariantViolationExceptionShim {
+ def apply(c: Check, m: Map[String, Any]): InvariantViolationException = {
+ DeltaInvariantViolationException(c, m)
+ }
+
+ def apply(c: NotNull): InvariantViolationException = {
+ DeltaInvariantViolationException(c)
+ }
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala
new file mode 100644
index 00000000000..8e13a9e4b5a
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta.shims
+
+import com.databricks.sql.transaction.tahoe.commands.{MergeIntoCommand, MergeIntoCommandEdge}
+import com.databricks.sql.transaction.tahoe.rapids.{GpuDeltaLog, GpuMergeIntoCommand}
+import com.nvidia.spark.rapids.RapidsConf
+import com.nvidia.spark.rapids.delta.{MergeIntoCommandEdgeMeta, MergeIntoCommandMeta}
+
+import org.apache.spark.sql.execution.command.RunnableCommand
+
+object MergeIntoCommandMetaShim {
+ def tagForGpu(meta: MergeIntoCommandMeta, mergeCmd: MergeIntoCommand): Unit = {
+ // see https://github.com/NVIDIA/spark-rapids/issues/8415 for more information
+ if (mergeCmd.notMatchedBySourceClauses.nonEmpty) {
+ meta.willNotWorkOnGpu("notMatchedBySourceClauses not supported on GPU")
+ }
+ }
+
+ def tagForGpu(meta: MergeIntoCommandEdgeMeta, mergeCmd: MergeIntoCommandEdge): Unit = {
+ // see https://github.com/NVIDIA/spark-rapids/issues/8415 for more information
+ if (mergeCmd.notMatchedBySourceClauses.nonEmpty) {
+ meta.willNotWorkOnGpu("notMatchedBySourceClauses not supported on GPU")
+ }
+ }
+
+ def convertToGpu(mergeCmd: MergeIntoCommand, conf: RapidsConf): RunnableCommand = {
+ GpuMergeIntoCommand(
+ mergeCmd.source,
+ mergeCmd.target,
+ new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf),
+ mergeCmd.condition,
+ mergeCmd.matchedClauses,
+ mergeCmd.notMatchedClauses,
+ mergeCmd.notMatchedBySourceClauses,
+ mergeCmd.migratedSchema)(conf)
+ }
+
+ def convertToGpu(mergeCmd: MergeIntoCommandEdge, conf: RapidsConf): RunnableCommand = {
+ GpuMergeIntoCommand(
+ mergeCmd.source,
+ mergeCmd.target,
+ new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf),
+ mergeCmd.condition,
+ mergeCmd.matchedClauses,
+ mergeCmd.notMatchedClauses,
+ mergeCmd.notMatchedBySourceClauses,
+ mergeCmd.migratedSchema)(conf)
+ }
+}
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala
new file mode 100644
index 00000000000..8f5196d7c66
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids.delta.shims
+
+import com.databricks.sql.transaction.tahoe.stats.DeltaStatistics
+
+trait ShimUsesMetadataFields {
+ val NUM_RECORDS = DeltaStatistics.NUM_RECORDS
+ val MIN = DeltaStatistics.MIN
+ val MAX = DeltaStatistics.MAX
+ val NULL_COUNT = DeltaStatistics.NULL_COUNT
+}
\ No newline at end of file
diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/ShimDeltaUDF.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/ShimDeltaUDF.scala
new file mode 100644
index 00000000000..fd9052d9691
--- /dev/null
+++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/ShimDeltaUDF.scala
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.delta.shims
+
+import com.databricks.sql.transaction.tahoe.DeltaUDF
+
+import org.apache.spark.sql.expressions.UserDefinedFunction
+
+object ShimDeltaUDF {
+ def stringStringUdf(f: String => String): UserDefinedFunction = DeltaUDF.stringFromString(f)
+}
diff --git a/integration_tests/src/main/python/delta_lake_utils.py b/integration_tests/src/main/python/delta_lake_utils.py
index 61ed3d08a20..9a5545a6e3a 100644
--- a/integration_tests/src/main/python/delta_lake_utils.py
+++ b/integration_tests/src/main/python/delta_lake_utils.py
@@ -33,10 +33,10 @@
delta_writes_enabled_conf = {"spark.rapids.sql.format.delta.write.enabled": "true"}
-delta_write_fallback_allow = "ExecutedCommandExec,DataWritingCommandExec" if is_databricks122_or_later() else "ExecutedCommandExec"
+delta_write_fallback_allow = "ExecutedCommandExec,DataWritingCommandExec,WriteFilesExec" if is_databricks122_or_later() else "ExecutedCommandExec"
delta_write_fallback_check = "DataWritingCommandExec" if is_databricks122_or_later() else "ExecutedCommandExec"
-delta_optimized_write_fallback_allow = "ExecutedCommandExec,DataWritingCommandExec,DeltaOptimizedWriterExec" if is_databricks122_or_later() else "ExecutedCommandExec"
+delta_optimized_write_fallback_allow = "ExecutedCommandExec,DataWritingCommandExec,DeltaOptimizedWriterExec,WriteFilesExec" if is_databricks122_or_later() else "ExecutedCommandExec"
def _fixup_operation_metrics(opm):
"""Update the specified operationMetrics node to facilitate log comparisons"""
diff --git a/integration_tests/src/main/python/delta_lake_write_test.py b/integration_tests/src/main/python/delta_lake_write_test.py
index 3997ae3eba3..ec861ddefd3 100644
--- a/integration_tests/src/main/python/delta_lake_write_test.py
+++ b/integration_tests/src/main/python/delta_lake_write_test.py
@@ -178,6 +178,7 @@ def do_write(spark, path):
@delta_lake
@ignore_order(local=True)
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
+@pytest.mark.xfail(condition=is_spark_340_or_later() and is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/9676")
def test_delta_atomic_create_table_as_select(spark_tmp_table_factory, spark_tmp_path):
_atomic_write_table_as_select(delta_write_gens, spark_tmp_table_factory, spark_tmp_path, overwrite=False)
@@ -185,6 +186,7 @@ def test_delta_atomic_create_table_as_select(spark_tmp_table_factory, spark_tmp_
@delta_lake
@ignore_order(local=True)
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
+@pytest.mark.xfail(condition=is_spark_340_or_later() and is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/9676")
def test_delta_atomic_replace_table_as_select(spark_tmp_table_factory, spark_tmp_path):
_atomic_write_table_as_select(delta_write_gens, spark_tmp_table_factory, spark_tmp_path, overwrite=True)
diff --git a/scala2.13/delta-lake/delta-spark341db/pom.xml b/scala2.13/delta-lake/delta-spark341db/pom.xml
new file mode 100644
index 00000000000..e8d7d0dd644
--- /dev/null
+++ b/scala2.13/delta-lake/delta-spark341db/pom.xml
@@ -0,0 +1,296 @@
+
+
+
+ 4.0.0
+
+
+ com.nvidia
+ rapids-4-spark-jdk-profiles_2.13
+ 23.12.0-SNAPSHOT
+ ../../jdk-profiles/pom.xml
+
+
+ rapids-4-spark-delta-spark341db_2.13
+ RAPIDS Accelerator for Apache Spark Databricks 13.3 Delta Lake Support
+ Databricks 13.3 Delta Lake support for the RAPIDS Accelerator for Apache Spark
+ 23.12.0-SNAPSHOT
+
+
+ false
+ **/*
+ package
+
+
+
+
+ com.nvidia
+ rapids-4-spark-sql_${scala.binary.version}
+ ${project.version}
+ ${spark.version.classifier}
+ provided
+
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-annotation_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-network-common_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-network-shuffle_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-launcher_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-unsafe_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.avro
+ avro-mapred
+ ${spark.version}
+ provided
+
+
+ org.apache.avro
+ avro
+ ${spark.version}
+ provided
+
+
+ org.apache.hive
+ hive-exec
+ ${spark.version}
+ provided
+
+
+ org.apache.hive
+ hive-serde
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-hive_${scala.binary.version}
+
+
+ com.fasterxml.jackson.core
+ jackson-core
+ ${spark.version}
+ provided
+
+
+ com.fasterxml.jackson.core
+ jackson-annotations
+ ${spark.version}
+ provided
+
+
+ org.json4s
+ json4s-ast_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.json4s
+ json4s-core_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.commons
+ commons-io
+ ${spark.version}
+ provided
+
+
+ org.scala-lang
+ scala-reflect
+ ${scala.version}
+ provided
+
+
+ org.apache.commons
+ commons-lang3
+ ${spark.version}
+ provided
+
+
+ com.esotericsoftware.kryo
+ kryo-shaded-db
+ ${spark.version}
+ provided
+
+
+ org.apache.parquet
+ parquet-hadoop
+ ${spark.version}
+ provided
+
+
+ org.apache.parquet
+ parquet-common
+ ${spark.version}
+ provided
+
+
+ org.apache.parquet
+ parquet-column
+ ${spark.version}
+ provided
+
+
+ org.apache.parquet
+ parquet-format
+ ${spark.version}
+ provided
+
+
+ org.apache.arrow
+ arrow-memory
+ ${spark.version}
+ provided
+
+
+ org.apache.arrow
+ arrow-vector
+ ${spark.version}
+ provided
+
+
+ org.apache.hadoop
+ hadoop-client
+ ${hadoop.client.version}
+ provided
+
+
+ org.apache.orc
+ orc-core
+ ${spark.version}
+ provided
+
+
+ org.apache.orc
+ orc-shims
+ ${spark.version}
+ provided
+
+
+ org.apache.orc
+ orc-mapreduce
+ ${spark.version}
+ provided
+
+
+ org.apache.hive
+ hive-storage-api
+ ${spark.version}
+ provided
+
+
+ com.google.protobuf
+ protobuf-java
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-common-utils_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-sql-api_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+
+
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-common-sources
+ generate-sources
+
+ add-source
+
+
+
+
+
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ org.apache.rat
+ apache-rat-plugin
+
+
+
+
diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala
index c5059046867..bbd43ee24b0 100644
--- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala
+++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala
@@ -68,7 +68,7 @@ case class GpuAtomicReplaceTableAsSelectExec(
val properties = CatalogV2Util.convertTableProperties(tableSpec)
override def supportsColumnar: Boolean = false
-
+
override protected def run(): Seq[InternalRow] = {
val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable
if (catalog.tableExists(ident)) {
diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala
new file mode 100644
index 00000000000..398bc8e187d
--- /dev/null
+++ b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "341db"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.execution.datasources.v2.rapids
+
+import scala.collection.JavaConverters._
+
+import com.nvidia.spark.rapids.GpuExec
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec}
+import org.apache.spark.sql.catalyst.util.CharVarcharUtils
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog}
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+/**
+ * GPU version of AtomicCreateTableAsSelectExec.
+ *
+ * Physical plan node for v2 create table as select, when the catalog is determined to support
+ * staging table creation.
+ *
+ * A new table will be created using the schema of the query, and rows from the query are appended.
+ * The CTAS operation is atomic. The creation of the table is staged and the commit of the write
+ * should bundle the commitment of the metadata and the table contents in a single unit. If the
+ * write fails, the table is instructed to roll back all staged changes.
+ */
+case class GpuAtomicCreateTableAsSelectExec(
+ override val output: Seq[Attribute],
+ catalog: StagingTableCatalog,
+ ident: Identifier,
+ partitioning: Seq[Transform],
+ query: LogicalPlan,
+ tableSpec: TableSpec,
+ writeOptions: Map[String, String],
+ ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec with GpuExec {
+
+ val properties = CatalogV2Util.convertTableProperties(tableSpec)
+
+ override def supportsColumnar: Boolean = false
+
+ override protected def run(): Seq[InternalRow] = {
+ if (catalog.tableExists(ident)) {
+ if (ifNotExists) {
+ return Nil
+ }
+
+ throw QueryCompilationErrors.tableAlreadyExistsError(ident)
+ }
+ val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable
+ val stagedTable = catalog.stageCreate(
+ ident, getV2Columns(schema, catalog.useNullableQuerySchema),
+ partitioning.toArray, properties.asJava)
+
+ writeToTable(catalog, stagedTable, writeOptions, ident, query)
+ }
+
+ override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] =
+ throw new IllegalStateException("Columnar execution not supported")
+}
diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala
new file mode 100644
index 00000000000..d1380facb86
--- /dev/null
+++ b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "341db"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.execution.datasources.v2.rapids
+
+import scala.collection.JavaConverters._
+
+import com.nvidia.spark.rapids.GpuExec
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec}
+import org.apache.spark.sql.catalyst.util.CharVarcharUtils
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog}
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+/**
+ * GPU version of AtomicReplaceTableAsSelectExec.
+ *
+ * Physical plan node for v2 replace table as select when the catalog supports staging
+ * table replacement.
+ *
+ * A new table will be created using the schema of the query, and rows from the query are appended.
+ * If the table exists, its contents and schema should be replaced with the schema and the contents
+ * of the query. This implementation is atomic. The table replacement is staged, and the commit
+ * operation at the end should perform the replacement of the table's metadata and contents. If the
+ * write fails, the table is instructed to roll back staged changes and any previously written table
+ * is left untouched.
+ */
+case class GpuAtomicReplaceTableAsSelectExec(
+ override val output: Seq[Attribute],
+ catalog: StagingTableCatalog,
+ ident: Identifier,
+ partitioning: Seq[Transform],
+ query: LogicalPlan,
+ tableSpec: TableSpec,
+ writeOptions: Map[String, String],
+ orCreate: Boolean,
+ invalidateCache: (TableCatalog, Table, Identifier) => Unit)
+ extends V2CreateTableAsSelectBaseExec with GpuExec {
+
+
+ val properties = CatalogV2Util.convertTableProperties(tableSpec)
+
+ override def supportsColumnar: Boolean = false
+
+ override protected def run(): Seq[InternalRow] = {
+ val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable
+ if (catalog.tableExists(ident)) {
+ val table = catalog.loadTable(ident)
+ invalidateCache(catalog, table, ident)
+ }
+ val staged = if (orCreate) {
+ catalog.stageCreateOrReplace(
+ ident, schema, partitioning.toArray, properties.asJava)
+ } else if (catalog.tableExists(ident)) {
+ try {
+ catalog.stageReplace(
+ ident, schema, partitioning.toArray, properties.asJava)
+ } catch {
+ case e: NoSuchTableException =>
+ throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e))
+ }
+ } else {
+ throw QueryCompilationErrors.cannotReplaceMissingTableError(ident)
+ }
+ writeToTable(catalog, staged, writeOptions, ident, query)
+ }
+
+ override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] =
+ throw new IllegalStateException("Columnar execution not supported")
+}