diff --git a/delta-lake/README.md b/delta-lake/README.md index ff9d2553e31..f48c20fb7c3 100644 --- a/delta-lake/README.md +++ b/delta-lake/README.md @@ -19,6 +19,7 @@ and directory contains the corresponding support code. | Databricks 10.4 | Databricks 10.4 | `delta-spark321db` | | Databricks 11.3 | Databricks 11.3 | `delta-spark330db` | | Databricks 12.2 | Databricks 12.2 | `delta-spark332db` | +| Databricks 13.3 | Databricks 13.3 | `delta-spark341db` | Delta Lake is not supported on all Spark versions, and for Spark versions where it is not supported the `delta-stub` project is used. diff --git a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala index 083ebd09979..2ea31f0ae06 100644 --- a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala +++ b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala @@ -25,7 +25,7 @@ import com.databricks.sql.managedcatalog.UnityCatalogV2Proxy import com.databricks.sql.transaction.tahoe.{DeltaLog, DeltaOptions, DeltaParquetFileFormat} import com.databricks.sql.transaction.tahoe.catalog.{DeltaCatalog, DeltaTableV2} import com.databricks.sql.transaction.tahoe.commands.{DeleteCommand, DeleteCommandEdge, MergeIntoCommand, MergeIntoCommandEdge, UpdateCommand, UpdateCommandEdge, WriteIntoDelta} -import com.databricks.sql.transaction.tahoe.rapids.{GpuDeltaCatalog, GpuDeltaLog, GpuWriteIntoDelta} +import com.databricks.sql.transaction.tahoe.rapids.{GpuDeltaLog, GpuWriteIntoDelta} import com.databricks.sql.transaction.tahoe.sources.{DeltaDataSource, DeltaSourceUtils} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.delta.shims.DeltaLogShim @@ -38,15 +38,15 @@ import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.{FileFormat, LogicalRelation, SaveIntoDataSourceCommand} import org.apache.spark.sql.execution.datasources.v2.{AppendDataExecV1, AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec, OverwriteByExpressionExecV1} -import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.ExternalSource import org.apache.spark.sql.sources.{CreatableRelationProvider, InsertableRelation} +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * Common implementation of the DeltaProvider interface for all Databricks versions. */ -object DatabricksDeltaProvider extends DeltaProviderImplBase { +trait DatabricksDeltaProviderBase extends DeltaProviderImplBase { override def getCreatableRelationRules: Map[Class[_ <: CreatableRelationProvider], CreatableRelationProviderRule[_ <: CreatableRelationProvider]] = { Seq( @@ -116,6 +116,15 @@ object DatabricksDeltaProvider extends DeltaProviderImplBase { catalogClass == classOf[DeltaCatalog] || catalogClass == classOf[UnityCatalogV2Proxy] } + private def getWriteOptions(options: Any): Map[String, String] = { + // For Databricks 13.3 AtomicCreateTableAsSelectExec writeOptions is a Map[String, String] + // while in all the other versions it's a CaseInsensitiveMap + options match { + case c: CaseInsensitiveStringMap => c.asCaseSensitiveMap().asScala.toMap + case _ => options.asInstanceOf[Map[String, String]] + } + } + override def tagForGpu( cpuExec: AtomicCreateTableAsSelectExec, meta: AtomicCreateTableAsSelectExecMeta): Unit = { @@ -131,22 +140,7 @@ object DatabricksDeltaProvider extends DeltaProviderImplBase { meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider") } RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None, - cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session) - } - - override def convertToGpu( - cpuExec: AtomicCreateTableAsSelectExec, - meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { - GpuAtomicCreateTableAsSelectExec( - cpuExec.output, - new GpuDeltaCatalog(cpuExec.catalog, meta.conf), - cpuExec.ident, - cpuExec.partitioning, - cpuExec.plan, - meta.childPlans.head.convertIfNeeded(), - cpuExec.tableSpec, - cpuExec.writeOptions, - cpuExec.ifNotExists) + getWriteOptions(cpuExec.writeOptions), cpuExec.session) } override def tagForGpu( @@ -164,23 +158,7 @@ object DatabricksDeltaProvider extends DeltaProviderImplBase { meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider") } RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None, - cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session) - } - - override def convertToGpu( - cpuExec: AtomicReplaceTableAsSelectExec, - meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { - GpuAtomicReplaceTableAsSelectExec( - cpuExec.output, - new GpuDeltaCatalog(cpuExec.catalog, meta.conf), - cpuExec.ident, - cpuExec.partitioning, - cpuExec.plan, - meta.childPlans.head.convertIfNeeded(), - cpuExec.tableSpec, - cpuExec.writeOptions, - cpuExec.orCreate, - cpuExec.invalidateCache) + getWriteOptions(cpuExec.writeOptions), cpuExec.session) } private case class DeltaWriteV1Config( @@ -360,13 +338,4 @@ class DeltaCreatableRelationProviderMeta( } override def convertToGpu(): GpuCreatableRelationProvider = new GpuDeltaDataSource(conf) -} - -/** - * Implements the Delta Probe interface for probing the Delta Lake provider on Databricks. - * @note This is instantiated via reflection from ShimLoader. - */ -class DeltaProbeImpl extends DeltaProbe { - // Delta Lake is built-in for Databricks instances, so no probing is necessary. - override def getDeltaProvider: DeltaProvider = DatabricksDeltaProvider -} +} \ No newline at end of file diff --git a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala index 467f16986aa..3968c48bff0 100644 --- a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala +++ b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/shims/package-shims.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,6 @@ package com.nvidia.spark.rapids.delta.shims import com.databricks.sql.expressions.JoinedProjection import com.databricks.sql.transaction.tahoe.DeltaColumnMapping -import com.databricks.sql.transaction.tahoe.stats.UsesMetadataFields import com.databricks.sql.transaction.tahoe.util.JsonUtils import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} @@ -47,6 +46,4 @@ object ShimJoinedProjection { object ShimJsonUtils { def fromJson[T: Manifest](json: String): T = JsonUtils.fromJson[T](json) -} - -trait ShimUsesMetadataFields extends UsesMetadataFields +} \ No newline at end of file diff --git a/delta-lake/common/src/main/scala/org/apache/spark/sql/rapids/delta/DeltaShufflePartitionsUtil.scala b/delta-lake/common/src/main/scala/org/apache/spark/sql/rapids/delta/DeltaShufflePartitionsUtil.scala index c573819f31e..395757f0a30 100644 --- a/delta-lake/common/src/main/scala/org/apache/spark/sql/rapids/delta/DeltaShufflePartitionsUtil.scala +++ b/delta-lake/common/src/main/scala/org/apache/spark/sql/rapids/delta/DeltaShufflePartitionsUtil.scala @@ -179,9 +179,8 @@ object DeltaShufflePartitionsUtil { c.child case _ => p } - case ShuffleExchangeExec(_, child, shuffleOrigin) - if !shuffleOrigin.equals(ENSURE_REQUIREMENTS) => - child + case s: ShuffleExchangeExec if !s.shuffleOrigin.equals(ENSURE_REQUIREMENTS) => + s.child case CoalesceExec(_, child) => child case _ => diff --git a/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala b/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala new file mode 100644 index 00000000000..2194522ab82 --- /dev/null +++ b/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.delta + +/** + * Implements the Delta Probe interface for probing the Delta Lake provider on Databricks. + * @note This is instantiated via reflection from ShimLoader. + */ +class DeltaProbeImpl extends DeltaProbe { + // Delta Lake is built-in for Databricks instances, so no probing is necessary. + override def getDeltaProvider: DeltaProvider = DeltaSpark321DBProvider +} diff --git a/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark321DBProvider.scala b/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark321DBProvider.scala new file mode 100644 index 00000000000..44e5721bafc --- /dev/null +++ b/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark321DBProvider.scala @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta + +import com.databricks.sql.transaction.tahoe.rapids.GpuDeltaCatalog +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec} + +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} + +object DeltaSpark321DBProvider extends DatabricksDeltaProviderBase { + + override def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { + GpuAtomicCreateTableAsSelectExec( + cpuExec.output, + new GpuDeltaCatalog(cpuExec.catalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.ifNotExists) + } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + GpuAtomicReplaceTableAsSelectExec( + cpuExec.output, + new GpuDeltaCatalog(cpuExec.catalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } +} diff --git a/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala b/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala new file mode 100644 index 00000000000..f722837778a --- /dev/null +++ b/delta-lake/delta-spark321db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.delta.shims + +import com.databricks.sql.transaction.tahoe.stats.UsesMetadataFields + +trait ShimUsesMetadataFields extends UsesMetadataFields diff --git a/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala b/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala new file mode 100644 index 00000000000..01a386c3769 --- /dev/null +++ b/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.delta + +/** + * Implements the Delta Probe interface for probing the Delta Lake provider on Databricks. + * @note This is instantiated via reflection from ShimLoader. + */ +class DeltaProbeImpl extends DeltaProbe { + // Delta Lake is built-in for Databricks instances, so no probing is necessary. + override def getDeltaProvider: DeltaProvider = DeltaSpark330DBProvider +} diff --git a/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark330DBProvider.scala b/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark330DBProvider.scala new file mode 100644 index 00000000000..c7a5cf1db27 --- /dev/null +++ b/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark330DBProvider.scala @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta + +import com.databricks.sql.transaction.tahoe.rapids.GpuDeltaCatalog +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec} + +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} + +object DeltaSpark330DBProvider extends DatabricksDeltaProviderBase { + + override def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { + GpuAtomicCreateTableAsSelectExec( + cpuExec.output, + new GpuDeltaCatalog(cpuExec.catalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.ifNotExists) + } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + GpuAtomicReplaceTableAsSelectExec( + cpuExec.output, + new GpuDeltaCatalog(cpuExec.catalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } +} diff --git a/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala b/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala new file mode 100644 index 00000000000..f722837778a --- /dev/null +++ b/delta-lake/delta-spark330db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.delta.shims + +import com.databricks.sql.transaction.tahoe.stats.UsesMetadataFields + +trait ShimUsesMetadataFields extends UsesMetadataFields diff --git a/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala b/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala new file mode 100644 index 00000000000..e18cbcff2d3 --- /dev/null +++ b/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.delta + +/** + * Implements the Delta Probe interface for probing the Delta Lake provider on Databricks. + * @note This is instantiated via reflection from ShimLoader. + */ +class DeltaProbeImpl extends DeltaProbe { + // Delta Lake is built-in for Databricks instances, so no probing is necessary. + override def getDeltaProvider: DeltaProvider = DeltaSpark332DBProvider +} diff --git a/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark332DBProvider.scala b/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark332DBProvider.scala new file mode 100644 index 00000000000..283ad44fb30 --- /dev/null +++ b/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark332DBProvider.scala @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta + +import com.databricks.sql.transaction.tahoe.rapids.GpuDeltaCatalog +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec} + +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} + +object DeltaSpark332DBProvider extends DatabricksDeltaProviderBase { + + override def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { + GpuAtomicCreateTableAsSelectExec( + cpuExec.output, + new GpuDeltaCatalog(cpuExec.catalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.ifNotExists) + } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + GpuAtomicReplaceTableAsSelectExec( + cpuExec.output, + new GpuDeltaCatalog(cpuExec.catalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } +} diff --git a/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala b/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala new file mode 100644 index 00000000000..f722837778a --- /dev/null +++ b/delta-lake/delta-spark332db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.delta.shims + +import com.databricks.sql.transaction.tahoe.stats.UsesMetadataFields + +trait ShimUsesMetadataFields extends UsesMetadataFields diff --git a/delta-lake/delta-spark341db/pom.xml b/delta-lake/delta-spark341db/pom.xml new file mode 100644 index 00000000000..64e920eb8f1 --- /dev/null +++ b/delta-lake/delta-spark341db/pom.xml @@ -0,0 +1,296 @@ + + + + 4.0.0 + + + com.nvidia + rapids-4-spark-jdk-profiles_2.12 + 23.12.0-SNAPSHOT + ../../jdk-profiles/pom.xml + + + rapids-4-spark-delta-spark341db_2.12 + RAPIDS Accelerator for Apache Spark Databricks 13.3 Delta Lake Support + Databricks 13.3 Delta Lake support for the RAPIDS Accelerator for Apache Spark + 23.12.0-SNAPSHOT + + + false + **/* + package + + + + + com.nvidia + rapids-4-spark-sql_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + provided + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-annotation_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-network-common_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-network-shuffle_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-launcher_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${spark.version} + provided + + + org.apache.avro + avro-mapred + ${spark.version} + provided + + + org.apache.avro + avro + ${spark.version} + provided + + + org.apache.hive + hive-exec + ${spark.version} + provided + + + org.apache.hive + hive-serde + ${spark.version} + provided + + + org.apache.spark + spark-hive_${scala.binary.version} + + + com.fasterxml.jackson.core + jackson-core + ${spark.version} + provided + + + com.fasterxml.jackson.core + jackson-annotations + ${spark.version} + provided + + + org.json4s + json4s-ast_${scala.binary.version} + ${spark.version} + provided + + + org.json4s + json4s-core_${scala.binary.version} + ${spark.version} + provided + + + org.apache.commons + commons-io + ${spark.version} + provided + + + org.scala-lang + scala-reflect + ${scala.version} + provided + + + org.apache.commons + commons-lang3 + ${spark.version} + provided + + + com.esotericsoftware.kryo + kryo-shaded-db + ${spark.version} + provided + + + org.apache.parquet + parquet-hadoop + ${spark.version} + provided + + + org.apache.parquet + parquet-common + ${spark.version} + provided + + + org.apache.parquet + parquet-column + ${spark.version} + provided + + + org.apache.parquet + parquet-format + ${spark.version} + provided + + + org.apache.arrow + arrow-memory + ${spark.version} + provided + + + org.apache.arrow + arrow-vector + ${spark.version} + provided + + + org.apache.hadoop + hadoop-client + ${hadoop.client.version} + provided + + + org.apache.orc + orc-core + ${spark.version} + provided + + + org.apache.orc + orc-shims + ${spark.version} + provided + + + org.apache.orc + orc-mapreduce + ${spark.version} + provided + + + org.apache.hive + hive-storage-api + ${spark.version} + provided + + + com.google.protobuf + protobuf-java + ${spark.version} + provided + + + org.apache.spark + spark-common-utils_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-sql-api_${scala.binary.version} + ${spark.version} + provided + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-common-sources + generate-sources + + add-source + + + + ${project.basedir}/../common/src/main/scala + ${project.basedir}/../common/src/main/databricks/scala + + + + + + + net.alchim31.maven + scala-maven-plugin + + + org.apache.rat + apache-rat-plugin + + + + diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala new file mode 100644 index 00000000000..320485eb1ee --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala @@ -0,0 +1,464 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from CreateDeltaTableCommand.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import com.databricks.sql.transaction.tahoe._ +import com.databricks.sql.transaction.tahoe.actions.Metadata +import com.databricks.sql.transaction.tahoe.commands.{TableCreationModes, WriteIntoDelta} +import com.databricks.sql.transaction.tahoe.metering.DeltaLogging +import com.databricks.sql.transaction.tahoe.schema.SchemaUtils +import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf +import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.execution.command.{LeafRunnableCommand, RunnableCommand} +import org.apache.spark.sql.types.StructType + +/** + * Single entry point for all write or declaration operations for Delta tables accessed through + * the table name. + * + * @param table The table identifier for the Delta table + * @param existingTableOpt The existing table for the same identifier if exists + * @param mode The save mode when writing data. Relevant when the query is empty or set to Ignore + * with `CREATE TABLE IF NOT EXISTS`. + * @param query The query to commit into the Delta table if it exist. This can come from + * - CTAS + * - saveAsTable + */ +case class GpuCreateDeltaTableCommand( + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode = TableCreationModes.Create, + tableByPath: Boolean = false, + override val output: Seq[Attribute] = Nil)(@transient rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaLogging { + + override def otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override def run(sparkSession: SparkSession): Seq[Row] = { + val table = this.table + + assert(table.tableType != CatalogTableType.VIEW) + assert(table.identifier.database.isDefined, "Database should've been fixed at analysis") + // There is a subtle race condition here, where the table can be created by someone else + // while this command is running. Nothing we can do about that though :( + val tableExists = existingTableOpt.isDefined + if (mode == SaveMode.Ignore && tableExists) { + // Early exit on ignore + return Nil + } else if (mode == SaveMode.ErrorIfExists && tableExists) { + throw DeltaErrors.tableAlreadyExists(table) + } + + val tableWithLocation = if (tableExists) { + val existingTable = existingTableOpt.get + table.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + throw DeltaErrors.tableLocationMismatch(table, existingTable) + case _ => + } + table.copy( + storage = existingTable.storage, + tableType = existingTable.tableType) + } else if (table.storage.locationUri.isEmpty) { + // We are defining a new managed table + assert(table.tableType == CatalogTableType.MANAGED) + val loc = sparkSession.sessionState.catalog.defaultTablePath(table.identifier) + table.copy(storage = table.storage.copy(locationUri = Some(loc))) + } else { + // 1. We are defining a new external table + // 2. It's a managed table which already has the location populated. This can happen in DSV2 + // CTAS flow. + table + } + + val isManagedTable = tableWithLocation.tableType == CatalogTableType.MANAGED + val tableLocation = new Path(tableWithLocation.location) + val gpuDeltaLog = GpuDeltaLog.forTable(sparkSession, tableLocation, rapidsConf) + val hadoopConf = gpuDeltaLog.deltaLog.newDeltaHadoopConf() + val fs = tableLocation.getFileSystem(hadoopConf) + val options = new DeltaOptions(table.storage.properties, sparkSession.sessionState.conf) + var result: Seq[Row] = Nil + + recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.ddl.createTable") { + val txn = gpuDeltaLog.startTransaction() + val opStartTs = System.currentTimeMillis() + if (query.isDefined) { + // If the mode is Ignore or ErrorIfExists, the table must not exist, or we would return + // earlier. And the data should not exist either, to match the behavior of + // Ignore/ErrorIfExists mode. This means the table path should not exist or is empty. + if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { + assert(!tableExists) + // We may have failed a previous write. The retry should still succeed even if we have + // garbage data + if (txn.readVersion > -1 || !fs.exists(gpuDeltaLog.deltaLog.logPath)) { + assertPathEmpty(hadoopConf, tableWithLocation) + } + } + // We are either appending/overwriting with saveAsTable or creating a new table with CTAS or + // we are creating a table as part of a RunnableCommand + query.get match { + case writer: WriteIntoDelta => + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, writer.data.schema.asNullable) + } + val actions = writer.write(txn, sparkSession) + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + case cmd: RunnableCommand => + result = cmd.run(sparkSession) + case other => + // When using V1 APIs, the `other` plan is not yet optimized, therefore, it is safe + // to once again go through analysis + val data = Dataset.ofRows(sparkSession, other) + + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, other.schema.asNullable) + } + + val actions = WriteIntoDelta( + deltaLog = gpuDeltaLog.deltaLog, + mode = mode, + options, + partitionColumns = table.partitionColumnNames, + configuration = tableWithLocation.properties + ("comment" -> table.comment.orNull), + data = data).write(txn, sparkSession) + + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + } + } else { + def createTransactionLogOrVerify(): Unit = { + if (isManagedTable) { + // When creating a managed table, the table path should not exist or is empty, or + // users would be surprised to see the data, or see the data directory being dropped + // after the table is dropped. + assertPathEmpty(hadoopConf, tableWithLocation) + } + + // This is either a new table, or, we never defined the schema of the table. While it is + // unexpected that `txn.metadata.schema` to be empty when txn.readVersion >= 0, we still + // guard against it, in case of checkpoint corruption bugs. + val noExistingMetadata = txn.readVersion == -1 || txn.metadata.schema.isEmpty + if (noExistingMetadata) { + assertTableSchemaDefined(fs, tableLocation, tableWithLocation, txn, sparkSession) + assertPathEmpty(hadoopConf, tableWithLocation) + // This is a user provided schema. + // Doesn't come from a query, Follow nullability invariants. + val newMetadata = getProvidedMetadata(tableWithLocation, table.schema.json) + txn.updateMetadataForNewTable(newMetadata) + + val op = getOperation(newMetadata, isManagedTable, None) + txn.commit(Nil, op) + } else { + verifyTableMetadata(txn, tableWithLocation) + } + } + // We are defining a table using the Create or Replace Table statements. + operation match { + case TableCreationModes.Create => + require(!tableExists, "Can't recreate a table when it exists") + createTransactionLogOrVerify() + + case TableCreationModes.CreateOrReplace if !tableExists => + // If the table doesn't exist, CREATE OR REPLACE must provide a schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + createTransactionLogOrVerify() + case _ => + // When the operation is a REPLACE or CREATE OR REPLACE, then the schema shouldn't be + // empty, since we'll use the entry to replace the schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + // We need to replace + replaceMetadataIfNecessary(txn, tableWithLocation, options, tableWithLocation.schema) + // Truncate the table + val operationTimestamp = System.currentTimeMillis() + val removes = txn.filterFiles().map(_.removeWithTimestamp(operationTimestamp)) + val op = getOperation(txn.metadata, isManagedTable, None) + txn.commit(removes, op) + } + } + + // We would have failed earlier on if we couldn't ignore the existence of the table + // In addition, we just might using saveAsTable to append to the table, so ignore the creation + // if it already exists. + // Note that someone may have dropped and recreated the table in a separate location in the + // meantime... Unfortunately we can't do anything there at the moment, because Hive sucks. + logInfo(s"Table is path-based table: $tableByPath. Update catalog with mode: $operation") + updateCatalog( + sparkSession, + tableWithLocation, + gpuDeltaLog.deltaLog.update(checkIfUpdatedSinceTs = Some(opStartTs)), + txn) + + result + } + } + + private def getProvidedMetadata(table: CatalogTable, schemaString: String): Metadata = { + Metadata( + description = table.comment.orNull, + schemaString = schemaString, + partitionColumns = table.partitionColumnNames, + configuration = table.properties, + createdTime = Some(System.currentTimeMillis())) + } + + private def assertPathEmpty( + hadoopConf: Configuration, + tableWithLocation: CatalogTable): Unit = { + val path = new Path(tableWithLocation.location) + val fs = path.getFileSystem(hadoopConf) + // Verify that the table location associated with CREATE TABLE doesn't have any data. Note that + // we intentionally diverge from this behavior w.r.t regular datasource tables (that silently + // overwrite any previous data) + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createTableWithNonEmptyLocation( + tableWithLocation.identifier.toString, + tableWithLocation.location.toString) + } + } + + private def assertTableSchemaDefined( + fs: FileSystem, + path: Path, + table: CatalogTable, + txn: OptimisticTransaction, + sparkSession: SparkSession): Unit = { + // If we allow creating an empty schema table and indeed the table is new, we just need to + // make sure: + // 1. txn.readVersion == -1 to read a new table + // 2. for external tables: path must either doesn't exist or is completely empty + val allowCreatingTableWithEmptySchema = sparkSession.sessionState + .conf.getConf(DeltaSQLConf.DELTA_ALLOW_CREATE_EMPTY_SCHEMA_TABLE) && txn.readVersion == -1 + + // Users did not specify the schema. We expect the schema exists in Delta. + if (table.schema.isEmpty) { + if (table.tableType == CatalogTableType.EXTERNAL) { + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createExternalTableWithoutLogException( + path, table.identifier.quotedString, sparkSession) + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createExternalTableWithoutSchemaException( + path, table.identifier.quotedString, sparkSession) + } + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createManagedTableWithoutSchemaException( + table.identifier.quotedString, sparkSession) + } + } + } + + /** + * Verify against our transaction metadata that the user specified the right metadata for the + * table. + */ + private def verifyTableMetadata( + txn: OptimisticTransaction, + tableDesc: CatalogTable): Unit = { + val existingMetadata = txn.metadata + val path = new Path(tableDesc.location) + + // The delta log already exists. If they give any configuration, we'll make sure it all matches. + // Otherwise we'll just go with the metadata already present in the log. + // The schema compatibility checks will be made in `WriteIntoDelta` for CreateTable + // with a query + if (txn.readVersion > -1) { + if (tableDesc.schema.nonEmpty) { + // We check exact alignment on create table if everything is provided + // However, if in column mapping mode, we can safely ignore the related metadata fields in + // existing metadata because new table desc will not have related metadata assigned yet + val differences = SchemaUtils.reportDifferences( + DeltaColumnMapping.dropColumnMappingMetadata(existingMetadata.schema), + tableDesc.schema) + if (differences.nonEmpty) { + throw DeltaErrors.createTableWithDifferentSchemaException( + path, tableDesc.schema, existingMetadata.schema, differences) + } + } + + // If schema is specified, we must make sure the partitioning matches, even the partitioning + // is not specified. + if (tableDesc.schema.nonEmpty && + tableDesc.partitionColumnNames != existingMetadata.partitionColumns) { + throw DeltaErrors.createTableWithDifferentPartitioningException( + path, tableDesc.partitionColumnNames, existingMetadata.partitionColumns) + } + + if (tableDesc.properties.nonEmpty && tableDesc.properties != existingMetadata.configuration) { + throw DeltaErrors.createTableWithDifferentPropertiesException( + path, tableDesc.properties, existingMetadata.configuration) + } + } + } + + /** + * Based on the table creation operation, and parameters, we can resolve to different operations. + * A lot of this is needed for legacy reasons in Databricks Runtime. + * @param metadata The table metadata, which we are creating or replacing + * @param isManagedTable Whether we are creating or replacing a managed table + * @param options Write options, if this was a CTAS/RTAS + */ + private def getOperation( + metadata: Metadata, + isManagedTable: Boolean, + options: Option[DeltaOptions]): DeltaOperations.Operation = operation match { + // This is legacy saveAsTable behavior in Databricks Runtime + case TableCreationModes.Create if existingTableOpt.isDefined && query.isDefined => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // DataSourceV2 table creation + // CREATE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Create => + DeltaOperations.CreateTable(metadata, isManagedTable, query.isDefined) + + // DataSourceV2 table replace + // REPLACE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Replace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = false, query.isDefined) + + // Legacy saveAsTable with Overwrite mode + case TableCreationModes.CreateOrReplace if options.exists(_.replaceWhere.isDefined) => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // New DataSourceV2 saveAsTable with overwrite mode behavior + case TableCreationModes.CreateOrReplace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = true, query.isDefined, + options.flatMap(_.userMetadata)) + } + + /** + * Similar to getOperation, here we disambiguate the catalog alterations we need to do based + * on the table operation, and whether we have reached here through legacy code or DataSourceV2 + * code paths. + */ + private def updateCatalog( + spark: SparkSession, + table: CatalogTable, + snapshot: Snapshot, + txn: OptimisticTransaction): Unit = { + val cleaned = cleanupTableDefinition(table, snapshot) + operation match { + case _ if tableByPath => // do nothing with the metastore if this is by path + case TableCreationModes.Create => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = existingTableOpt.isDefined, + validateLocation = false) + case TableCreationModes.Replace | TableCreationModes.CreateOrReplace + if existingTableOpt.isDefined => + spark.sessionState.catalog.alterTable(table) + case TableCreationModes.Replace => + val ident = Identifier.of(table.identifier.database.toArray, table.identifier.table) + throw DeltaErrors.cannotReplaceMissingTableException(ident) + case TableCreationModes.CreateOrReplace => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = false, + validateLocation = false) + } + } + + /** Clean up the information we pass on to store in the catalog. */ + private def cleanupTableDefinition(table: CatalogTable, snapshot: Snapshot): CatalogTable = { + // These actually have no effect on the usability of Delta, but feature flagging legacy + // behavior for now + val storageProps = if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) { + // Legacy behavior + table.storage + } else { + table.storage.copy(properties = Map.empty) + } + + table.copy( + schema = new StructType(), + properties = Map.empty, + partitionColumnNames = Nil, + // Remove write specific options when updating the catalog + storage = storageProps, + tracksPartitionsInCatalog = true) + } + + /** + * With DataFrameWriterV2, methods like `replace()` or `createOrReplace()` mean that the + * metadata of the table should be replaced. If overwriteSchema=false is provided with these + * methods, then we will verify that the metadata match exactly. + */ + private def replaceMetadataIfNecessary( + txn: OptimisticTransaction, + tableDesc: CatalogTable, + options: DeltaOptions, + schema: StructType): Unit = { + val isReplace = (operation == TableCreationModes.CreateOrReplace || + operation == TableCreationModes.Replace) + // If a user explicitly specifies not to overwrite the schema, during a replace, we should + // tell them that it's not supported + val dontOverwriteSchema = options.options.contains(DeltaOptions.OVERWRITE_SCHEMA_OPTION) && + !options.canOverwriteSchema + if (isReplace && dontOverwriteSchema) { + throw DeltaErrors.illegalUsageException(DeltaOptions.OVERWRITE_SCHEMA_OPTION, "replacing") + } + if (txn.readVersion > -1L && isReplace && !dontOverwriteSchema) { + // When a table already exists, and we're using the DataFrameWriterV2 API to replace + // or createOrReplace a table, we blindly overwrite the metadata. + txn.updateMetadataForNewTable(getProvidedMetadata(table, schema.json)) + } + } + + /** + * Horrible hack to differentiate between DataFrameWriterV1 and V2 so that we can decide + * what to do with table metadata. In DataFrameWriterV1, mode("overwrite").saveAsTable, + * behaves as a CreateOrReplace table, but we have asked for "overwriteSchema" as an + * explicit option to overwrite partitioning or schema information. With DataFrameWriterV2, + * the behavior asked for by the user is clearer: .createOrReplace(), which means that we + * should overwrite schema and/or partitioning. Therefore we have this hack. + */ + private def isV1Writer: Boolean = { + Thread.currentThread().getStackTrace.exists(_.toString.contains( + classOf[DataFrameWriter[_]].getCanonicalName + ".")) + } +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeleteCommand.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeleteCommand.scala new file mode 100644 index 00000000000..97ae9dc46ab --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeleteCommand.scala @@ -0,0 +1,369 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeleteCommand.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import com.databricks.sql.transaction.tahoe.{DeltaConfigs, DeltaLog, DeltaOperations, DeltaTableUtils, DeltaUDF, OptimisticTransaction} +import com.databricks.sql.transaction.tahoe.DeltaCommitTag._ +import com.databricks.sql.transaction.tahoe.RowTracking +import com.databricks.sql.transaction.tahoe.actions.{AddCDCFile, FileAction} +import com.databricks.sql.transaction.tahoe.commands.{DeleteCommandMetrics, DeleteMetric, DeltaCommand, DMLUtils} +import com.databricks.sql.transaction.tahoe.commands.MergeIntoCommandBase.totalBytesAndDistinctPartitionValues +import com.databricks.sql.transaction.tahoe.files.TahoeBatchFileIndex +import com.databricks.sql.transaction.tahoe.rapids.GpuDeleteCommand.{rewritingFilesMsg, FINDING_TOUCHED_FILES_MSG} +import com.nvidia.spark.rapids.delta.GpuDeltaMetricUpdateUDF + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, Not} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.functions.input_file_name +import org.apache.spark.sql.types.LongType + +/** + * GPU version of Delta Lake DeleteCommand. + * + * Performs a Delete based on the search condition + * + * Algorithm: + * 1) Scan all the files and determine which files have + * the rows that need to be deleted. + * 2) Traverse the affected files and rebuild the touched files. + * 3) Use the Delta protocol to atomically write the remaining rows to new files and remove + * the affected files that are identified in step 1. + */ +case class GpuDeleteCommand( + gpuDeltaLog: GpuDeltaLog, + target: LogicalPlan, + condition: Option[Expression]) + extends LeafRunnableCommand with DeltaCommand with DeleteCommandMetrics { + + override def innerChildren: Seq[QueryPlan[_]] = Seq(target) + + override val output: Seq[Attribute] = Seq(AttributeReference("num_affected_rows", LongType)()) + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + + // DeleteCommandMetrics does not include deletion vector metrics, so add them here because + // the commit command needs to collect these metrics for inclusion in the delta log event + override lazy val metrics = createMetrics ++ Map( + "numDeletionVectorsAdded" -> SQLMetrics.createMetric(sc, "number of deletion vectors added."), + "numDeletionVectorsRemoved" -> + SQLMetrics.createMetric(sc, "number of deletion vectors removed.") + ) + + final override def run(sparkSession: SparkSession): Seq[Row] = { + val deltaLog = gpuDeltaLog.deltaLog + recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.dml.delete") { + gpuDeltaLog.withNewTransaction { txn => + DeltaLog.assertRemovable(txn.snapshot) + val deleteCommitTags = performDelete(sparkSession, deltaLog, txn) + val deleteActions = deleteCommitTags.actions + if (deleteActions.nonEmpty) { + txn.commitIfNeeded(deleteActions, DeltaOperations.Delete(condition.toSeq), + deleteCommitTags.stringTags) + } + } + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to + // this data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, target) + } + + // Adjust for deletes at partition boundaries. Deletes at partition boundaries is a metadata + // operation, therefore we don't actually have any information around how many rows were deleted + // While this info may exist in the file statistics, it's not guaranteed that we have these + // statistics. To avoid any performance regressions, we currently just return a -1 in such cases + if (metrics("numRemovedFiles").value > 0 && metrics("numDeletedRows").value == 0) { + Seq(Row(-1L)) + } else { + Seq(Row(metrics("numDeletedRows").value)) + } + } + + def performDelete( + sparkSession: SparkSession, + deltaLog: DeltaLog, + txn: OptimisticTransaction): DMLUtils.TaggedCommitData = { + import com.databricks.sql.transaction.tahoe.implicits._ + + var numRemovedFiles: Long = 0 + var numAddedFiles: Long = 0 + var numAddedChangeFiles: Long = 0 + var scanTimeMs: Long = 0 + var rewriteTimeMs: Long = 0 + var numBytesAdded: Long = 0 + var changeFileBytes: Long = 0 + var numBytesRemoved: Long = 0 + var numFilesBeforeSkipping: Long = 0 + var numBytesBeforeSkipping: Long = 0 + var numFilesAfterSkipping: Long = 0 + var numBytesAfterSkipping: Long = 0 + var numPartitionsAfterSkipping: Option[Long] = None + var numPartitionsRemovedFrom: Option[Long] = None + var numPartitionsAddedTo: Option[Long] = None + var numDeletedRows: Option[Long] = None + var numCopiedRows: Option[Long] = None + + val startTime = System.nanoTime() + val numFilesTotal = txn.snapshot.numOfFiles + + val deleteActions: Seq[FileAction] = condition match { + case None => + // Case 1: Delete the whole table if the condition is true + val allFiles = txn.filterFiles(Nil) + + numRemovedFiles = allFiles.size + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + val (numBytes, numPartitions) = totalBytesAndDistinctPartitionValues(allFiles) + numBytesRemoved = numBytes + numFilesBeforeSkipping = numRemovedFiles + numBytesBeforeSkipping = numBytes + numFilesAfterSkipping = numRemovedFiles + numBytesAfterSkipping = numBytes + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsAfterSkipping = Some(numPartitions) + numPartitionsRemovedFrom = Some(numPartitions) + numPartitionsAddedTo = Some(0) + } + val operationTimestamp = System.currentTimeMillis() + allFiles.map(_.removeWithTimestamp(operationTimestamp)) + case Some(cond) => + val (metadataPredicates, otherPredicates) = + DeltaTableUtils.splitMetadataAndDataPredicates( + cond, txn.metadata.partitionColumns, sparkSession) + + numFilesBeforeSkipping = txn.snapshot.numOfFiles + numBytesBeforeSkipping = txn.snapshot.sizeInBytes + + if (otherPredicates.isEmpty) { + // Case 2: The condition can be evaluated using metadata only. + // Delete a set of files without the need of scanning any data files. + val operationTimestamp = System.currentTimeMillis() + val candidateFiles = txn.filterFiles(metadataPredicates) + + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + numRemovedFiles = candidateFiles.size + numBytesRemoved = candidateFiles.map(_.size).sum + numFilesAfterSkipping = candidateFiles.size + val (numCandidateBytes, numCandidatePartitions) = + totalBytesAndDistinctPartitionValues(candidateFiles) + numBytesAfterSkipping = numCandidateBytes + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsAfterSkipping = Some(numCandidatePartitions) + numPartitionsRemovedFrom = Some(numCandidatePartitions) + numPartitionsAddedTo = Some(0) + } + candidateFiles.map(_.removeWithTimestamp(operationTimestamp)) + } else { + // Case 3: Delete the rows based on the condition. + val candidateFiles = txn.filterFiles(metadataPredicates ++ otherPredicates) + + numFilesAfterSkipping = candidateFiles.size + val (numCandidateBytes, numCandidatePartitions) = + totalBytesAndDistinctPartitionValues(candidateFiles) + numBytesAfterSkipping = numCandidateBytes + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsAfterSkipping = Some(numCandidatePartitions) + } + + val nameToAddFileMap = generateCandidateFileMap(deltaLog.dataPath, candidateFiles) + + val fileIndex = new TahoeBatchFileIndex( + sparkSession, "delete", candidateFiles, deltaLog, deltaLog.dataPath, txn.snapshot) + // Keep everything from the resolved target except a new TahoeFileIndex + // that only involves the affected files instead of all files. + val newTarget = DeltaTableUtils.replaceFileIndex(target, fileIndex) + val data = Dataset.ofRows(sparkSession, newTarget) + val deletedRowCount = metrics("numDeletedRows") + val deletedRowUdf = DeltaUDF.boolean { + new GpuDeltaMetricUpdateUDF(deletedRowCount) + }.asNondeterministic() + val filesToRewrite = + withStatusCode("DELTA", FINDING_TOUCHED_FILES_MSG) { + if (candidateFiles.isEmpty) { + Array.empty[String] + } else { + data.filter(new Column(cond)) + .select(input_file_name()) + .filter(deletedRowUdf()) + .distinct() + .as[String] + .collect() + } + } + + numRemovedFiles = filesToRewrite.length + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + if (filesToRewrite.isEmpty) { + // Case 3.1: no row matches and no delete will be triggered + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsRemovedFrom = Some(0) + numPartitionsAddedTo = Some(0) + } + Nil + } else { + // Case 3.2: some files need an update to remove the deleted files + // Do the second pass and just read the affected files + val baseRelation = buildBaseRelation( + sparkSession, txn, "delete", deltaLog.dataPath, filesToRewrite, nameToAddFileMap) + // Keep everything from the resolved target except a new TahoeFileIndex + // that only involves the affected files instead of all files. + val newTarget = DeltaTableUtils.replaceFileIndex(target, baseRelation.location) + val targetDF = Dataset.ofRows(sparkSession, newTarget) + val filterCond = Not(EqualNullSafe(cond, Literal.TrueLiteral)) + val rewrittenActions = rewriteFiles(txn, targetDF, filterCond, filesToRewrite.length) + val (changeFiles, rewrittenFiles) = rewrittenActions + .partition(_.isInstanceOf[AddCDCFile]) + numAddedFiles = rewrittenFiles.size + val removedFiles = filesToRewrite.map(f => + getTouchedFile(deltaLog.dataPath, f, nameToAddFileMap)) + val (removedBytes, removedPartitions) = + totalBytesAndDistinctPartitionValues(removedFiles) + numBytesRemoved = removedBytes + val (rewrittenBytes, rewrittenPartitions) = + totalBytesAndDistinctPartitionValues(rewrittenFiles) + numBytesAdded = rewrittenBytes + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsRemovedFrom = Some(removedPartitions) + numPartitionsAddedTo = Some(rewrittenPartitions) + } + numAddedChangeFiles = changeFiles.size + changeFileBytes = changeFiles.collect { case f: AddCDCFile => f.size }.sum + rewriteTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 - scanTimeMs + numDeletedRows = Some(metrics("numDeletedRows").value) + numCopiedRows = Some(metrics("numTouchedRows").value - metrics("numDeletedRows").value) + + val operationTimestamp = System.currentTimeMillis() + removeFilesFromPaths(deltaLog, nameToAddFileMap, filesToRewrite, + operationTimestamp) ++ rewrittenActions + } + } + } + metrics("numRemovedFiles").set(numRemovedFiles) + metrics("numAddedFiles").set(numAddedFiles) + val executionTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + metrics("executionTimeMs").set(executionTimeMs) + metrics("scanTimeMs").set(scanTimeMs) + metrics("rewriteTimeMs").set(rewriteTimeMs) + metrics("numAddedChangeFiles").set(numAddedChangeFiles) + metrics("changeFileBytes").set(changeFileBytes) + metrics("numAddedBytes").set(numBytesAdded) + metrics("numRemovedBytes").set(numBytesRemoved) + metrics("numFilesBeforeSkipping").set(numFilesBeforeSkipping) + metrics("numBytesBeforeSkipping").set(numBytesBeforeSkipping) + metrics("numFilesAfterSkipping").set(numFilesAfterSkipping) + metrics("numBytesAfterSkipping").set(numBytesAfterSkipping) + numPartitionsAfterSkipping.foreach(metrics("numPartitionsAfterSkipping").set) + numPartitionsAddedTo.foreach(metrics("numPartitionsAddedTo").set) + numPartitionsRemovedFrom.foreach(metrics("numPartitionsRemovedFrom").set) + numCopiedRows.foreach(metrics("numCopiedRows").set) + metrics("numDeletionVectorsAdded").set(0) + metrics("numDeletionVectorsRemoved").set(0) + txn.registerSQLMetrics(sparkSession, metrics) + // This is needed to make the SQL metrics visible in the Spark UI + val executionId = sparkSession.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkSession.sparkContext, executionId, metrics.values.toSeq) + + recordDeltaEvent( + deltaLog, + "delta.dml.delete.stats", + data = DeleteMetric( + condition = condition.map(_.sql).getOrElse("true"), + numFilesTotal, + numFilesAfterSkipping, + numAddedFiles, + numRemovedFiles, + numAddedFiles, + numAddedChangeFiles = numAddedChangeFiles, + numFilesBeforeSkipping, + numBytesBeforeSkipping, + numFilesAfterSkipping, + numBytesAfterSkipping, + numPartitionsAfterSkipping, + numPartitionsAddedTo, + numPartitionsRemovedFrom, + numCopiedRows, + numDeletedRows, + numBytesAdded, + numBytesRemoved, + changeFileBytes = changeFileBytes, + scanTimeMs, + rewriteTimeMs) + ) + + DMLUtils.TaggedCommitData(deleteActions) + .withTag(PreservedRowTrackingTag, RowTracking.isEnabled(txn.protocol, txn.metadata)) + .withTag(NoRowsCopiedTag, metrics("numCopiedRows").value == 0) + } + + /** + * Returns the list of `AddFile`s and `AddCDCFile`s that have been re-written. + */ + private def rewriteFiles( + txn: OptimisticTransaction, + baseData: DataFrame, + filterCondition: Expression, + numFilesToRewrite: Long): Seq[FileAction] = { + val shouldWriteCdc = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(txn.metadata) + + // number of total rows that we have seen / are either copying or deleting (sum of both). + val numTouchedRows = metrics("numTouchedRows") + val numTouchedRowsUdf = DeltaUDF.boolean { + new GpuDeltaMetricUpdateUDF(numTouchedRows) + }.asNondeterministic() + + withStatusCode( + "DELTA", rewritingFilesMsg(numFilesToRewrite)) { + val dfToWrite = if (shouldWriteCdc) { + import com.databricks.sql.transaction.tahoe.commands.cdc.CDCReader._ + // The logic here ends up being surprisingly elegant, with all source rows ending up in + // the output. Recall that we flipped the user-provided delete condition earlier, before the + // call to `rewriteFiles`. All rows which match this latest `filterCondition` are retained + // as table data, while all rows which don't match are removed from the rewritten table data + // but do get included in the output as CDC events. + baseData + .filter(numTouchedRowsUdf()) + .withColumn( + CDC_TYPE_COLUMN_NAME, + new Column(If(filterCondition, CDC_TYPE_NOT_CDC, CDC_TYPE_DELETE)) + ) + } else { + baseData + .filter(numTouchedRowsUdf()) + .filter(new Column(filterCondition)) + } + + txn.writeFiles(dfToWrite) + } + } +} + +object GpuDeleteCommand { + val FINDING_TOUCHED_FILES_MSG: String = "Finding files to rewrite for DELETE operation" + + def rewritingFilesMsg(numFilesToRewrite: Long): String = + s"Rewriting $numFilesToRewrite files for DELETE operation" +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala new file mode 100644 index 00000000000..8c62f4f3fd7 --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeltaDataSource.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import java.util + +import com.databricks.sql.transaction.tahoe.{DeltaConfigs, DeltaErrors} +import com.databricks.sql.transaction.tahoe.commands.TableCreationModes +import com.databricks.sql.transaction.tahoe.metering.DeltaLogging +import com.databricks.sql.transaction.tahoe.sources.DeltaSourceUtils +import com.nvidia.spark.rapids.RapidsConf + +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, Table} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.types.StructType + +class GpuDeltaCatalog( + override val cpuCatalog: StagingTableCatalog, + override val rapidsConf: RapidsConf) + extends GpuDeltaCatalogBase with SupportsPathIdentifier with DeltaLogging { + + override protected def buildGpuCreateDeltaTableCommand( + rapidsConf: RapidsConf, + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode, + tableByPath: Boolean): LeafRunnableCommand = { + GpuCreateDeltaTableCommand( + table, + existingTableOpt, + mode, + query, + operation, + tableByPath = tableByPath + )(rapidsConf) + } + + override protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] = { + // If this is a path identifier, we cannot return an existing CatalogTable. The Create command + // will check the file system itself + if (isPathIdentifier(table)) return None + val tableExists = catalog.tableExists(table) + if (tableExists) { + val oldTable = catalog.getTableMetadata(table) + if (oldTable.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s"$table is a view. You may not write data into a view.") + } + if (!DeltaSourceUtils.isDeltaTable(oldTable.provider)) { + throw new AnalysisException(s"$table is not a Delta table. Please drop this " + + "table first if you would like to recreate it with Delta Lake.") + } + Some(oldTable) + } else { + None + } + } + + override protected def verifyTableAndSolidify( + tableDesc: CatalogTable, + query: Option[LogicalPlan]): CatalogTable = { + + if (tableDesc.bucketSpec.isDefined) { + throw DeltaErrors.operationNotSupportedException("Bucketing", tableDesc.identifier) + } + + val schema = query.map { plan => + assert(tableDesc.schema.isEmpty, "Can't specify table schema in CTAS.") + plan.schema.asNullable + }.getOrElse(tableDesc.schema) + + PartitioningUtils.validatePartitionColumn( + schema, + tableDesc.partitionColumnNames, + caseSensitive = false) // Delta is case insensitive + + val validatedConfigurations = DeltaConfigs.validateConfigurations(tableDesc.properties) + + val db = tableDesc.identifier.database.getOrElse(catalog.getCurrentDatabase) + val tableIdentWithDB = tableDesc.identifier.copy(database = Some(db)) + tableDesc.copy( + identifier = tableIdentWithDB, + schema = schema, + properties = validatedConfigurations) + } + + override protected def createGpuStagedDeltaTableV2( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode): StagedTable = { + new GpuStagedDeltaTableV2WithLogging(ident, schema, partitions, properties, operation) + } + + override def loadTable(ident: Identifier, timestamp: Long): Table = { + cpuCatalog.loadTable(ident, timestamp) + } + + override def loadTable(ident: Identifier, version: String): Table = { + cpuCatalog.loadTable(ident, version) + } + + /** + * Creates a Delta table using GPU for writing the data + * + * @param ident The identifier of the table + * @param schema The schema of the table + * @param partitions The partition transforms for the table + * @param allTableProperties The table properties that configure the behavior of the table or + * provide information about the table + * @param writeOptions Options specific to the write during table creation or replacement + * @param sourceQuery A query if this CREATE request came from a CTAS or RTAS + * @param operation The specific table creation mode, whether this is a + * Create/Replace/Create or Replace + */ + override def createDeltaTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + allTableProperties: util.Map[String, String], + writeOptions: Map[String, String], + sourceQuery: Option[DataFrame], + operation: TableCreationModes.CreationMode + ): Table = recordFrameProfile( + "DeltaCatalog", "createDeltaTable") { + super.createDeltaTable( + ident, + schema, + partitions, + allTableProperties, + writeOptions, + sourceQuery, + operation) + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = + recordFrameProfile("DeltaCatalog", "createTable") { + super.createTable(ident, schema, partitions, properties) + } + + override def stageReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageReplace") { + super.stageReplace(ident, schema, partitions, properties) + } + + override def stageCreateOrReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageCreateOrReplace") { + super.stageCreateOrReplace(ident, schema, partitions, properties) + } + + override def stageCreate( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageCreate") { + super.stageCreate(ident, schema, partitions, properties) + } + + /** + * A staged Delta table, which creates a HiveMetaStore entry and appends data if this was a + * CTAS/RTAS command. We have a ugly way of using this API right now, but it's the best way to + * maintain old behavior compatibility between Databricks Runtime and OSS Delta Lake. + */ + protected class GpuStagedDeltaTableV2WithLogging( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode) + extends GpuStagedDeltaTableV2(ident, schema, partitions, properties, operation) { + + override def commitStagedChanges(): Unit = recordFrameProfile( + "DeltaCatalog", "commitStagedChanges") { + super.commitStagedChanges() + } + } +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDoAutoCompaction.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDoAutoCompaction.scala new file mode 100644 index 00000000000..9726511ad44 --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDoAutoCompaction.scala @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DoAutoCompaction.scala + * from https://github.com/delta-io/delta/pull/1156 + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import com.databricks.sql.transaction.tahoe._ +import com.databricks.sql.transaction.tahoe.actions.Action +import com.databricks.sql.transaction.tahoe.hooks.PostCommitHook +import com.databricks.sql.transaction.tahoe.metering.DeltaLogging + +import org.apache.spark.sql.SparkSession + +object GpuDoAutoCompaction extends PostCommitHook + with DeltaLogging + with Serializable { + override val name: String = "Triggers compaction if necessary" + + override def run(spark: SparkSession, + txn: OptimisticTransactionImpl, + committedVersion: Long, + postCommitSnapshot: Snapshot, + committedActions: Seq[Action]): Unit = { + val gpuTxn = txn.asInstanceOf[GpuOptimisticTransaction] + val newTxn = new GpuDeltaLog(gpuTxn.deltaLog, gpuTxn.rapidsConf).startTransaction() + // Note: The Databricks AutoCompact PostCommitHook cannot be used here + // (with a GpuOptimisticTransaction). It appears that AutoCompact creates a new transaction, + // thereby circumventing GpuOptimisticTransaction (which intercepts Parquet writes + // to go through the GPU). + new GpuOptimizeExecutor(spark, newTxn, Seq.empty, Seq.empty, committedActions).optimize() + } + + override def handleError(error: Throwable, version: Long): Unit = + throw DeltaErrors.postCommitHookFailedException(this, version, name, error) +} \ No newline at end of file diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuMergeIntoCommand.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuMergeIntoCommand.scala new file mode 100644 index 00000000000..a0a4e4263c3 --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuMergeIntoCommand.scala @@ -0,0 +1,1187 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from MergeIntoCommand.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import com.databricks.sql.transaction.tahoe._ +import com.databricks.sql.transaction.tahoe.DeltaOperations.MergePredicate +import com.databricks.sql.transaction.tahoe.actions.{AddCDCFile, AddFile, FileAction} +import com.databricks.sql.transaction.tahoe.commands.DeltaCommand +import com.databricks.sql.transaction.tahoe.schema.ImplicitMetadataOperation +import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf +import com.databricks.sql.transaction.tahoe.util.{AnalysisHelper, SetAccumulator} +import com.fasterxml.jackson.databind.annotation.JsonDeserialize +import com.nvidia.spark.rapids.{BaseExprMeta, GpuOverrides, RapidsConf} +import com.nvidia.spark.rapids.delta._ + +import org.apache.spark.SparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BasePredicate, Expression, Literal, NamedExpression, PredicateHelper, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.plans.logical.{DeltaMergeIntoClause, DeltaMergeIntoMatchedClause, DeltaMergeIntoMatchedDeleteClause, DeltaMergeIntoMatchedUpdateClause, DeltaMergeIntoNotMatchedBySourceClause, DeltaMergeIntoNotMatchedClause, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataTypes, LongType, StringType, StructType} +case class GpuMergeDataSizes( + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + rows: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + files: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + bytes: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + partitions: Option[Long] = None) + +/** + * Represents the state of a single merge clause: + * - merge clause's (optional) predicate + * - action type (insert, update, delete) + * - action's expressions + */ +case class GpuMergeClauseStats( + condition: Option[String], + actionType: String, + actionExpr: Seq[String]) + +object GpuMergeClauseStats { + def apply(mergeClause: DeltaMergeIntoClause): GpuMergeClauseStats = { + GpuMergeClauseStats( + condition = mergeClause.condition.map(_.sql), + mergeClause.clauseType.toLowerCase(), + actionExpr = mergeClause.actions.map(_.sql)) + } +} + +/** State for a GPU merge operation */ +case class GpuMergeStats( + // Merge condition expression + conditionExpr: String, + + // Expressions used in old MERGE stats, now always Null + updateConditionExpr: String, + updateExprs: Seq[String], + insertConditionExpr: String, + insertExprs: Seq[String], + deleteConditionExpr: String, + + // Newer expressions used in MERGE with any number of MATCHED/NOT MATCHED + matchedStats: Seq[GpuMergeClauseStats], + notMatchedStats: Seq[GpuMergeClauseStats], + + // Data sizes of source and target at different stages of processing + source: GpuMergeDataSizes, + targetBeforeSkipping: GpuMergeDataSizes, + targetAfterSkipping: GpuMergeDataSizes, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + sourceRowsInSecondScan: Option[Long], + + // Data change sizes + targetFilesRemoved: Long, + targetFilesAdded: Long, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetChangeFilesAdded: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetChangeFileBytes: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetBytesRemoved: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetBytesAdded: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetPartitionsRemovedFrom: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetPartitionsAddedTo: Option[Long], + targetRowsCopied: Long, + targetRowsUpdated: Long, + targetRowsInserted: Long, + targetRowsDeleted: Long +) + +object GpuMergeStats { + + def fromMergeSQLMetrics( + metrics: Map[String, SQLMetric], + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], + isPartitioned: Boolean): GpuMergeStats = { + + def metricValueIfPartitioned(metricName: String): Option[Long] = { + if (isPartitioned) Some(metrics(metricName).value) else None + } + + GpuMergeStats( + // Merge condition expression + conditionExpr = condition.sql, + + // Newer expressions used in MERGE with any number of MATCHED/NOT MATCHED + matchedStats = matchedClauses.map(GpuMergeClauseStats(_)), + notMatchedStats = notMatchedClauses.map(GpuMergeClauseStats(_)), + + // Data sizes of source and target at different stages of processing + source = GpuMergeDataSizes(rows = Some(metrics("numSourceRows").value)), + targetBeforeSkipping = + GpuMergeDataSizes( + files = Some(metrics("numTargetFilesBeforeSkipping").value), + bytes = Some(metrics("numTargetBytesBeforeSkipping").value)), + targetAfterSkipping = + GpuMergeDataSizes( + files = Some(metrics("numTargetFilesAfterSkipping").value), + bytes = Some(metrics("numTargetBytesAfterSkipping").value), + partitions = metricValueIfPartitioned("numTargetPartitionsAfterSkipping")), + sourceRowsInSecondScan = + metrics.get("numSourceRowsInSecondScan").map(_.value).filter(_ >= 0), + + // Data change sizes + targetFilesAdded = metrics("numTargetFilesAdded").value, + targetChangeFilesAdded = metrics.get("numTargetChangeFilesAdded").map(_.value), + targetChangeFileBytes = metrics.get("numTargetChangeFileBytes").map(_.value), + targetFilesRemoved = metrics("numTargetFilesRemoved").value, + targetBytesAdded = Some(metrics("numTargetBytesAdded").value), + targetBytesRemoved = Some(metrics("numTargetBytesRemoved").value), + targetPartitionsRemovedFrom = metricValueIfPartitioned("numTargetPartitionsRemovedFrom"), + targetPartitionsAddedTo = metricValueIfPartitioned("numTargetPartitionsAddedTo"), + targetRowsCopied = metrics("numTargetRowsCopied").value, + targetRowsUpdated = metrics("numTargetRowsUpdated").value, + targetRowsInserted = metrics("numTargetRowsInserted").value, + targetRowsDeleted = metrics("numTargetRowsDeleted").value, + + // Deprecated fields + updateConditionExpr = null, + updateExprs = null, + insertConditionExpr = null, + insertExprs = null, + deleteConditionExpr = null) + } +} + +/** + * GPU version of Delta Lake's MergeIntoCommand. + * + * Performs a merge of a source query/table into a Delta table. + * + * Issues an error message when the ON search_condition of the MERGE statement can match + * a single row from the target table with multiple rows of the source table-reference. + * + * Algorithm: + * + * Phase 1: Find the input files in target that are touched by the rows that satisfy + * the condition and verify that no two source rows match with the same target row. + * This is implemented as an inner-join using the given condition. See [[findTouchedFiles]] + * for more details. + * + * Phase 2: Read the touched files again and write new files with updated and/or inserted rows. + * + * Phase 3: Use the Delta protocol to atomically remove the touched files and add the new files. + * + * @param source Source data to merge from + * @param target Target table to merge into + * @param gpuDeltaLog Delta log to use + * @param condition Condition for a source row to match with a target row + * @param matchedClauses All info related to matched clauses. + * @param notMatchedClauses All info related to not matched clause. + * @param migratedSchema The final schema of the target - may be changed by schema evolution. + */ +case class GpuMergeIntoCommand( + @transient source: LogicalPlan, + @transient target: LogicalPlan, + @transient gpuDeltaLog: GpuDeltaLog, + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], + notMatchedBySourceClauses: Seq[DeltaMergeIntoNotMatchedBySourceClause], + migratedSchema: Option[StructType])( + @transient val rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaCommand with PredicateHelper with AnalysisHelper with ImplicitMetadataOperation { + + import GpuMergeIntoCommand._ + + import SQLMetrics._ + import com.databricks.sql.transaction.tahoe.commands.cdc.CDCReader._ + + override val otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE) + override val canOverwriteSchema: Boolean = false + + override val output: Seq[Attribute] = Seq( + AttributeReference("num_affected_rows", LongType)(), + AttributeReference("num_updated_rows", LongType)(), + AttributeReference("num_deleted_rows", LongType)(), + AttributeReference("num_inserted_rows", LongType)()) + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + @transient private lazy val targetDeltaLog: DeltaLog = gpuDeltaLog.deltaLog + /** + * Map to get target output attributes by name. + * The case sensitivity of the map is set accordingly to Spark configuration. + */ + @transient private lazy val targetOutputAttributesMap: Map[String, Attribute] = { + val attrMap: Map[String, Attribute] = target + .outputSet.view + .map(attr => attr.name -> attr).toMap + if (conf.caseSensitiveAnalysis) { + attrMap + } else { + CaseInsensitiveMap(attrMap) + } + } + + /** Whether this merge statement has only a single insert (NOT MATCHED) clause. */ + private def isSingleInsertOnly: Boolean = matchedClauses.isEmpty && notMatchedClauses.length == 1 + /** Whether this merge statement has only MATCHED clauses. */ + private def isMatchedOnly: Boolean = notMatchedClauses.isEmpty && matchedClauses.nonEmpty + + // We over-count numTargetRowsDeleted when there are multiple matches; + // this is the amount of the overcount, so we can subtract it to get a correct final metric. + private var multipleMatchDeleteOnlyOvercount: Option[Long] = None + + override lazy val metrics = Map[String, SQLMetric]( + "numSourceRows" -> createMetric(sc, "number of source rows"), + "numSourceRowsInSecondScan" -> + createMetric(sc, "number of source rows (during repeated scan)"), + "numTargetRowsCopied" -> createMetric(sc, "number of target rows rewritten unmodified"), + "numTargetRowsInserted" -> createMetric(sc, "number of inserted rows"), + "numTargetRowsUpdated" -> createMetric(sc, "number of updated rows"), + "numTargetRowsDeleted" -> createMetric(sc, "number of deleted rows"), + "numTargetFilesBeforeSkipping" -> createMetric(sc, "number of target files before skipping"), + "numTargetFilesAfterSkipping" -> createMetric(sc, "number of target files after skipping"), + "numTargetFilesRemoved" -> createMetric(sc, "number of files removed to target"), + "numTargetFilesAdded" -> createMetric(sc, "number of files added to target"), + "numTargetChangeFilesAdded" -> + createMetric(sc, "number of change data capture files generated"), + "numTargetChangeFileBytes" -> + createMetric(sc, "total size of change data capture files generated"), + "numTargetBytesBeforeSkipping" -> createMetric(sc, "number of target bytes before skipping"), + "numTargetBytesAfterSkipping" -> createMetric(sc, "number of target bytes after skipping"), + "numTargetBytesRemoved" -> createMetric(sc, "number of target bytes removed"), + "numTargetBytesAdded" -> createMetric(sc, "number of target bytes added"), + "numTargetPartitionsAfterSkipping" -> + createMetric(sc, "number of target partitions after skipping"), + "numTargetPartitionsRemovedFrom" -> + createMetric(sc, "number of target partitions from which files were removed"), + "numTargetPartitionsAddedTo" -> + createMetric(sc, "number of target partitions to which files were added"), + "executionTimeMs" -> + createMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> + createMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> + createMetric(sc, "time taken to rewrite the matched files")) + + override def run(spark: SparkSession): Seq[Row] = { + recordDeltaOperation(targetDeltaLog, "delta.dml.merge") { + val startTime = System.nanoTime() + gpuDeltaLog.withNewTransaction { deltaTxn => + if (target.schema.size != deltaTxn.metadata.schema.size) { + throw DeltaErrors.schemaChangedSinceAnalysis( + atAnalysis = target.schema, latestSchema = deltaTxn.metadata.schema) + } + + if (canMergeSchema) { + updateMetadata( + spark, deltaTxn, migratedSchema.getOrElse(target.schema), + deltaTxn.metadata.partitionColumns, deltaTxn.metadata.configuration, + isOverwriteMode = false, rearrangeOnly = false) + } + + val deltaActions = { + if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) { + writeInsertsOnlyWhenNoMatchedClauses(spark, deltaTxn) + } else { + val filesToRewrite = findTouchedFiles(spark, deltaTxn) + val newWrittenFiles = withStatusCode("DELTA", "Writing merged data") { + writeAllChanges(spark, deltaTxn, filesToRewrite) + } + filesToRewrite.map(_.remove) ++ newWrittenFiles + } + } + + // Metrics should be recorded before commit (where they are written to delta logs). + metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000) + deltaTxn.registerSQLMetrics(spark, metrics) + + // This is a best-effort sanity check. + if (metrics("numSourceRowsInSecondScan").value >= 0 && + metrics("numSourceRows").value != metrics("numSourceRowsInSecondScan").value) { + log.warn(s"Merge source has ${metrics("numSourceRows").value} rows in initial scan but " + + s"${metrics("numSourceRowsInSecondScan").value} rows in second scan") + if (conf.getConf(DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED)) { + throw DeltaErrors.sourceNotDeterministicInMergeException(spark) + } + } + + deltaTxn.commit( + deltaActions, + DeltaOperations.Merge( + Option(condition), + matchedClauses.map(DeltaOperations.MergePredicate(_)), + notMatchedClauses.map(DeltaOperations.MergePredicate(_)), + // We do not support notMatchedBySourcePredicates yet and fall back to CPU + // See https://github.com/NVIDIA/spark-rapids/issues/8415 + notMatchedBySourcePredicates = Seq.empty[MergePredicate] + )) + + // Record metrics + val stats = GpuMergeStats.fromMergeSQLMetrics( + metrics, condition, matchedClauses, notMatchedClauses, + deltaTxn.metadata.partitionColumns.nonEmpty) + recordDeltaEvent(targetDeltaLog, "delta.dml.merge.stats", data = stats) + + } + spark.sharedState.cacheManager.recacheByPlan(spark, target) + } + // This is needed to make the SQL metrics visible in the Spark UI. Also this needs + // to be outside the recordMergeOperation because this method will update some metric. + val executionId = spark.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(spark.sparkContext, executionId, metrics.values.toSeq) + Seq(Row(metrics("numTargetRowsUpdated").value + metrics("numTargetRowsDeleted").value + + metrics("numTargetRowsInserted").value, metrics("numTargetRowsUpdated").value, + metrics("numTargetRowsDeleted").value, metrics("numTargetRowsInserted").value)) + } + + /** + * Find the target table files that contain the rows that satisfy the merge condition. This is + * implemented as an inner-join between the source query/table and the target table using + * the merge condition. + */ + private def findTouchedFiles( + spark: SparkSession, + deltaTxn: OptimisticTransaction + ): Seq[AddFile] = recordMergeOperation(sqlMetricName = "scanTimeMs") { + + // Accumulator to collect all the distinct touched files + val touchedFilesAccum = new SetAccumulator[String]() + spark.sparkContext.register(touchedFilesAccum, TOUCHED_FILES_ACCUM_NAME) + + // UDFs to records touched files names and add them to the accumulator + val recordTouchedFileName = udf(new GpuDeltaRecordTouchedFileNameUDF(touchedFilesAccum)) + .asNondeterministic() + + // Skip data based on the merge condition + val targetOnlyPredicates = + splitConjunctivePredicates(condition).filter(_.references.subsetOf(target.outputSet)) + val dataSkippedFiles = deltaTxn.filterFiles(targetOnlyPredicates) + + // UDF to increment metrics + val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows") + val sourceDF = Dataset.ofRows(spark, source) + .filter(new Column(incrSourceRowCountExpr)) + + // Apply inner join to between source and target using the merge condition to find matches + // In addition, we attach two columns + // - a monotonically increasing row id for target rows to later identify whether the same + // target row is modified by multiple user or not + // - the target file name the row is from to later identify the files touched by matched rows + val targetDF = Dataset.ofRows(spark, buildTargetPlanWithFiles(deltaTxn, dataSkippedFiles)) + .withColumn(ROW_ID_COL, monotonically_increasing_id()) + .withColumn(FILE_NAME_COL, input_file_name()) + val joinToFindTouchedFiles = sourceDF.join(targetDF, new Column(condition), "inner") + + // Process the matches from the inner join to record touched files and find multiple matches + val collectTouchedFiles = joinToFindTouchedFiles + .select(col(ROW_ID_COL), recordTouchedFileName(col(FILE_NAME_COL)).as("one")) + + // Calculate frequency of matches per source row + val matchedRowCounts = collectTouchedFiles.groupBy(ROW_ID_COL).agg(sum("one").as("count")) + + // Get multiple matches and simultaneously collect (using touchedFilesAccum) the file names + // multipleMatchCount = # of target rows with more than 1 matching source row (duplicate match) + // multipleMatchSum = total # of duplicate matched rows + import spark.implicits._ + val (multipleMatchCount, multipleMatchSum) = matchedRowCounts + .filter("count > 1") + .select(coalesce(count("*"), lit(0)), coalesce(sum("count"), lit(0))) + .as[(Long, Long)] + .collect() + .head + + val hasMultipleMatches = multipleMatchCount > 0 + + // Throw error if multiple matches are ambiguous or cannot be computed correctly. + val canBeComputedUnambiguously = { + // Multiple matches are not ambiguous when there is only one unconditional delete as + // all the matched row pairs in the 2nd join in `writeAllChanges` will get deleted. + val isUnconditionalDelete = matchedClauses.headOption match { + case Some(DeltaMergeIntoMatchedDeleteClause(None)) => true + case _ => false + } + matchedClauses.size == 1 && isUnconditionalDelete + } + + if (hasMultipleMatches && !canBeComputedUnambiguously) { + throw DeltaErrors.multipleSourceRowMatchingTargetRowInMergeException(spark) + } + + if (hasMultipleMatches) { + // This is only allowed for delete-only queries. + // This query will count the duplicates for numTargetRowsDeleted in Job 2, + // because we count matches after the join and not just the target rows. + // We have to compensate for this by subtracting the duplicates later, + // so we need to record them here. + val duplicateCount = multipleMatchSum - multipleMatchCount + multipleMatchDeleteOnlyOvercount = Some(duplicateCount) + } + + // Get the AddFiles using the touched file names. + val touchedFileNames = touchedFilesAccum.value.iterator().asScala.toSeq + logTrace(s"findTouchedFiles: matched files:\n\t${touchedFileNames.mkString("\n\t")}") + + val nameToAddFileMap = generateCandidateFileMap(targetDeltaLog.dataPath, dataSkippedFiles) + val touchedAddFiles = touchedFileNames.map(f => + getTouchedFile(targetDeltaLog.dataPath, f, nameToAddFileMap)) + + // When the target table is empty, and the optimizer optimized away the join entirely + // numSourceRows will be incorrectly 0. We need to scan the source table once to get the correct + // metric here. + if (metrics("numSourceRows").value == 0 && + (dataSkippedFiles.isEmpty || targetDF.take(1).isEmpty)) { + val numSourceRows = sourceDF.count() + metrics("numSourceRows").set(numSourceRows) + } + + // Update metrics + metrics("numTargetFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles + metrics("numTargetBytesBeforeSkipping") += deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + val (removedBytes, removedPartitions) = totalBytesAndDistinctPartitionValues(touchedAddFiles) + metrics("numTargetFilesRemoved") += touchedAddFiles.size + metrics("numTargetBytesRemoved") += removedBytes + metrics("numTargetPartitionsRemovedFrom") += removedPartitions + touchedAddFiles + } + + /** + * This is an optimization of the case when there is no update clause for the merge. + * We perform an left anti join on the source data to find the rows to be inserted. + * + * This will currently only optimize for the case when there is a _single_ notMatchedClause. + */ + private def writeInsertsOnlyWhenNoMatchedClauses( + spark: SparkSession, + deltaTxn: OptimisticTransaction + ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") { + + // UDFs to update metrics + val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows") + val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted") + + val outputColNames = getTargetOutputCols(deltaTxn).map(_.name) + // we use head here since we know there is only a single notMatchedClause + val outputExprs = notMatchedClauses.head.resolvedActions.map(_.expr) + val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) => + new Column(Alias(expr, name)()) + } + + // source DataFrame + val sourceDF = Dataset.ofRows(spark, source) + .filter(new Column(incrSourceRowCountExpr)) + .filter(new Column(notMatchedClauses.head.condition.getOrElse(Literal.TrueLiteral))) + + // Skip data based on the merge condition + val conjunctivePredicates = splitConjunctivePredicates(condition) + val targetOnlyPredicates = + conjunctivePredicates.filter(_.references.subsetOf(target.outputSet)) + val dataSkippedFiles = deltaTxn.filterFiles(targetOnlyPredicates) + + // target DataFrame + val targetDF = Dataset.ofRows( + spark, buildTargetPlanWithFiles(deltaTxn, dataSkippedFiles)) + + val insertDf = sourceDF.join(targetDF, new Column(condition), "leftanti") + .select(outputCols: _*) + .filter(new Column(incrInsertedCountExpr)) + + val newFiles = deltaTxn + .writeFiles(repartitionIfNeeded(spark, insertDf, deltaTxn.metadata.partitionColumns)) + + // Update metrics + metrics("numTargetFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles + metrics("numTargetBytesBeforeSkipping") += deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + metrics("numTargetFilesRemoved") += 0 + metrics("numTargetBytesRemoved") += 0 + metrics("numTargetPartitionsRemovedFrom") += 0 + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + metrics("numTargetBytesAdded") += addedBytes + metrics("numTargetPartitionsAddedTo") += addedPartitions + newFiles + } + + /** + * Write new files by reading the touched files and updating/inserting data using the source + * query/table. This is implemented using a full|right-outer-join using the merge condition. + * + * Note that unlike the insert-only code paths with just one control column INCR_ROW_COUNT_COL, + * this method has two additional control columns ROW_DROPPED_COL for dropping deleted rows and + * CDC_TYPE_COL_NAME used for handling CDC when enabled. + */ + private def writeAllChanges( + spark: SparkSession, + deltaTxn: OptimisticTransaction, + filesToRewrite: Seq[AddFile] + ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") { + import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} + + val cdcEnabled = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(deltaTxn.metadata) + + var targetOutputCols = getTargetOutputCols(deltaTxn) + var outputRowSchema = deltaTxn.metadata.schema + + // When we have duplicate matches (only allowed when the whenMatchedCondition is a delete with + // no match condition) we will incorrectly generate duplicate CDC rows. + // Duplicate matches can be due to: + // - Duplicate rows in the source w.r.t. the merge condition + // - A target-only or source-only merge condition, which essentially turns our join into a cross + // join with the target/source satisfiying the merge condition. + // These duplicate matches are dropped from the main data output since this is a delete + // operation, but the duplicate CDC rows are not removed by default. + // See https://github.com/delta-io/delta/issues/1274 + + // We address this specific scenario by adding row ids to the target before performing our join. + // There should only be one CDC delete row per target row so we can use these row ids to dedupe + // the duplicate CDC delete rows. + + // We also need to address the scenario when there are duplicate matches with delete and we + // insert duplicate rows. Here we need to additionally add row ids to the source before the + // join to avoid dropping these valid duplicate inserted rows and their corresponding cdc rows. + + // When there is an insert clause, we set SOURCE_ROW_ID_COL=null for all delete rows because we + // need to drop the duplicate matches. + val isDeleteWithDuplicateMatchesAndCdc = multipleMatchDeleteOnlyOvercount.nonEmpty && cdcEnabled + + // Generate a new logical plan that has same output attributes exprIds as the target plan. + // This allows us to apply the existing resolved update/insert expressions. + val newTarget = buildTargetPlanWithFiles(deltaTxn, filesToRewrite) + val joinType = if (isMatchedOnly && + spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED)) { + "rightOuter" + } else { + "fullOuter" + } + + logDebug(s"""writeAllChanges using $joinType join: + | source.output: ${source.outputSet} + | target.output: ${target.outputSet} + | condition: $condition + | newTarget.output: ${newTarget.outputSet} + """.stripMargin) + + // UDFs to update metrics + // Make UDFs that appear in the custom join processor node deterministic, as they always + // return true and update a metric. Catalyst precludes non-deterministic UDFs that are not + // allowed outside a very specific set of Catalyst nodes (Project, Filter, Window, Aggregate). + val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRowsInSecondScan") + val incrUpdatedCountExpr = makeMetricUpdateUDF("numTargetRowsUpdated", deterministic = true) + val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted", deterministic = true) + val incrNoopCountExpr = makeMetricUpdateUDF("numTargetRowsCopied", deterministic = true) + val incrDeletedCountExpr = makeMetricUpdateUDF("numTargetRowsDeleted", deterministic = true) + + // Apply an outer join to find both, matches and non-matches. We are adding two boolean fields + // with value `true`, one to each side of the join. Whether this field is null or not after + // the outer join, will allow us to identify whether the resultant joined row was a + // matched inner result or an unmatched result with null on one side. + // We add row IDs to the targetDF if we have a delete-when-matched clause with duplicate + // matches and CDC is enabled, and additionally add row IDs to the source if we also have an + // insert clause. See above at isDeleteWithDuplicateMatchesAndCdc definition for more details. + var sourceDF = Dataset.ofRows(spark, source) + .withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr)) + var targetDF = Dataset.ofRows(spark, newTarget) + .withColumn(TARGET_ROW_PRESENT_COL, lit(true)) + if (isDeleteWithDuplicateMatchesAndCdc) { + targetDF = targetDF.withColumn(TARGET_ROW_ID_COL, monotonically_increasing_id()) + if (notMatchedClauses.nonEmpty) { // insert clause + sourceDF = sourceDF.withColumn(SOURCE_ROW_ID_COL, monotonically_increasing_id()) + } + } + val joinedDF = sourceDF.join(targetDF, new Column(condition), joinType) + val joinedPlan = joinedDF.queryExecution.analyzed + + def resolveOnJoinedPlan(exprs: Seq[Expression]): Seq[Expression] = { + tryResolveReferencesForExpressions(spark, exprs, joinedPlan) + } + + // ==== Generate the expressions to process full-outer join output and generate target rows ==== + // If there are N columns in the target table, there will be N + 3 columns after processing + // - N columns for target table + // - ROW_DROPPED_COL to define whether the generated row should dropped or written + // - INCR_ROW_COUNT_COL containing a UDF to update the output row row counter + // - CDC_TYPE_COLUMN_NAME containing the type of change being performed in a particular row + + // To generate these N + 3 columns, we will generate N + 3 expressions and apply them to the + // rows in the joinedDF. The CDC column will be either used for CDC generation or dropped before + // performing the final write, and the other two will always be dropped after executing the + // metrics UDF and filtering on ROW_DROPPED_COL. + + // We produce rows for both the main table data (with CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC), + // and rows for the CDC data which will be output to CDCReader.CDC_LOCATION. + // See [[CDCReader]] for general details on how partitioning on the CDC type column works. + + // In the following two functions `matchedClauseOutput` and `notMatchedClauseOutput`, we + // produce a Seq[Expression] for each intended output row. + // Depending on the clause and whether CDC is enabled, we output between 0 and 3 rows, as a + // Seq[Seq[Expression]] + + // There is one corner case outlined above at isDeleteWithDuplicateMatchesAndCdc definition. + // When we have a delete-ONLY merge with duplicate matches we have N + 4 columns: + // N target cols, TARGET_ROW_ID_COL, ROW_DROPPED_COL, INCR_ROW_COUNT_COL, CDC_TYPE_COLUMN_NAME + // When we have a delete-when-matched merge with duplicate matches + an insert clause, we have + // N + 5 columns: + // N target cols, TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL, ROW_DROPPED_COL, INCR_ROW_COUNT_COL, + // CDC_TYPE_COLUMN_NAME + // These ROW_ID_COL will always be dropped before the final write. + + if (isDeleteWithDuplicateMatchesAndCdc) { + targetOutputCols = targetOutputCols :+ UnresolvedAttribute(TARGET_ROW_ID_COL) + outputRowSchema = outputRowSchema.add(TARGET_ROW_ID_COL, DataTypes.LongType) + if (notMatchedClauses.nonEmpty) { // there is an insert clause, make SRC_ROW_ID_COL=null + targetOutputCols = targetOutputCols :+ Alias(Literal(null), SOURCE_ROW_ID_COL)() + outputRowSchema = outputRowSchema.add(SOURCE_ROW_ID_COL, DataTypes.LongType) + } + } + + if (cdcEnabled) { + outputRowSchema = outputRowSchema + .add(ROW_DROPPED_COL, DataTypes.BooleanType) + .add(INCR_ROW_COUNT_COL, DataTypes.BooleanType) + .add(CDC_TYPE_COLUMN_NAME, DataTypes.StringType) + } + + def matchedClauseOutput(clause: DeltaMergeIntoMatchedClause): Seq[Seq[Expression]] = { + val exprs = clause match { + case u: DeltaMergeIntoMatchedUpdateClause => + // Generate update expressions and set ROW_DELETED_COL = false and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC + val mainDataOutput = u.resolvedActions.map(_.expr) :+ FalseLiteral :+ + incrUpdatedCountExpr :+ CDC_TYPE_NOT_CDC_LITERAL + if (cdcEnabled) { + // For update preimage, we have do a no-op copy with ROW_DELETED_COL = false and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_PREIMAGE and INCR_ROW_COUNT_COL as a no-op + // (because the metric will be incremented in `mainDataOutput`) + val preImageOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+ + Literal(CDC_TYPE_UPDATE_PREIMAGE) + // For update postimage, we have the same expressions as for mainDataOutput but with + // INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in + // `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_POSTIMAGE + val postImageOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+ + Literal(CDC_TYPE_UPDATE_POSTIMAGE) + Seq(mainDataOutput, preImageOutput, postImageOutput) + } else { + Seq(mainDataOutput) + } + case _: DeltaMergeIntoMatchedDeleteClause => + // Generate expressions to set the ROW_DELETED_COL = true and CDC_TYPE_COLUMN_NAME = + // CDC_TYPE_NOT_CDC + val mainDataOutput = targetOutputCols :+ TrueLiteral :+ incrDeletedCountExpr :+ + CDC_TYPE_NOT_CDC_LITERAL + if (cdcEnabled) { + // For delete we do a no-op copy with ROW_DELETED_COL = false, INCR_ROW_COUNT_COL as a + // no-op (because the metric will be incremented in `mainDataOutput`) and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_DELETE + val deleteCdcOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+ + Literal(CDC_TYPE_DELETE) + Seq(mainDataOutput, deleteCdcOutput) + } else { + Seq(mainDataOutput) + } + } + exprs.map(resolveOnJoinedPlan) + } + + def notMatchedClauseOutput(clause: DeltaMergeIntoNotMatchedClause): Seq[Seq[Expression]] = { + // Generate insert expressions and set ROW_DELETED_COL = false and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC + val insertExprs = clause.resolvedActions.map(_.expr) + val mainDataOutput = resolveOnJoinedPlan( + if (isDeleteWithDuplicateMatchesAndCdc) { + // Must be delete-when-matched merge with duplicate matches + insert clause + // Therefore we must keep the target row id and source row id. Since this is a not-matched + // clause we know the target row-id will be null. See above at + // isDeleteWithDuplicateMatchesAndCdc definition for more details. + insertExprs :+ + Alias(Literal(null), TARGET_ROW_ID_COL)() :+ UnresolvedAttribute(SOURCE_ROW_ID_COL) :+ + FalseLiteral :+ incrInsertedCountExpr :+ CDC_TYPE_NOT_CDC_LITERAL + } else { + insertExprs :+ FalseLiteral :+ incrInsertedCountExpr :+ CDC_TYPE_NOT_CDC_LITERAL + } + ) + if (cdcEnabled) { + // For insert we have the same expressions as for mainDataOutput, but with + // INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in + // `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_INSERT + val insertCdcOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+ Literal(CDC_TYPE_INSERT) + Seq(mainDataOutput, insertCdcOutput) + } else { + Seq(mainDataOutput) + } + } + + def clauseCondition(clause: DeltaMergeIntoClause): Expression = { + // if condition is None, then expression always evaluates to true + val condExpr = clause.condition.getOrElse(TrueLiteral) + resolveOnJoinedPlan(Seq(condExpr)).head + } + + val targetRowHasNoMatch = resolveOnJoinedPlan(Seq(col(SOURCE_ROW_PRESENT_COL).isNull.expr)).head + val sourceRowHasNoMatch = resolveOnJoinedPlan(Seq(col(TARGET_ROW_PRESENT_COL).isNull.expr)).head + val matchedConditions = matchedClauses.map(clauseCondition) + val matchedOutputs = matchedClauses.map(matchedClauseOutput) + val notMatchedConditions = notMatchedClauses.map(clauseCondition) + val notMatchedOutputs = notMatchedClauses.map(notMatchedClauseOutput) + // TODO support notMatchedBySourceClauses which is new in DBR 12.2 + // https://github.com/NVIDIA/spark-rapids/issues/8415 + val notMatchedBySourceConditions = Seq.empty + val notMatchedBySourceOutputs = Seq.empty + val noopCopyOutput = + resolveOnJoinedPlan(targetOutputCols :+ FalseLiteral :+ incrNoopCountExpr :+ + CDC_TYPE_NOT_CDC_LITERAL) + val deleteRowOutput = + resolveOnJoinedPlan(targetOutputCols :+ TrueLiteral :+ TrueLiteral :+ + CDC_TYPE_NOT_CDC_LITERAL) + var outputDF = addMergeJoinProcessor(spark, joinedPlan, outputRowSchema, + targetRowHasNoMatch = targetRowHasNoMatch, + sourceRowHasNoMatch = sourceRowHasNoMatch, + matchedConditions = matchedConditions, + matchedOutputs = matchedOutputs, + notMatchedConditions = notMatchedConditions, + notMatchedOutputs = notMatchedOutputs, + notMatchedBySourceConditions = notMatchedBySourceConditions, + notMatchedBySourceOutputs = notMatchedBySourceOutputs, + noopCopyOutput = noopCopyOutput, + deleteRowOutput = deleteRowOutput) + + if (isDeleteWithDuplicateMatchesAndCdc) { + // When we have a delete when matched clause with duplicate matches we have to remove + // duplicate CDC rows. This scenario is further explained at + // isDeleteWithDuplicateMatchesAndCdc definition. + + // To remove duplicate CDC rows generated by the duplicate matches we dedupe by + // TARGET_ROW_ID_COL since there should only be one CDC delete row per target row. + // When there is an insert clause in addition to the delete clause we additionally dedupe by + // SOURCE_ROW_ID_COL and CDC_TYPE_COLUMN_NAME to avoid dropping valid duplicate inserted rows + // and their corresponding CDC rows. + val columnsToDedupeBy = if (notMatchedClauses.nonEmpty) { // insert clause + Seq(TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL, CDC_TYPE_COLUMN_NAME) + } else { + Seq(TARGET_ROW_ID_COL) + } + outputDF = outputDF + .dropDuplicates(columnsToDedupeBy) + .drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL, TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL) + } else { + outputDF = outputDF.drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL) + } + + logDebug("writeAllChanges: join output plan:\n" + outputDF.queryExecution) + + // Write to Delta + val newFiles = deltaTxn + .writeFiles(repartitionIfNeeded(spark, outputDF, deltaTxn.metadata.partitionColumns)) + + // Update metrics + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + metrics("numTargetChangeFilesAdded") += newFiles.count(_.isInstanceOf[AddCDCFile]) + metrics("numTargetChangeFileBytes") += newFiles.collect{ case f: AddCDCFile => f.size }.sum + metrics("numTargetBytesAdded") += addedBytes + metrics("numTargetPartitionsAddedTo") += addedPartitions + if (multipleMatchDeleteOnlyOvercount.isDefined) { + // Compensate for counting duplicates during the query. + val actualRowsDeleted = + metrics("numTargetRowsDeleted").value - multipleMatchDeleteOnlyOvercount.get + assert(actualRowsDeleted >= 0) + metrics("numTargetRowsDeleted").set(actualRowsDeleted) + } + + newFiles + } + + private def addMergeJoinProcessor( + spark: SparkSession, + joinedPlan: LogicalPlan, + outputRowSchema: StructType, + targetRowHasNoMatch: Expression, + sourceRowHasNoMatch: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Seq[Expression]]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Seq[Expression]]], + notMatchedBySourceConditions: Seq[Expression], + notMatchedBySourceOutputs: Seq[Seq[Seq[Expression]]], + noopCopyOutput: Seq[Expression], + deleteRowOutput: Seq[Expression]): Dataset[Row] = { + def wrap(e: Expression): BaseExprMeta[Expression] = { + GpuOverrides.wrapExpr(e, rapidsConf, None) + } + + val targetRowHasNoMatchMeta = wrap(targetRowHasNoMatch) + val sourceRowHasNoMatchMeta = wrap(sourceRowHasNoMatch) + val matchedConditionsMetas = matchedConditions.map(wrap) + val matchedOutputsMetas = matchedOutputs.map(_.map(_.map(wrap))) + val notMatchedConditionsMetas = notMatchedConditions.map(wrap) + val notMatchedOutputsMetas = notMatchedOutputs.map(_.map(_.map(wrap))) + val notMatchedBySourceConditionsMetas = notMatchedBySourceConditions.map(wrap) + val notMatchedBySourceOutputsMetas = notMatchedBySourceOutputs.map(_.map(_.map(wrap))) + val noopCopyOutputMetas = noopCopyOutput.map(wrap) + val deleteRowOutputMetas = deleteRowOutput.map(wrap) + val allMetas = Seq(targetRowHasNoMatchMeta, sourceRowHasNoMatchMeta) ++ + matchedConditionsMetas ++ matchedOutputsMetas.flatten.flatten ++ + notMatchedConditionsMetas ++ notMatchedOutputsMetas.flatten.flatten ++ + notMatchedBySourceConditionsMetas ++ notMatchedBySourceOutputsMetas.flatten.flatten ++ + noopCopyOutputMetas ++ deleteRowOutputMetas + allMetas.foreach(_.tagForGpu()) + val canReplace = allMetas.forall(_.canExprTreeBeReplaced) && rapidsConf.isOperatorEnabled( + "spark.rapids.sql.exec.RapidsProcessDeltaMergeJoinExec", false, false) + if (rapidsConf.shouldExplainAll || (rapidsConf.shouldExplain && !canReplace)) { + val exprExplains = allMetas.map(_.explain(rapidsConf.shouldExplainAll)) + val execWorkInfo = if (canReplace) { + "will run on GPU" + } else { + "cannot run on GPU because not all merge processing expressions can be replaced" + } + logWarning(s" $execWorkInfo:\n" + + s" ${exprExplains.mkString(" ")}") + } + + if (canReplace) { + val processedJoinPlan = RapidsProcessDeltaMergeJoin( + joinedPlan, + outputRowSchema.toAttributes, + targetRowHasNoMatch = targetRowHasNoMatch, + sourceRowHasNoMatch = sourceRowHasNoMatch, + matchedConditions = matchedConditions, + matchedOutputs = matchedOutputs, + notMatchedConditions = notMatchedConditions, + notMatchedOutputs = notMatchedOutputs, + notMatchedBySourceConditions = notMatchedBySourceConditions, + notMatchedBySourceOutputs = notMatchedBySourceOutputs, + noopCopyOutput = noopCopyOutput, + deleteRowOutput = deleteRowOutput) + Dataset.ofRows(spark, processedJoinPlan) + } else { + val joinedRowEncoder = RowEncoder(joinedPlan.schema) + val outputRowEncoder = RowEncoder(outputRowSchema).resolveAndBind() + + val processor = new JoinedRowProcessor( + targetRowHasNoMatch = targetRowHasNoMatch, + sourceRowHasNoMatch = sourceRowHasNoMatch, + matchedConditions = matchedConditions, + matchedOutputs = matchedOutputs, + notMatchedConditions = notMatchedConditions, + notMatchedOutputs = notMatchedOutputs, + noopCopyOutput = noopCopyOutput, + deleteRowOutput = deleteRowOutput, + joinedAttributes = joinedPlan.output, + joinedRowEncoder = joinedRowEncoder, + outputRowEncoder = outputRowEncoder) + + Dataset.ofRows(spark, joinedPlan).mapPartitions(processor.processPartition)(outputRowEncoder) + } + } + + /** + * Build a new logical plan using the given `files` that has the same output columns (exprIds) + * as the `target` logical plan, so that existing update/insert expressions can be applied + * on this new plan. + */ + private def buildTargetPlanWithFiles( + deltaTxn: OptimisticTransaction, + files: Seq[AddFile]): LogicalPlan = { + val targetOutputCols = getTargetOutputCols(deltaTxn) + val targetOutputColsMap = { + val colsMap: Map[String, NamedExpression] = targetOutputCols.view + .map(col => col.name -> col).toMap + if (conf.caseSensitiveAnalysis) { + colsMap + } else { + CaseInsensitiveMap(colsMap) + } + } + + val plan = { + // We have to do surgery to use the attributes from `targetOutputCols` to scan the table. + // In cases of schema evolution, they may not be the same type as the original attributes. + val original = + deltaTxn.deltaLog.createDataFrame(deltaTxn.snapshot, files).queryExecution.analyzed + val transformed = original.transform { + case LogicalRelation(base, _, catalogTbl, isStreaming) => + LogicalRelation( + base, + // We can ignore the new columns which aren't yet AttributeReferences. + targetOutputCols.collect { case a: AttributeReference => a }, + catalogTbl, + isStreaming) + } + + // In case of schema evolution & column mapping, we would also need to rebuild the file format + // because under column mapping, the reference schema within DeltaParquetFileFormat + // that is used to populate metadata needs to be updated + if (deltaTxn.metadata.columnMappingMode != NoMapping) { + val updatedFileFormat = deltaTxn.deltaLog.fileFormat( + deltaTxn.deltaLog.unsafeVolatileSnapshot.protocol, deltaTxn.metadata) + DeltaTableUtils.replaceFileFormat(transformed, updatedFileFormat) + } else { + transformed + } + } + + // For each plan output column, find the corresponding target output column (by name) and + // create an alias + val aliases = plan.output.map { + case newAttrib: AttributeReference => + val existingTargetAttrib = targetOutputColsMap.get(newAttrib.name) + .getOrElse { + throw new AnalysisException( + s"Could not find ${newAttrib.name} among the existing target output " + + targetOutputCols.mkString(",")) + }.asInstanceOf[AttributeReference] + + if (existingTargetAttrib.exprId == newAttrib.exprId) { + // It's not valid to alias an expression to its own exprId (this is considered a + // non-unique exprId by the analyzer), so we just use the attribute directly. + newAttrib + } else { + Alias(newAttrib, existingTargetAttrib.name)(exprId = existingTargetAttrib.exprId) + } + } + + Project(aliases, plan) + } + + /** Expressions to increment SQL metrics */ + private def makeMetricUpdateUDF(name: String, deterministic: Boolean = false): Expression = { + // only capture the needed metric in a local variable + val metric = metrics(name) + var u = udf(new GpuDeltaMetricUpdateUDF(metric)) + if (!deterministic) { + u = u.asNondeterministic() + } + u.apply().expr + } + + private def getTargetOutputCols(txn: OptimisticTransaction): Seq[NamedExpression] = { + txn.metadata.schema.map { col => + targetOutputAttributesMap + .get(col.name) + .map { a => + AttributeReference(col.name, col.dataType, col.nullable)(a.exprId) + } + .getOrElse(Alias(Literal(null), col.name)() + ) + } + } + + /** + * Repartitions the output DataFrame by the partition columns if table is partitioned + * and `merge.repartitionBeforeWrite.enabled` is set to true. + */ + protected def repartitionIfNeeded( + spark: SparkSession, + df: DataFrame, + partitionColumns: Seq[String]): DataFrame = { + if (partitionColumns.nonEmpty && spark.conf.get(DeltaSQLConf.MERGE_REPARTITION_BEFORE_WRITE)) { + df.repartition(partitionColumns.map(col): _*) + } else { + df + } + } + + /** + * Execute the given `thunk` and return its result while recording the time taken to do it. + * + * @param sqlMetricName name of SQL metric to update with the time taken by the thunk + * @param thunk the code to execute + */ + private def recordMergeOperation[A](sqlMetricName: String)(thunk: => A): A = { + val startTimeNs = System.nanoTime() + val r = thunk + val timeTakenMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + if (sqlMetricName != null && timeTakenMs > 0) { + metrics(sqlMetricName) += timeTakenMs + } + r + } +} + +object GpuMergeIntoCommand { + /** + * Spark UI will track all normal accumulators along with Spark tasks to show them on Web UI. + * However, the accumulator used by `MergeIntoCommand` can store a very large value since it + * tracks all files that need to be rewritten. We should ask Spark UI to not remember it, + * otherwise, the UI data may consume lots of memory. Hence, we use the prefix `internal.metrics.` + * to make this accumulator become an internal accumulator, so that it will not be tracked by + * Spark UI. + */ + val TOUCHED_FILES_ACCUM_NAME = "internal.metrics.MergeIntoDelta.touchedFiles" + + val ROW_ID_COL = "_row_id_" + val TARGET_ROW_ID_COL = "_target_row_id_" + val SOURCE_ROW_ID_COL = "_source_row_id_" + val FILE_NAME_COL = "_file_name_" + val SOURCE_ROW_PRESENT_COL = "_source_row_present_" + val TARGET_ROW_PRESENT_COL = "_target_row_present_" + val ROW_DROPPED_COL = GpuDeltaMergeConstants.ROW_DROPPED_COL + val INCR_ROW_COUNT_COL = "_incr_row_count_" + + // Some Delta versions use Literal(null) which translates to a literal of NullType instead + // of the Literal(null, StringType) which is needed, so using a fixed version here + // rather than the version from Delta Lake. + val CDC_TYPE_NOT_CDC_LITERAL = Literal(null, StringType) + + /** + * @param targetRowHasNoMatch whether a joined row is a target row with no match in the source + * table + * @param sourceRowHasNoMatch whether a joined row is a source row with no match in the target + * table + * @param matchedConditions condition for each match clause + * @param matchedOutputs corresponding output for each match clause. for each clause, we + * have 1-3 output rows, each of which is a sequence of expressions + * to apply to the joined row + * @param notMatchedConditions condition for each not-matched clause + * @param notMatchedOutputs corresponding output for each not-matched clause. for each clause, + * we have 1-2 output rows, each of which is a sequence of + * expressions to apply to the joined row + * @param noopCopyOutput no-op expression to copy a target row to the output + * @param deleteRowOutput expression to drop a row from the final output. this is used for + * source rows that don't match any not-matched clauses + * @param joinedAttributes schema of our outer-joined dataframe + * @param joinedRowEncoder joinedDF row encoder + * @param outputRowEncoder final output row encoder + */ + class JoinedRowProcessor( + targetRowHasNoMatch: Expression, + sourceRowHasNoMatch: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Seq[Expression]]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Seq[Expression]]], + noopCopyOutput: Seq[Expression], + deleteRowOutput: Seq[Expression], + joinedAttributes: Seq[Attribute], + joinedRowEncoder: ExpressionEncoder[Row], + outputRowEncoder: ExpressionEncoder[Row]) extends Serializable { + + private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = { + UnsafeProjection.create(exprs, joinedAttributes) + } + + private def generatePredicate(expr: Expression): BasePredicate = { + GeneratePredicate.generate(expr, joinedAttributes) + } + + def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = { + + val targetRowHasNoMatchPred = generatePredicate(targetRowHasNoMatch) + val sourceRowHasNoMatchPred = generatePredicate(sourceRowHasNoMatch) + val matchedPreds = matchedConditions.map(generatePredicate) + val matchedProjs = matchedOutputs.map(_.map(generateProjection)) + val notMatchedPreds = notMatchedConditions.map(generatePredicate) + val notMatchedProjs = notMatchedOutputs.map(_.map(generateProjection)) + val noopCopyProj = generateProjection(noopCopyOutput) + val deleteRowProj = generateProjection(deleteRowOutput) + val outputProj = UnsafeProjection.create(outputRowEncoder.schema) + + // this is accessing ROW_DROPPED_COL. If ROW_DROPPED_COL is not in outputRowEncoder.schema + // then CDC must be disabled and it's the column after our output cols + def shouldDeleteRow(row: InternalRow): Boolean = { + row.getBoolean( + outputRowEncoder.schema.getFieldIndex(ROW_DROPPED_COL) + .getOrElse(outputRowEncoder.schema.fields.size) + ) + } + + def processRow(inputRow: InternalRow): Iterator[InternalRow] = { + if (targetRowHasNoMatchPred.eval(inputRow)) { + // Target row did not match any source row, so just copy it to the output + Iterator(noopCopyProj.apply(inputRow)) + } else { + // identify which set of clauses to execute: matched or not-matched ones + val (predicates, projections, noopAction) = if (sourceRowHasNoMatchPred.eval(inputRow)) { + // Source row did not match with any target row, so insert the new source row + (notMatchedPreds, notMatchedProjs, deleteRowProj) + } else { + // Source row matched with target row, so update the target row + (matchedPreds, matchedProjs, noopCopyProj) + } + + // find (predicate, projection) pair whose predicate satisfies inputRow + val pair = (predicates zip projections).find { + case (predicate, _) => predicate.eval(inputRow) + } + + pair match { + case Some((_, projections)) => + projections.map(_.apply(inputRow)).iterator + case None => Iterator(noopAction.apply(inputRow)) + } + } + } + + val toRow = joinedRowEncoder.createSerializer() + val fromRow = outputRowEncoder.createDeserializer() + rowIterator + .map(toRow) + .flatMap(processRow) + .filter(!shouldDeleteRow(_)) + .map { notDeletedInternalRow => + fromRow(outputProj(notDeletedInternalRow)) + } + } + } + + /** Count the number of distinct partition values among the AddFiles in the given set. */ + def totalBytesAndDistinctPartitionValues(files: Seq[FileAction]): (Long, Int) = { + val distinctValues = new mutable.HashSet[Map[String, String]]() + var bytes = 0L + val iter = files.collect { case a: AddFile => a }.iterator + while (iter.hasNext) { + val file = iter.next() + distinctValues += file.partitionValues + bytes += file.size + } + // If the only distinct value map is an empty map, then it must be an unpartitioned table. + // Return 0 in that case. + val numDistinctValues = + if (distinctValues.size == 1 && distinctValues.head.isEmpty) 0 else distinctValues.size + (bytes, numDistinctValues) + } +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuOptimisticTransaction.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuOptimisticTransaction.scala new file mode 100644 index 00000000000..3e836056b6d --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuOptimisticTransaction.scala @@ -0,0 +1,312 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from OptimisticTransaction.scala and TransactionalWrite.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import java.net.URI + +import scala.collection.mutable.ListBuffer + +import com.databricks.sql.transaction.tahoe._ +import com.databricks.sql.transaction.tahoe.actions.{AddFile, FileAction} +import com.databricks.sql.transaction.tahoe.constraints.{Constraint, Constraints} +import com.databricks.sql.transaction.tahoe.schema.InvariantViolationException +import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.delta._ +import org.apache.commons.lang3.exception.ExceptionUtils +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormatWriter} +import org.apache.spark.sql.functions.to_json +import org.apache.spark.sql.rapids.{BasicColumnarWriteJobStatsTracker, ColumnarWriteJobStatsTracker, GpuFileFormatWriter, GpuWriteJobStatsTracker} +import org.apache.spark.sql.rapids.delta.GpuIdentityColumn +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.{Clock, SerializableConfiguration} + +/** + * Used to perform a set of reads in a transaction and then commit a set of updates to the + * state of the log. All reads from the DeltaLog, MUST go through this instance rather + * than directly to the DeltaLog otherwise they will not be check for logical conflicts + * with concurrent updates. + * + * This class is not thread-safe. + * + * @param deltaLog The Delta Log for the table this transaction is modifying. + * @param snapshot The snapshot that this transaction is reading at. + * @param rapidsConf RAPIDS Accelerator config settings. + */ +class GpuOptimisticTransaction( + deltaLog: DeltaLog, + snapshot: Snapshot, + rapidsConf: RapidsConf)(implicit clock: Clock) + extends GpuOptimisticTransactionBase(deltaLog, snapshot, rapidsConf)(clock) { + + /** Creates a new OptimisticTransaction. + * + * @param deltaLog The Delta Log for the table this transaction is modifying. + * @param rapidsConf RAPIDS Accelerator config settings + */ + def this(deltaLog: DeltaLog, rapidsConf: RapidsConf)(implicit clock: Clock) = { + this(deltaLog, deltaLog.update(), rapidsConf) + } + + private def getGpuStatsColExpr( + statsDataSchema: Seq[Attribute], + statsCollection: GpuStatisticsCollection): Expression = { + Dataset.ofRows(spark, LocalRelation(statsDataSchema)) + .select(to_json(statsCollection.statsCollector)) + .queryExecution.analyzed.expressions.head + } + + /** Return the pair of optional stats tracker and stats collection class */ + private def getOptionalGpuStatsTrackerAndStatsCollection( + output: Seq[Attribute], + partitionSchema: StructType, data: DataFrame): ( + Option[GpuDeltaJobStatisticsTracker], + Option[GpuStatisticsCollection]) = { + if (spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_COLLECT_STATS)) { + + val (statsDataSchema, statsCollectionSchema) = getStatsSchema(output, partitionSchema) + + val indexedCols = DeltaConfigs.DATA_SKIPPING_NUM_INDEXED_COLS.fromMetaData(metadata) + val prefixLength = + spark.sessionState.conf.getConf(DeltaSQLConf.DATA_SKIPPING_STRING_PREFIX_LENGTH) + val tableSchema = { + // If collecting stats using the table schema, then pass in statsCollectionSchema. + // Otherwise pass in statsDataSchema to collect stats using the DataFrame schema. + if (spark.sessionState.conf.getConf(DeltaSQLConf + .DELTA_COLLECT_STATS_USING_TABLE_SCHEMA)) { + statsCollectionSchema.toStructType + } else { + statsDataSchema.toStructType + } + } + + val _spark = spark + val protocol = deltaLog.unsafeVolatileSnapshot.protocol + + val statsCollection = new GpuStatisticsCollection { + override val spark = _spark + override val deletionVectorsSupported = + protocol.isFeatureSupported(DeletionVectorsTableFeature) + override val tableDataSchema = tableSchema + override val dataSchema = statsDataSchema.toStructType + override val numIndexedCols = indexedCols + override val stringPrefixLength: Int = prefixLength + } + + val statsColExpr = getGpuStatsColExpr(statsDataSchema, statsCollection) + + val statsSchema = statsCollection.statCollectionSchema + val explodedDataSchema = statsCollection.explodedDataSchema + val batchStatsToRow = (batch: ColumnarBatch, row: InternalRow) => { + GpuStatisticsCollection.batchStatsToRow(statsSchema, explodedDataSchema, batch, row) + } + (Some(new GpuDeltaJobStatisticsTracker(statsDataSchema, statsColExpr, batchStatsToRow)), + Some(statsCollection)) + } else { + (None, None) + } + } + + override def writeFiles( + inputData: Dataset[_], + writeOptions: Option[DeltaOptions], + additionalConstraints: Seq[Constraint]): Seq[FileAction] = { + hasWritten = true + + val spark = inputData.sparkSession + val (data, partitionSchema) = performCDCPartition(inputData) + val outputPath = deltaLog.dataPath + + val (normalizedQueryExecution, output, generatedColumnConstraints, dataHighWaterMarks) = { + // TODO: is none ok to pass here? + normalizeData(deltaLog, None, data) + } + val highWaterMarks = trackHighWaterMarks.getOrElse(dataHighWaterMarks) + + // Build a new plan with a stub GpuDeltaWrite node to work around undesired transitions between + // columns and rows when AQE is involved. Without this node in the plan, AdaptiveSparkPlanExec + // could be the root node of the plan. In that case we do not have enough context to know + // whether the AdaptiveSparkPlanExec should be columnar or not, since the GPU overrides do not + // see how the parent is using the AdaptiveSparkPlanExec outputs. By using this stub node that + // appears to be a data writing node to AQE (it derives from V2CommandExec), the + // AdaptiveSparkPlanExec will be planned as a child of this new node. That provides enough + // context to plan the AQE sub-plan properly with respect to columnar and row transitions. + // We could force the AQE node to be columnar here by explicitly replacing the node, but that + // breaks the connection between the queryExecution and the node that will actually execute. + val gpuWritePlan = Dataset.ofRows(spark, RapidsDeltaWrite(normalizedQueryExecution.logical)) + val queryExecution = gpuWritePlan.queryExecution + + val partitioningColumns = getPartitioningColumns(partitionSchema, output) + + val committer = getCommitter(outputPath) + + // If Statistics Collection is enabled, then create a stats tracker that will be injected during + // the FileFormatWriter.write call below and will collect per-file stats using + // StatisticsCollection + val (optionalStatsTracker, _) = getOptionalGpuStatsTrackerAndStatsCollection(output, + partitionSchema, data) + + // schema should be normalized, therefore we can do an equality check + val (statsDataSchema, _) = getStatsSchema(output, partitionSchema) + val identityTracker = GpuIdentityColumn.createIdentityColumnStatsTracker( + spark, + statsDataSchema, + metadata.schema, + highWaterMarks) + + val constraints = + Constraints.getAll(metadata, spark) ++ generatedColumnConstraints ++ additionalConstraints + + val isOptimize = isOptimizeCommand(queryExecution.analyzed) + + SQLExecution.withNewExecutionId(queryExecution, Option("deltaTransactionalWrite")) { + val outputSpec = FileFormatWriter.OutputSpec( + outputPath.toString, + Map.empty, + output) + + // Remove any unnecessary row conversions added as part of Spark planning + val queryPhysicalPlan = queryExecution.executedPlan match { + case GpuColumnarToRowExec(child, _) => child + case p => p + } + val gpuRapidsWrite = queryPhysicalPlan match { + case g: GpuRapidsDeltaWriteExec => Some(g) + case _ => None + } + + val empty2NullPlan = convertEmptyToNullIfNeeded(queryPhysicalPlan, + partitioningColumns, constraints) + val optimizedPlan = + applyOptimizeWriteIfNeeded(spark, empty2NullPlan, partitionSchema, isOptimize) + val planWithInvariants = addInvariantChecks(optimizedPlan, constraints) + val physicalPlan = convertToGpu(planWithInvariants) + + val statsTrackers: ListBuffer[ColumnarWriteJobStatsTracker] = ListBuffer() + + val hadoopConf = spark.sessionState.newHadoopConfWithOptions( + metadata.configuration ++ deltaLog.options) + + if (spark.conf.get(DeltaSQLConf.DELTA_HISTORY_METRICS_ENABLED)) { + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + val basicWriteJobStatsTracker = new BasicColumnarWriteJobStatsTracker( + serializableHadoopConf, + BasicWriteJobStatsTracker.metrics) + registerSQLMetrics(spark, basicWriteJobStatsTracker.driverSideMetrics) + statsTrackers.append(basicWriteJobStatsTracker) + gpuRapidsWrite.foreach { grw => + val tracker = new GpuWriteJobStatsTracker(serializableHadoopConf, + grw.basicMetrics, grw.taskMetrics) + statsTrackers.append(tracker) + } + } + + // Retain only a minimal selection of Spark writer options to avoid any potential + // compatibility issues + val options = writeOptions match { + case None => Map.empty[String, String] + case Some(writeOptions) => + writeOptions.options.filterKeys { key => + key.equalsIgnoreCase(DeltaOptions.MAX_RECORDS_PER_FILE) || + key.equalsIgnoreCase(DeltaOptions.COMPRESSION) + }.toMap + } + val deltaFileFormat = deltaLog.fileFormat(deltaLog.unsafeVolatileSnapshot.protocol, metadata) + val gpuFileFormat = if (deltaFileFormat.getClass == classOf[DeltaParquetFileFormat]) { + new GpuParquetFileFormat + } else { + throw new IllegalStateException(s"file format $deltaFileFormat is not supported") + } + + try { + logDebug(s"Physical plan for write:\n$physicalPlan") + GpuFileFormatWriter.write( + sparkSession = spark, + plan = physicalPlan, + fileFormat = gpuFileFormat, + committer = committer, + outputSpec = outputSpec, + hadoopConf = hadoopConf, + partitionColumns = partitioningColumns, + bucketSpec = None, + statsTrackers = optionalStatsTracker.toSeq ++ identityTracker.toSeq ++ statsTrackers, + options = options, + rapidsConf.stableSort, + rapidsConf.concurrentWriterPartitionFlushSize) + } catch { + case s: SparkException => + // Pull an InvariantViolationException up to the top level if it was the root cause. + val violationException = ExceptionUtils.getRootCause(s) + if (violationException.isInstanceOf[InvariantViolationException]) { + throw violationException + } else { + throw s + } + } + } + + val resultFiles = committer.addedStatuses.map { a => + a.copy(stats = optionalStatsTracker.map( + _.recordedStats(new Path(new URI(a.path)).getName)).getOrElse(a.stats)) + }.filter { + // In some cases, we can write out an empty `inputData`. Some examples of this (though, they + // may be fixed in the future) are the MERGE command when you delete with empty source, or + // empty target, or on disjoint tables. This is hard to catch before the write without + // collecting the DF ahead of time. Instead, we can return only the AddFiles that + // a) actually add rows, or + // b) don't have any stats so we don't know the number of rows at all + case a: AddFile => a.numLogicalRecords.forall(_ > 0) + case _ => true + } + + identityTracker.foreach { tracker => + updatedIdentityHighWaterMarks.appendAll(tracker.highWaterMarks.toSeq) + } + val fileActions = resultFiles.toSeq ++ committer.changeFiles + + // Check if auto-compaction is enabled. + // (Auto compaction checks are derived from the work in + // https://github.com/delta-io/delta/pull/1156). + lazy val autoCompactEnabled = + spark.sessionState.conf + .getConf[String](DeltaSQLConf.DELTA_AUTO_COMPACT_ENABLED) + .getOrElse { + DeltaConfigs.AUTO_COMPACT.fromMetaData(metadata) + .getOrElse("false") + }.toBoolean + + if (!isOptimize && autoCompactEnabled && fileActions.nonEmpty) { + registerPostCommitHook(GpuDoAutoCompaction) + } + + fileActions + } +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuOptimizeExecutor.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuOptimizeExecutor.scala new file mode 100644 index 00000000000..df619c3b1a7 --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuOptimizeExecutor.scala @@ -0,0 +1,405 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from: + * 1. DoAutoCompaction.scala from PR#1156 at https://github.com/delta-io/delta/pull/1156, + * 2. OptimizeTableCommand.scala from the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.databricks.sql.transaction.tahoe.rapids + +import java.util.ConcurrentModificationException + +import scala.annotation.tailrec +import scala.collection.mutable.ArrayBuffer + +import com.databricks.sql.io.skipping.MultiDimClustering +import com.databricks.sql.transaction.tahoe._ +import com.databricks.sql.transaction.tahoe.DeltaOperations.Operation +import com.databricks.sql.transaction.tahoe.actions.{Action, AddFile, FileAction, RemoveFile} +import com.databricks.sql.transaction.tahoe.commands.DeltaCommand +import com.databricks.sql.transaction.tahoe.commands.optimize._ +import com.databricks.sql.transaction.tahoe.files.SQLMetricsReporting +import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf +import com.nvidia.spark.rapids.delta.RapidsDeltaSQLConf + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext.SPARK_JOB_GROUP_ID +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.metric.SQLMetrics.createMetric +import org.apache.spark.util.ThreadUtils + +class GpuOptimizeExecutor( + sparkSession: SparkSession, + txn: OptimisticTransaction, + partitionPredicate: Seq[Expression], + zOrderByColumns: Seq[String], + prevCommitActions: Seq[Action]) + extends DeltaCommand with SQLMetricsReporting with Serializable { + + /** Timestamp to use in [[FileAction]] */ + private val operationTimestamp = System.currentTimeMillis + + private val isMultiDimClustering = zOrderByColumns.nonEmpty + private val isAutoCompact = prevCommitActions.nonEmpty + private val optimizeType = GpuOptimizeType(isMultiDimClustering, isAutoCompact) + + def optimize(): Seq[Row] = { + recordDeltaOperation(txn.deltaLog, "delta.optimize") { + val maxFileSize = optimizeType.maxFileSize + require(maxFileSize > 0, "maxFileSize must be > 0") + + val minNumFilesInDir = optimizeType.minNumFiles + val (candidateFiles, filesToProcess) = optimizeType.targetFiles + val partitionSchema = txn.metadata.partitionSchema + + // select all files in case of multi-dimensional clustering + val partitionsToCompact = filesToProcess + .groupBy(_.partitionValues) + .filter { case (_, filesInPartition) => filesInPartition.size >= minNumFilesInDir } + .toSeq + + val groupedJobs = groupFilesIntoBins(partitionsToCompact, maxFileSize) + val jobs = optimizeType.targetBins(groupedJobs) + + val maxThreads = + sparkSession.sessionState.conf.getConf(DeltaSQLConf.DELTA_OPTIMIZE_MAX_THREADS) + val updates = ThreadUtils.parmap(jobs, "OptimizeJob", maxThreads) { partitionBinGroup => + runOptimizeBinJob(txn, partitionBinGroup._1, partitionBinGroup._2, maxFileSize) + }.flatten + + val addedFiles = updates.collect { case a: AddFile => a } + val removedFiles = updates.collect { case r: RemoveFile => r } + if (addedFiles.nonEmpty) { + val operation = DeltaOperations.Optimize(partitionPredicate, zOrderByColumns) + val metrics = createMetrics(sparkSession.sparkContext, addedFiles, removedFiles) + commitAndRetry(txn, operation, updates, metrics) { newTxn => + val newPartitionSchema = newTxn.metadata.partitionSchema + val candidateSetOld = candidateFiles.map(_.path).toSet + val candidateSetNew = newTxn.filterFiles(partitionPredicate).map(_.path).toSet + + // As long as all of the files that we compacted are still part of the table, + // and the partitioning has not changed it is valid to continue to try + // and commit this checkpoint. + if (candidateSetOld.subsetOf(candidateSetNew) && partitionSchema == newPartitionSchema) { + true + } else { + val deleted = candidateSetOld -- candidateSetNew + logWarning(s"The following compacted files were delete " + + s"during checkpoint ${deleted.mkString(",")}. Aborting the compaction.") + false + } + } + } + + val optimizeStats = OptimizeStats() + optimizeStats.addedFilesSizeStats.merge(addedFiles) + optimizeStats.removedFilesSizeStats.merge(removedFiles) + optimizeStats.numPartitionsOptimized = jobs.map(j => j._1).distinct.size + optimizeStats.numBatches = jobs.size + optimizeStats.totalConsideredFiles = candidateFiles.size + optimizeStats.totalFilesSkipped = optimizeStats.totalConsideredFiles - removedFiles.size + optimizeStats.totalClusterParallelism = sparkSession.sparkContext.defaultParallelism + + if (isMultiDimClustering) { + val inputFileStats = + ZOrderFileStats(removedFiles.size, removedFiles.map(_.size.getOrElse(0L)).sum) + optimizeStats.zOrderStats = Some(ZOrderStats( + strategyName = "all", // means process all files in a partition + inputCubeFiles = ZOrderFileStats(0, 0), + inputOtherFiles = inputFileStats, + inputNumCubes = 0, + mergedFiles = inputFileStats, + // There will one z-cube for each partition + numOutputCubes = optimizeStats.numPartitionsOptimized)) + } + + return Seq(Row(txn.deltaLog.dataPath.toString, optimizeStats.toOptimizeMetrics)) + } + } + + /** + * Utility methods to group files into bins for optimize. + * + * @param partitionsToCompact List of files to compact group by partition. + * Partition is defined by the partition values (partCol -> partValue) + * @param maxTargetFileSize Max size (in bytes) of the compaction output file. + * @return Sequence of bins. Each bin contains one or more files from the same + * partition and targeted for one output file. + */ + private def groupFilesIntoBins( + partitionsToCompact: Seq[(Map[String, String], Seq[AddFile])], + maxTargetFileSize: Long): Seq[(Map[String, String], Seq[AddFile])] = { + + partitionsToCompact.flatMap { + case (partition, files) => + val bins = new ArrayBuffer[Seq[AddFile]]() + + val currentBin = new ArrayBuffer[AddFile]() + var currentBinSize = 0L + + files.sortBy(_.size).foreach { file => + // Generally, a bin is a group of existing files, whose total size does not exceed the + // desired maxFileSize. They will be coalesced into a single output file. + // However, if isMultiDimClustering = true, all files in a partition will be read by the + // same job, the data will be range-partitioned and numFiles = totalFileSize / maxFileSize + // will be produced. See below. + if (file.size + currentBinSize > maxTargetFileSize && !isMultiDimClustering) { + bins += currentBin.toVector + currentBin.clear() + currentBin += file + currentBinSize = file.size + } else { + currentBin += file + currentBinSize += file.size + } + } + + if (currentBin.nonEmpty) { + bins += currentBin.toVector + } + + bins.map(b => (partition, b)) + // select bins that have at least two files or in case of multi-dim clustering + // select all bins + .filter(_._2.size > 1 || isMultiDimClustering) + } + } + + /** + * Utility method to run a Spark job to compact the files in given bin + * + * @param txn [[OptimisticTransaction]] instance in use to commit the changes to DeltaLog. + * @param partition Partition values of the partition that files in [[bin]] belongs to. + * @param bin List of files to compact into one large file. + * @param maxFileSize Targeted output file size in bytes + */ + private def runOptimizeBinJob( + txn: OptimisticTransaction, + partition: Map[String, String], + bin: Seq[AddFile], + maxFileSize: Long): Seq[FileAction] = { + val baseTablePath = txn.deltaLog.dataPath + + val input = txn.deltaLog.createDataFrame(txn.snapshot, bin, actionTypeOpt = Some("Optimize")) + val repartitionDF = if (isMultiDimClustering) { + val totalSize = bin.map(_.size).sum + val approxNumFiles = Math.max(1, totalSize / maxFileSize).toInt + MultiDimClustering.cluster( + input, + approxNumFiles, + zOrderByColumns) + } else { + val useRepartition = sparkSession.sessionState.conf.getConf( + DeltaSQLConf.DELTA_OPTIMIZE_REPARTITION_ENABLED) + if (useRepartition) { + input.repartition(numPartitions = 1) + } else { + input.coalesce(numPartitions = 1) + } + } + + val partitionDesc = partition.toSeq.map(entry => entry._1 + "=" + entry._2).mkString(",") + + val partitionName = if (partition.isEmpty) "" else s" in partition ($partitionDesc)" + val description = s"$baseTablePath
Optimizing ${bin.size} files" + partitionName + sparkSession.sparkContext.setJobGroup( + sparkSession.sparkContext.getLocalProperty(SPARK_JOB_GROUP_ID), + description) + + val addFiles = txn.writeFiles(repartitionDF).collect { + case a: AddFile => + a.copy(dataChange = false) + case other => + throw new IllegalStateException( + s"Unexpected action $other with type ${other.getClass}. File compaction job output" + + s"should only have AddFiles") + } + val removeFiles = bin.map(f => f.removeWithTimestamp(operationTimestamp, dataChange = false)) + val updates = addFiles ++ removeFiles + updates + } + + private type PartitionedBin = (Map[String, String], Seq[AddFile]) + + private trait GpuOptimizeType { + def minNumFiles: Long + + def maxFileSize: Long = + sparkSession.sessionState.conf.getConf(DeltaSQLConf.DELTA_OPTIMIZE_MAX_FILE_SIZE) + + def targetFiles: (Seq[AddFile], Seq[AddFile]) + + def targetBins(jobs: Seq[PartitionedBin]): Seq[PartitionedBin] = jobs + } + + private case class GpuCompaction() extends GpuOptimizeType { + def minNumFiles: Long = 2 + + def targetFiles: (Seq[AddFile], Seq[AddFile]) = { + val minFileSize = sparkSession.sessionState.conf.getConf( + DeltaSQLConf.DELTA_OPTIMIZE_MIN_FILE_SIZE) + require(minFileSize > 0, "minFileSize must be > 0") + val candidateFiles = txn.filterFiles(partitionPredicate) + val filesToProcess = candidateFiles.filter(_.size < minFileSize) + (candidateFiles, filesToProcess) + } + } + + private case class GpuMultiDimOrdering() extends GpuOptimizeType { + def minNumFiles: Long = 1 + + def targetFiles: (Seq[AddFile], Seq[AddFile]) = { + // select all files in case of multi-dimensional clustering + val candidateFiles = txn.filterFiles(partitionPredicate) + (candidateFiles, candidateFiles) + } + } + + private case class GpuAutoCompaction() extends GpuOptimizeType { + def minNumFiles: Long = { + val minNumFiles = + sparkSession.sessionState.conf.getConf(DeltaSQLConf.DELTA_AUTO_COMPACT_MIN_NUM_FILES) + require(minNumFiles > 0, "minNumFiles must be > 0") + minNumFiles + } + + override def maxFileSize: Long = + sparkSession.sessionState.conf.getConf(DeltaSQLConf.DELTA_AUTO_COMPACT_MAX_FILE_SIZE) + .getOrElse(128 * 1024 * 1024) + + override def targetFiles: (Seq[AddFile], Seq[AddFile]) = { + val autoCompactTarget = + sparkSession.sessionState.conf.getConf(RapidsDeltaSQLConf.AUTO_COMPACT_TARGET) + // Filter the candidate files according to autoCompact.target config. + lazy val addedFiles = prevCommitActions.collect { case a: AddFile => a } + val candidateFiles = autoCompactTarget match { + case "table" => + txn.filterFiles() + case "commit" => + addedFiles + case "partition" => + val eligiblePartitions = addedFiles.map(_.partitionValues).toSet + txn.filterFiles().filter(f => eligiblePartitions.contains(f.partitionValues)) + case _ => + logError(s"Invalid config for autoCompact.target: $autoCompactTarget. " + + s"Falling back to the default value 'table'.") + txn.filterFiles() + } + val filesToProcess = candidateFiles.filter(_.size < maxFileSize) + (candidateFiles, filesToProcess) + } + + override def targetBins(jobs: Seq[PartitionedBin]): Seq[PartitionedBin] = { + var acc = 0L + val maxCompactBytes = + sparkSession.sessionState.conf.getConf(RapidsDeltaSQLConf.AUTO_COMPACT_MAX_COMPACT_BYTES) + // bins with more files are prior to less files. + jobs + .sortBy { case (_, filesInBin) => -filesInBin.length } + .takeWhile { case (_, filesInBin) => + acc += filesInBin.map(_.size).sum + acc <= maxCompactBytes + } + } + } + + private object GpuOptimizeType { + + def apply(isMultiDimClustering: Boolean, isAutoCompact: Boolean): GpuOptimizeType = { + if (isMultiDimClustering) { + GpuMultiDimOrdering() + } else if (isAutoCompact) { + GpuAutoCompaction() + } else { + GpuCompaction() + } + } + } + + /** + * Attempts to commit the given actions to the log. In the case of a concurrent update, + * the given function will be invoked with a new transaction to allow custom conflict + * detection logic to indicate it is safe to try again, by returning `true`. + * + * This function will continue to try to commit to the log as long as `f` returns `true`, + * otherwise throws a subclass of [[ConcurrentModificationException]]. + */ + @tailrec + private def commitAndRetry( + txn: OptimisticTransaction, + optimizeOperation: Operation, + actions: Seq[Action], + metrics: Map[String, SQLMetric])(f: OptimisticTransaction => Boolean) + : Unit = { + try { + txn.registerSQLMetrics(sparkSession, metrics) + txn.commit(actions, optimizeOperation) + } catch { + case e: ConcurrentModificationException => + val newTxn = txn.deltaLog.startTransaction() + if (f(newTxn)) { + logInfo("Retrying commit after checking for semantic conflicts with concurrent updates.") + commitAndRetry(newTxn, optimizeOperation, actions, metrics)(f) + } else { + logWarning("Semantic conflicts detected. Aborting operation.") + throw e + } + } + } + + /** Create a map of SQL metrics for adding to the commit history. */ + private def createMetrics( + sparkContext: SparkContext, + addedFiles: Seq[AddFile], + removedFiles: Seq[RemoveFile]): Map[String, SQLMetric] = { + + def setAndReturnMetric(description: String, value: Long) = { + val metric = createMetric(sparkContext, description) + metric.set(value) + metric + } + + def totalSize(actions: Seq[FileAction]): Long = { + var totalSize = 0L + actions.foreach { file => + val fileSize = file match { + case addFile: AddFile => addFile.size + case removeFile: RemoveFile => removeFile.size.getOrElse(0L) + case default => + throw new IllegalArgumentException(s"Unknown FileAction type: ${default.getClass}") + } + totalSize += fileSize + } + totalSize + } + + val sizeStats = FileSizeStatsWithHistogram.create(addedFiles.map(_.size).sorted) + Map[String, SQLMetric]( + "minFileSize" -> setAndReturnMetric("minimum file size", sizeStats.get.min), + "p25FileSize" -> setAndReturnMetric("25th percentile file size", sizeStats.get.p25), + "p50FileSize" -> setAndReturnMetric("50th percentile file size", sizeStats.get.p50), + "p75FileSize" -> setAndReturnMetric("75th percentile file size", sizeStats.get.p75), + "maxFileSize" -> setAndReturnMetric("maximum file size", sizeStats.get.max), + "numAddedFiles" -> setAndReturnMetric("total number of files added.", addedFiles.size), + "numRemovedFiles" -> setAndReturnMetric("total number of files removed.", removedFiles.size), + "numAddedBytes" -> setAndReturnMetric("total number of bytes added", totalSize(addedFiles)), + "numRemovedBytes" -> + setAndReturnMetric("total number of bytes removed", totalSize(removedFiles))) + } +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuUpdateCommand.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuUpdateCommand.scala new file mode 100644 index 00000000000..531e34e9f3c --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuUpdateCommand.scala @@ -0,0 +1,276 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from UpdateCommand.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import com.databricks.sql.transaction.tahoe.{DeltaLog, DeltaOperations, DeltaTableUtils, DeltaUDF, OptimisticTransaction} +import com.databricks.sql.transaction.tahoe.DeltaCommitTag._ +import com.databricks.sql.transaction.tahoe.RowTracking +import com.databricks.sql.transaction.tahoe.actions.{AddCDCFile, AddFile, FileAction} +import com.databricks.sql.transaction.tahoe.commands.{DeltaCommand, DMLUtils, UpdateCommand, UpdateMetric} +import com.databricks.sql.transaction.tahoe.files.{TahoeBatchFileIndex, TahoeFileIndex} +import com.nvidia.spark.rapids.delta.GpuDeltaMetricUpdateUDF +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{Column, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetrics.{createMetric, createTimingMetric} +import org.apache.spark.sql.functions.input_file_name +import org.apache.spark.sql.types.LongType + +case class GpuUpdateCommand( + gpuDeltaLog: GpuDeltaLog, + tahoeFileIndex: TahoeFileIndex, + target: LogicalPlan, + updateExpressions: Seq[Expression], + condition: Option[Expression]) + extends LeafRunnableCommand with DeltaCommand { + + override val output: Seq[Attribute] = { + Seq(AttributeReference("num_affected_rows", LongType)()) + } + + override def innerChildren: Seq[QueryPlan[_]] = Seq(target) + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + + override lazy val metrics = Map[String, SQLMetric]( + "numAddedFiles" -> createMetric(sc, "number of files added."), + "numRemovedFiles" -> createMetric(sc, "number of files removed."), + "numUpdatedRows" -> createMetric(sc, "number of rows updated."), + "numCopiedRows" -> createMetric(sc, "number of rows copied."), + "executionTimeMs" -> + createTimingMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> + createTimingMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> + createTimingMetric(sc, "time taken to rewrite the matched files"), + "numAddedChangeFiles" -> createMetric(sc, "number of change data capture files generated"), + "changeFileBytes" -> createMetric(sc, "total size of change data capture files generated"), + "numTouchedRows" -> createMetric(sc, "number of rows touched (copied + updated)"), + "numDeletionVectorsAdded" -> createMetric(sc, "number of deletion vectors added."), + "numDeletionVectorsRemoved" -> createMetric(sc, "number of deletion vectors removed.") + ) + + final override def run(sparkSession: SparkSession): Seq[Row] = { + recordDeltaOperation(tahoeFileIndex.deltaLog, "delta.dml.update") { + val deltaLog = tahoeFileIndex.deltaLog + gpuDeltaLog.withNewTransaction { txn => + DeltaLog.assertRemovable(txn.snapshot) + performUpdate(sparkSession, deltaLog, txn) + } + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to + // this data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, target) + } + Seq(Row(metrics("numUpdatedRows").value)) + } + + private def performUpdate( + sparkSession: SparkSession, deltaLog: DeltaLog, txn: OptimisticTransaction): Unit = { + import com.databricks.sql.transaction.tahoe.implicits._ + + var numTouchedFiles: Long = 0 + var numRewrittenFiles: Long = 0 + var numAddedChangeFiles: Long = 0 + var changeFileBytes: Long = 0 + var scanTimeMs: Long = 0 + var rewriteTimeMs: Long = 0 + + val startTime = System.nanoTime() + val numFilesTotal = txn.snapshot.numOfFiles + + val updateCondition = condition.getOrElse(Literal.TrueLiteral) + val (metadataPredicates, dataPredicates) = + DeltaTableUtils.splitMetadataAndDataPredicates( + updateCondition, txn.metadata.partitionColumns, sparkSession) + val candidateFiles = txn.filterFiles(metadataPredicates ++ dataPredicates) + val nameToAddFile = generateCandidateFileMap(deltaLog.dataPath, candidateFiles) + + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + + val filesToRewrite: Seq[AddFile] = if (candidateFiles.isEmpty) { + // Case 1: Do nothing if no row qualifies the partition predicates + // that are part of Update condition + Nil + } else if (dataPredicates.isEmpty) { + // Case 2: Update all the rows from the files that are in the specified partitions + // when the data filter is empty + candidateFiles + } else { + // Case 3: Find all the affected files using the user-specified condition + val fileIndex = new TahoeBatchFileIndex( + sparkSession, "update", candidateFiles, deltaLog, tahoeFileIndex.path, txn.snapshot) + // Keep everything from the resolved target except a new TahoeFileIndex + // that only involves the affected files instead of all files. + val newTarget = DeltaTableUtils.replaceFileIndex(target, fileIndex) + val data = Dataset.ofRows(sparkSession, newTarget) + val updatedRowCount = metrics("numUpdatedRows") + val updatedRowUdf = DeltaUDF.boolean { + new GpuDeltaMetricUpdateUDF(updatedRowCount) + }.asNondeterministic() + val pathsToRewrite = + withStatusCode("DELTA", UpdateCommand.FINDING_TOUCHED_FILES_MSG) { + data.filter(new Column(updateCondition)) + .select(input_file_name()) + .filter(updatedRowUdf()) + .distinct() + .as[String] + .collect() + } + + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + + pathsToRewrite.map(getTouchedFile(deltaLog.dataPath, _, nameToAddFile)).toSeq + } + + numTouchedFiles = filesToRewrite.length + + val newActions = if (filesToRewrite.isEmpty) { + // Do nothing if no row qualifies the UPDATE condition + Nil + } else { + // Generate the new files containing the updated values + withStatusCode("DELTA", UpdateCommand.rewritingFilesMsg(filesToRewrite.size)) { + rewriteFiles(sparkSession, txn, tahoeFileIndex.path, + filesToRewrite.map(_.path), nameToAddFile, updateCondition) + } + } + + rewriteTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 - scanTimeMs + + val (changeActions, addActions) = newActions.partition(_.isInstanceOf[AddCDCFile]) + numRewrittenFiles = addActions.size + numAddedChangeFiles = changeActions.size + changeFileBytes = changeActions.collect { case f: AddCDCFile => f.size }.sum + + val totalActions = if (filesToRewrite.isEmpty) { + // Do nothing if no row qualifies the UPDATE condition + Nil + } else { + // Delete the old files and return those delete actions along with the new AddFile actions for + // files containing the updated values + val operationTimestamp = System.currentTimeMillis() + val deleteActions = filesToRewrite.map(_.removeWithTimestamp(operationTimestamp)) + + deleteActions ++ newActions + } + + if (totalActions.nonEmpty) { + metrics("numAddedFiles").set(numRewrittenFiles) + metrics("numAddedChangeFiles").set(numAddedChangeFiles) + metrics("changeFileBytes").set(changeFileBytes) + metrics("numRemovedFiles").set(numTouchedFiles) + metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000) + metrics("scanTimeMs").set(scanTimeMs) + metrics("rewriteTimeMs").set(rewriteTimeMs) + // In the case where the numUpdatedRows is not captured, we can siphon out the metrics from + // the BasicWriteStatsTracker. This is for case 2 where the update condition contains only + // metadata predicates and so the entire partition is re-written. + val outputRows = txn.getMetric("numOutputRows").map(_.value).getOrElse(-1L) + if (metrics("numUpdatedRows").value == 0 && outputRows != 0 && + metrics("numCopiedRows").value == 0) { + // We know that numTouchedRows = numCopiedRows + numUpdatedRows. + // Since an entire partition was re-written, no rows were copied. + // So numTouchedRows == numUpdateRows + metrics("numUpdatedRows").set(metrics("numTouchedRows").value) + } else { + // This is for case 3 where the update condition contains both metadata and data predicates + // so relevant files will have some rows updated and some rows copied. We don't need to + // consider case 1 here, where no files match the update condition, as we know that + // `totalActions` is empty. + metrics("numCopiedRows").set( + metrics("numTouchedRows").value - metrics("numUpdatedRows").value) + } + metrics("numDeletionVectorsAdded").set(0) + metrics("numDeletionVectorsRemoved").set(0) + txn.registerSQLMetrics(sparkSession, metrics) + val tags = DMLUtils.TaggedCommitData.EMPTY + .withTag(PreservedRowTrackingTag, RowTracking.isEnabled(txn.protocol, txn.metadata)) + .withTag(NoRowsCopiedTag, metrics("numCopiedRows").value == 0) + txn.commitIfNeeded(totalActions, DeltaOperations.Update(condition), tags.stringTags) + // This is needed to make the SQL metrics visible in the Spark UI + val executionId = sparkSession.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkSession.sparkContext, executionId, metrics.values.toSeq) + } + + recordDeltaEvent( + deltaLog, + "delta.dml.update.stats", + data = UpdateMetric( + condition = condition.map(_.sql).getOrElse("true"), + numFilesTotal, + numTouchedFiles, + numRewrittenFiles, + numAddedChangeFiles, + changeFileBytes, + scanTimeMs, + rewriteTimeMs) + ) + } + + /** + * Scan all the affected files and write out the updated files. + * + * When CDF is enabled, includes the generation of CDC preimage and postimage columns for + * changed rows. + * + * @return the list of [[AddFile]]s and [[AddCDCFile]]s that have been written. + */ + private def rewriteFiles( + spark: SparkSession, + txn: OptimisticTransaction, + rootPath: Path, + inputLeafFiles: Seq[String], + nameToAddFileMap: Map[String, AddFile], + condition: Expression): Seq[FileAction] = { + // Containing the map from the relative file path to AddFile + val baseRelation = buildBaseRelation( + spark, txn, "update", rootPath, inputLeafFiles, nameToAddFileMap) + val newTarget = DeltaTableUtils.replaceFileIndex(target, baseRelation.location) + val targetDf = Dataset.ofRows(spark, newTarget) + + // Number of total rows that we have seen, i.e. are either copying or updating (sum of both). + // This will be used later, along with numUpdatedRows, to determine numCopiedRows. + val numTouchedRows = metrics("numTouchedRows") + val numTouchedRowsUdf = DeltaUDF.boolean { + new GpuDeltaMetricUpdateUDF(numTouchedRows) + }.asNondeterministic() + + val updatedDataFrame = UpdateCommand.withUpdatedColumns( + target.output, + updateExpressions, + condition, + targetDf + .filter(numTouchedRowsUdf()) + .withColumn(UpdateCommand.CONDITION_COLUMN_NAME, new Column(condition)), + UpdateCommand.shouldOutputCdc(txn)) + + txn.writeFiles(updatedDataFrame) + } +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala new file mode 100644 index 00000000000..32b7656b45b --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProbe.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.delta + +/** + * Implements the Delta Probe interface for probing the Delta Lake provider on Databricks. + * @note This is instantiated via reflection from ShimLoader. + */ +class DeltaProbeImpl extends DeltaProbe { + // Delta Lake is built-in for Databricks instances, so no probing is necessary. + override def getDeltaProvider: DeltaProvider = DeltaSpark341DBProvider +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark341DBProvider.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark341DBProvider.scala new file mode 100644 index 00000000000..cd204fa0440 --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/DeltaSpark341DBProvider.scala @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta + +import com.databricks.sql.transaction.tahoe.rapids.GpuDeltaCatalog +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec} + +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} + +object DeltaSpark341DBProvider extends DatabricksDeltaProviderBase { + + override def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { + GpuAtomicCreateTableAsSelectExec( + cpuExec.output, + new GpuDeltaCatalog(cpuExec.catalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.query, + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.ifNotExists) + } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + GpuAtomicReplaceTableAsSelectExec( + cpuExec.output, + new GpuDeltaCatalog(cpuExec.catalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.query, + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala new file mode 100644 index 00000000000..969d005b573 --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta + +import com.databricks.sql.transaction.tahoe.{DeltaColumnMappingMode, DeltaParquetFileFormat, IdMapping} +import com.databricks.sql.transaction.tahoe.DeltaParquetFileFormat.IS_ROW_DELETED_COLUMN_NAME +import com.nvidia.spark.rapids.SparkPlanMeta +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +case class GpuDeltaParquetFileFormat( + override val columnMappingMode: DeltaColumnMappingMode, + override val referenceSchema: StructType, + isSplittable: Boolean) extends GpuDeltaParquetFileFormatBase { + + if (columnMappingMode == IdMapping) { + val requiredReadConf = SQLConf.PARQUET_FIELD_ID_READ_ENABLED + require(SparkSession.getActiveSession.exists(_.sessionState.conf.getConf(requiredReadConf)), + s"${requiredReadConf.key} must be enabled to support Delta id column mapping mode") + val requiredWriteConf = SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED + require(SparkSession.getActiveSession.exists(_.sessionState.conf.getConf(requiredWriteConf)), + s"${requiredWriteConf.key} must be enabled to support Delta id column mapping mode") + } + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = isSplittable +} + +object GpuDeltaParquetFileFormat { + def tagSupportForGpuFileSourceScan(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { + val format = meta.wrapped.relation.fileFormat.asInstanceOf[DeltaParquetFileFormat] + val requiredSchema = meta.wrapped.requiredSchema + if (requiredSchema.exists(_.name == IS_ROW_DELETED_COLUMN_NAME)) { + meta.willNotWorkOnGpu( + s"reading metadata column $IS_ROW_DELETED_COLUMN_NAME is not supported") + } + if (format.hasDeletionVectorMap()) { + meta.willNotWorkOnGpu("deletion vectors are not supported") + } + } + + def convertToGpu(fmt: DeltaParquetFileFormat): GpuDeltaParquetFileFormat = { + GpuDeltaParquetFileFormat(fmt.columnMappingMode, fmt.referenceSchema, fmt.isSplittable) + } +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/DeleteCommandMetaShim.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/DeleteCommandMetaShim.scala new file mode 100644 index 00000000000..96863c71ad0 --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/DeleteCommandMetaShim.scala @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta.shims + +import com.databricks.sql.transaction.tahoe.commands.DeletionVectorUtils +import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf +import com.nvidia.spark.rapids.delta.{DeleteCommandEdgeMeta, DeleteCommandMeta} + +object DeleteCommandMetaShim { + def tagForGpu(meta: DeleteCommandMeta): Unit = { + val dvFeatureEnabled = DeletionVectorUtils.deletionVectorsWritable( + meta.deleteCmd.deltaLog.unsafeVolatileSnapshot) + if (dvFeatureEnabled && meta.deleteCmd.conf.getConf( + DeltaSQLConf.DELETE_USE_PERSISTENT_DELETION_VECTORS)) { + // https://github.com/NVIDIA/spark-rapids/issues/8654 + meta.willNotWorkOnGpu("Deletion vector writes are not supported on GPU") + } + } + + def tagForGpu(meta: DeleteCommandEdgeMeta): Unit = { + val dvFeatureEnabled = DeletionVectorUtils.deletionVectorsWritable( + meta.deleteCmd.deltaLog.unsafeVolatileSnapshot) + if (dvFeatureEnabled && meta.deleteCmd.conf.getConf( + DeltaSQLConf.DELETE_USE_PERSISTENT_DELETION_VECTORS)) { + // https://github.com/NVIDIA/spark-rapids/issues/8654 + meta.willNotWorkOnGpu("Deletion vector writes are not supported on GPU") + } + } +} \ No newline at end of file diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/DeltaLogShim.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/DeltaLogShim.scala new file mode 100644 index 00000000000..0bd231e05a6 --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/DeltaLogShim.scala @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta.shims + +import com.databricks.sql.transaction.tahoe.DeltaLog +import com.databricks.sql.transaction.tahoe.actions.Metadata + +import org.apache.spark.sql.execution.datasources.FileFormat + +object DeltaLogShim { + def fileFormat(deltaLog: DeltaLog): FileFormat = { + deltaLog.fileFormat(deltaLog.unsafeVolatileSnapshot.protocol, + deltaLog.unsafeVolatileSnapshot.metadata) + } + def getMetadata(deltaLog: DeltaLog): Metadata = { + deltaLog.unsafeVolatileSnapshot.metadata + } +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/InvariantViolationExceptionShim.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/InvariantViolationExceptionShim.scala new file mode 100644 index 00000000000..58714a91fd4 --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/InvariantViolationExceptionShim.scala @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.delta.shims + +import com.databricks.sql.transaction.tahoe.constraints.Constraints._ +import com.databricks.sql.transaction.tahoe.schema.DeltaInvariantViolationException +import com.databricks.sql.transaction.tahoe.schema.InvariantViolationException + +object InvariantViolationExceptionShim { + def apply(c: Check, m: Map[String, Any]): InvariantViolationException = { + DeltaInvariantViolationException(c, m) + } + + def apply(c: NotNull): InvariantViolationException = { + DeltaInvariantViolationException(c) + } +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala new file mode 100644 index 00000000000..8e13a9e4b5a --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta.shims + +import com.databricks.sql.transaction.tahoe.commands.{MergeIntoCommand, MergeIntoCommandEdge} +import com.databricks.sql.transaction.tahoe.rapids.{GpuDeltaLog, GpuMergeIntoCommand} +import com.nvidia.spark.rapids.RapidsConf +import com.nvidia.spark.rapids.delta.{MergeIntoCommandEdgeMeta, MergeIntoCommandMeta} + +import org.apache.spark.sql.execution.command.RunnableCommand + +object MergeIntoCommandMetaShim { + def tagForGpu(meta: MergeIntoCommandMeta, mergeCmd: MergeIntoCommand): Unit = { + // see https://github.com/NVIDIA/spark-rapids/issues/8415 for more information + if (mergeCmd.notMatchedBySourceClauses.nonEmpty) { + meta.willNotWorkOnGpu("notMatchedBySourceClauses not supported on GPU") + } + } + + def tagForGpu(meta: MergeIntoCommandEdgeMeta, mergeCmd: MergeIntoCommandEdge): Unit = { + // see https://github.com/NVIDIA/spark-rapids/issues/8415 for more information + if (mergeCmd.notMatchedBySourceClauses.nonEmpty) { + meta.willNotWorkOnGpu("notMatchedBySourceClauses not supported on GPU") + } + } + + def convertToGpu(mergeCmd: MergeIntoCommand, conf: RapidsConf): RunnableCommand = { + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } + + def convertToGpu(mergeCmd: MergeIntoCommandEdge, conf: RapidsConf): RunnableCommand = { + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } +} diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala new file mode 100644 index 00000000000..8f5196d7c66 --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.delta.shims + +import com.databricks.sql.transaction.tahoe.stats.DeltaStatistics + +trait ShimUsesMetadataFields { + val NUM_RECORDS = DeltaStatistics.NUM_RECORDS + val MIN = DeltaStatistics.MIN + val MAX = DeltaStatistics.MAX + val NULL_COUNT = DeltaStatistics.NULL_COUNT +} \ No newline at end of file diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/ShimDeltaUDF.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/ShimDeltaUDF.scala new file mode 100644 index 00000000000..fd9052d9691 --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/ShimDeltaUDF.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta.shims + +import com.databricks.sql.transaction.tahoe.DeltaUDF + +import org.apache.spark.sql.expressions.UserDefinedFunction + +object ShimDeltaUDF { + def stringStringUdf(f: String => String): UserDefinedFunction = DeltaUDF.stringFromString(f) +} diff --git a/docs/compatibility.md b/docs/compatibility.md index 8d18d8b57ca..ac90d309fe1 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -346,8 +346,6 @@ with Spark, and can be enabled by setting `spark.rapids.sql.expression.StructsTo Known issues are: -- String escaping is not implemented, so strings containing quotes, newlines, and other special characters will - not produce valid JSON - There is no support for timestamp types - There can be rounding differences when formatting floating-point numbers as strings. For example, Spark may produce `-4.1243574E26` but the GPU may produce `-4.124357351E26`. diff --git a/integration_tests/src/main/python/delta_lake_utils.py b/integration_tests/src/main/python/delta_lake_utils.py index 61ed3d08a20..9a5545a6e3a 100644 --- a/integration_tests/src/main/python/delta_lake_utils.py +++ b/integration_tests/src/main/python/delta_lake_utils.py @@ -33,10 +33,10 @@ delta_writes_enabled_conf = {"spark.rapids.sql.format.delta.write.enabled": "true"} -delta_write_fallback_allow = "ExecutedCommandExec,DataWritingCommandExec" if is_databricks122_or_later() else "ExecutedCommandExec" +delta_write_fallback_allow = "ExecutedCommandExec,DataWritingCommandExec,WriteFilesExec" if is_databricks122_or_later() else "ExecutedCommandExec" delta_write_fallback_check = "DataWritingCommandExec" if is_databricks122_or_later() else "ExecutedCommandExec" -delta_optimized_write_fallback_allow = "ExecutedCommandExec,DataWritingCommandExec,DeltaOptimizedWriterExec" if is_databricks122_or_later() else "ExecutedCommandExec" +delta_optimized_write_fallback_allow = "ExecutedCommandExec,DataWritingCommandExec,DeltaOptimizedWriterExec,WriteFilesExec" if is_databricks122_or_later() else "ExecutedCommandExec" def _fixup_operation_metrics(opm): """Update the specified operationMetrics node to facilitate log comparisons""" diff --git a/integration_tests/src/main/python/delta_lake_write_test.py b/integration_tests/src/main/python/delta_lake_write_test.py index 41f782a0005..f6158624dbe 100644 --- a/integration_tests/src/main/python/delta_lake_write_test.py +++ b/integration_tests/src/main/python/delta_lake_write_test.py @@ -178,6 +178,7 @@ def do_write(spark, path): @delta_lake @ignore_order(local=True) @pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") +@pytest.mark.xfail(condition=is_spark_340_or_later() and is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/9676") def test_delta_atomic_create_table_as_select(spark_tmp_table_factory, spark_tmp_path): _atomic_write_table_as_select(delta_write_gens, spark_tmp_table_factory, spark_tmp_path, overwrite=False) @@ -185,6 +186,7 @@ def test_delta_atomic_create_table_as_select(spark_tmp_table_factory, spark_tmp_ @delta_lake @ignore_order(local=True) @pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") +@pytest.mark.xfail(condition=is_spark_340_or_later() and is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/9676") def test_delta_atomic_replace_table_as_select(spark_tmp_table_factory, spark_tmp_path): _atomic_write_table_as_select(delta_write_gens, spark_tmp_table_factory, spark_tmp_path, overwrite=True) diff --git a/integration_tests/src/main/python/json_test.py b/integration_tests/src/main/python/json_test.py index 043349ce54e..5b7cee85440 100644 --- a/integration_tests/src/main/python/json_test.py +++ b/integration_tests/src/main/python/json_test.py @@ -614,8 +614,14 @@ def test_read_case_col_name(spark_tmp_path, v1_enabled_list, col_name): pytest.param(double_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/9350')), pytest.param(date_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/9515')), pytest.param(timestamp_gen, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/9515')), - StringGen('[A-Za-z0-9]{0,10}', nullable=True), - pytest.param(StringGen(nullable=True), marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/9514')), + StringGen('[A-Za-z0-9\r\n\'"\\\\]{0,10}', nullable=True) \ + .with_special_case('\u1f600') \ + .with_special_case('"a"') \ + .with_special_case('\\"a\\"') \ + .with_special_case('\'a\'') \ + .with_special_case('\\\'a\\\''), + pytest.param(StringGen('\u001a', nullable=True), marks=pytest.mark.xfail( + reason='https://github.com/NVIDIA/spark-rapids/issues/9705')) ], ids=idfn) @pytest.mark.parametrize('ignore_null_fields', [True, False]) @pytest.mark.parametrize('pretty', [ diff --git a/scala2.13/delta-lake/delta-spark341db/pom.xml b/scala2.13/delta-lake/delta-spark341db/pom.xml new file mode 100644 index 00000000000..e8d7d0dd644 --- /dev/null +++ b/scala2.13/delta-lake/delta-spark341db/pom.xml @@ -0,0 +1,296 @@ + + + + 4.0.0 + + + com.nvidia + rapids-4-spark-jdk-profiles_2.13 + 23.12.0-SNAPSHOT + ../../jdk-profiles/pom.xml + + + rapids-4-spark-delta-spark341db_2.13 + RAPIDS Accelerator for Apache Spark Databricks 13.3 Delta Lake Support + Databricks 13.3 Delta Lake support for the RAPIDS Accelerator for Apache Spark + 23.12.0-SNAPSHOT + + + false + **/* + package + + + + + com.nvidia + rapids-4-spark-sql_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + provided + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-annotation_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-network-common_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-network-shuffle_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-launcher_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${spark.version} + provided + + + org.apache.avro + avro-mapred + ${spark.version} + provided + + + org.apache.avro + avro + ${spark.version} + provided + + + org.apache.hive + hive-exec + ${spark.version} + provided + + + org.apache.hive + hive-serde + ${spark.version} + provided + + + org.apache.spark + spark-hive_${scala.binary.version} + + + com.fasterxml.jackson.core + jackson-core + ${spark.version} + provided + + + com.fasterxml.jackson.core + jackson-annotations + ${spark.version} + provided + + + org.json4s + json4s-ast_${scala.binary.version} + ${spark.version} + provided + + + org.json4s + json4s-core_${scala.binary.version} + ${spark.version} + provided + + + org.apache.commons + commons-io + ${spark.version} + provided + + + org.scala-lang + scala-reflect + ${scala.version} + provided + + + org.apache.commons + commons-lang3 + ${spark.version} + provided + + + com.esotericsoftware.kryo + kryo-shaded-db + ${spark.version} + provided + + + org.apache.parquet + parquet-hadoop + ${spark.version} + provided + + + org.apache.parquet + parquet-common + ${spark.version} + provided + + + org.apache.parquet + parquet-column + ${spark.version} + provided + + + org.apache.parquet + parquet-format + ${spark.version} + provided + + + org.apache.arrow + arrow-memory + ${spark.version} + provided + + + org.apache.arrow + arrow-vector + ${spark.version} + provided + + + org.apache.hadoop + hadoop-client + ${hadoop.client.version} + provided + + + org.apache.orc + orc-core + ${spark.version} + provided + + + org.apache.orc + orc-shims + ${spark.version} + provided + + + org.apache.orc + orc-mapreduce + ${spark.version} + provided + + + org.apache.hive + hive-storage-api + ${spark.version} + provided + + + com.google.protobuf + protobuf-java + ${spark.version} + provided + + + org.apache.spark + spark-common-utils_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-sql-api_${scala.binary.version} + ${spark.version} + provided + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-common-sources + generate-sources + + add-source + + + + ${project.basedir}/../common/src/main/scala + ${project.basedir}/../common/src/main/databricks/scala + + + + + + + net.alchim31.maven + scala-maven-plugin + + + org.apache.rat + apache-rat-plugin + + + + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 31752d482c3..6634c946d47 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -899,7 +899,10 @@ object GpuCast { val numRows = input.getRowCount.toInt - /** Create a new column with quotes around the supplied string column */ + /** + * Create a new column with quotes around the supplied string column. Caller + * is responsible for closing `column`. + */ def addQuotes(column: ColumnVector, rowCount: Int): ColumnVector = { withResource(ArrayBuffer.empty[ColumnVector]) { columns => withResource(Scalar.fromString("\"")) { quote => @@ -922,7 +925,7 @@ object GpuCast { // keys must have quotes around them in JSON mode val strKey: ColumnVector = withResource(kvStructColumn.getChildColumnView(0)) { keyColumn => withResource(castToString(keyColumn, from.keyType, options)) { key => - addQuotes(key.incRefCount(), keyColumn.getRowCount.toInt) + addQuotes(key, keyColumn.getRowCount.toInt) } } // string values must have quotes around them in JSON mode, and null values need @@ -931,7 +934,7 @@ object GpuCast { withResource(kvStructColumn.getChildColumnView(1)) { valueColumn => val valueStr = if (valueColumn.getType == DType.STRING) { withResource(castToString(valueColumn, from.valueType, options)) { valueStr => - addQuotes(valueStr.incRefCount(), valueColumn.getRowCount.toInt) + addQuotes(valueStr, valueColumn.getRowCount.toInt) } } else { castToString(valueColumn, from.valueType, options) @@ -1136,7 +1139,7 @@ object GpuCast { attrValue => if (needsQuoting) { attrValues += quote.incRefCount() - attrValues += escapeJsonString(attrValue.incRefCount()) + attrValues += escapeJsonString(attrValue) attrValues += quote.incRefCount() withResource(Scalar.fromString("")) { emptyString => ColumnVector.stringConcatenate(emptyString, emptyString, attrValues.toArray) @@ -1199,10 +1202,17 @@ object GpuCast { } } + /** + * Escape quotes and newlines in a string column. Caller is responsible for closing `cv`. + */ private def escapeJsonString(cv: ColumnVector): ColumnVector = { - // this is a placeholder for implementing string escaping - // https://github.com/NVIDIA/spark-rapids/issues/9514 - cv + val chars = Seq("\r", "\n", "\\", "\"") + val escaped = chars.map(StringEscapeUtils.escapeJava) + withResource(ColumnVector.fromStrings(chars: _*)) { search => + withResource(ColumnVector.fromStrings(escaped: _*)) { replace => + cv.stringReplace(search, replace) + } + } } private[rapids] def castFloatingTypeToString(input: ColumnView): ColumnVector = { diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/PlanShimsImpl.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/PlanShimsImpl.scala index 05bcf9ca490..1ac7eeddf3b 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/PlanShimsImpl.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/PlanShimsImpl.scala @@ -32,7 +32,6 @@ {"spark": "333"} {"spark": "340"} {"spark": "341"} -{"spark": "341db"} {"spark": "350"} spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala index c5059046867..bbd43ee24b0 100644 --- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -68,7 +68,7 @@ case class GpuAtomicReplaceTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) override def supportsColumnar: Boolean = false - + override protected def run(): Seq[InternalRow] = { val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable if (catalog.tableExists(ident)) { diff --git a/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/PlanShimsImpl.scala b/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/PlanShimsImpl.scala new file mode 100644 index 00000000000..2968081dff6 --- /dev/null +++ b/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/PlanShimsImpl.scala @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "341db"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import com.nvidia.spark.rapids.{GpuAlias, PlanShims} + +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} +import org.apache.spark.sql.execution.{CommandResultExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive.ResultQueryStageExec + +class PlanShimsImpl extends PlanShims { + def extractExecutedPlan(plan: SparkPlan): SparkPlan = plan match { + case p: CommandResultExec => p.commandPhysicalPlan + case q: ResultQueryStageExec => q.plan + case _ => plan + } + + def isAnsiCast(e: Expression): Boolean = AnsiCastShim.isAnsiCast(e) + + def isAnsiCastOptionallyAliased(e: Expression): Boolean = e match { + case Alias(e, _) => isAnsiCast(e) + case GpuAlias(e, _) => isAnsiCast(e) + case e => isAnsiCast(e) + } +} diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala new file mode 100644 index 00000000000..398bc8e187d --- /dev/null +++ b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "341db"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicCreateTableAsSelectExec. + * + * Physical plan node for v2 create table as select, when the catalog is determined to support + * staging table creation. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * The CTAS operation is atomic. The creation of the table is staged and the commit of the write + * should bundle the commitment of the metadata and the table contents in a single unit. If the + * write fails, the table is instructed to roll back all staged changes. + */ +case class GpuAtomicCreateTableAsSelectExec( + override val output: Seq[Attribute], + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + query: LogicalPlan, + tableSpec: TableSpec, + writeOptions: Map[String, String], + ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec with GpuExec { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + if (catalog.tableExists(ident)) { + if (ifNotExists) { + return Nil + } + + throw QueryCompilationErrors.tableAlreadyExistsError(ident) + } + val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable + val stagedTable = catalog.stageCreate( + ident, getV2Columns(schema, catalog.useNullableQuerySchema), + partitioning.toArray, properties.asJava) + + writeToTable(catalog, stagedTable, writeOptions, ident, query) + } + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala new file mode 100644 index 00000000000..d1380facb86 --- /dev/null +++ b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "341db"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicReplaceTableAsSelectExec. + * + * Physical plan node for v2 replace table as select when the catalog supports staging + * table replacement. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If the table exists, its contents and schema should be replaced with the schema and the contents + * of the query. This implementation is atomic. The table replacement is staged, and the commit + * operation at the end should perform the replacement of the table's metadata and contents. If the + * write fails, the table is instructed to roll back staged changes and any previously written table + * is left untouched. + */ +case class GpuAtomicReplaceTableAsSelectExec( + override val output: Seq[Attribute], + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + query: LogicalPlan, + tableSpec: TableSpec, + writeOptions: Map[String, String], + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) + extends V2CreateTableAsSelectBaseExec with GpuExec { + + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) + } + val staged = if (orCreate) { + catalog.stageCreateOrReplace( + ident, schema, partitioning.toArray, properties.asJava) + } else if (catalog.tableExists(ident)) { + try { + catalog.stageReplace( + ident, schema, partitioning.toArray, properties.asJava) + } catch { + case e: NoSuchTableException => + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e)) + } + } else { + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) + } + writeToTable(catalog, staged, writeOptions, ident, query) + } + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +}