Skip to content

Commit

Permalink
Fix Delta Lake atomic table operations on spark341db (#9729)
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Lowe <[email protected]>
  • Loading branch information
jlowe authored Nov 17, 2023
1 parent 4f0e4fa commit 94b25db
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ import org.apache.spark.sql.sources.InsertableRelation
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/** A trait used to identify Delta tables that are GPU-aware. */
trait GpuDeltaSupportsWrite extends SupportsWrite

/** A trait used to identify Delta V1Write that is GPU-aware */
trait GpuDeltaV1Write extends V1Write

trait GpuDeltaCatalogBase extends StagingTableCatalog {
val spark: SparkSession = SparkSession.active

Expand Down Expand Up @@ -294,7 +300,7 @@ trait GpuDeltaCatalogBase extends StagingTableCatalog {
val partitions: Array[Transform],
override val properties: util.Map[String, String],
operation: TableCreationModes.CreationMode
) extends StagedTable with SupportsWrite {
) extends StagedTable with GpuDeltaSupportsWrite {

private var asSelectQuery: Option[DataFrame] = None
private var writeOptions: Map[String, String] = Map.empty
Expand Down Expand Up @@ -358,7 +364,7 @@ trait GpuDeltaCatalogBase extends StagingTableCatalog {
* WriteBuilder for creating a Delta table.
*/
private class DeltaV1WriteBuilder extends WriteBuilder {
override def build(): V1Write = new V1Write {
override def build(): V1Write = new GpuDeltaV1Write {
override def toInsertableRelation(): InsertableRelation = {
new InsertableRelation {
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.databricks.sql.managedcatalog.UnityCatalogV2Proxy
import com.databricks.sql.transaction.tahoe.{DeltaLog, DeltaOptions, DeltaParquetFileFormat}
import com.databricks.sql.transaction.tahoe.catalog.{DeltaCatalog, DeltaTableV2}
import com.databricks.sql.transaction.tahoe.commands.{DeleteCommand, DeleteCommandEdge, MergeIntoCommand, MergeIntoCommandEdge, UpdateCommand, UpdateCommandEdge, WriteIntoDelta}
import com.databricks.sql.transaction.tahoe.rapids.{GpuDeltaLog, GpuWriteIntoDelta}
import com.databricks.sql.transaction.tahoe.rapids.{GpuDeltaLog, GpuDeltaSupportsWrite, GpuDeltaV1Write, GpuWriteIntoDelta}
import com.databricks.sql.transaction.tahoe.sources.{DeltaDataSource, DeltaSourceUtils}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.delta.shims.DeltaLogShim
Expand Down Expand Up @@ -94,7 +94,7 @@ trait DatabricksDeltaProviderBase extends DeltaProviderImplBase {
}

override def isSupportedWrite(write: Class[_ <: SupportsWrite]): Boolean = {
write == classOf[DeltaTableV2]
write == classOf[DeltaTableV2] || classOf[GpuDeltaSupportsWrite].isAssignableFrom(write)
}

override def tagSupportForGpuFileSourceScan(meta: SparkPlanMeta[FileSourceScanExec]): Unit = {
Expand Down Expand Up @@ -219,29 +219,37 @@ trait DatabricksDeltaProviderBase extends DeltaProviderImplBase {
meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " +
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
val deltaTable = cpuExec.table.asInstanceOf[DeltaTableV2]
val tablePath = if (deltaTable.catalogTable.isDefined) {
new Path(deltaTable.catalogTable.get.location)
} else {
DeltaDataSource.parsePathIdentifier(cpuExec.session, deltaTable.path.toString,
deltaTable.options)._1
}
val deltaLog = DeltaLog.forTable(cpuExec.session, tablePath, deltaTable.options)
RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.plan.schema, Some(deltaLog),
deltaTable.options, cpuExec.session)
extractWriteV1Config(meta, deltaLog, cpuExec.write).foreach { writeConfig =>
meta.setCustomTaggingData(writeConfig)
cpuExec.write match {
case _: GpuDeltaV1Write => // write is already using GPU, nothing more to do
case write =>
val deltaTable = cpuExec.table.asInstanceOf[DeltaTableV2]
val tablePath = if (deltaTable.catalogTable.isDefined) {
new Path(deltaTable.catalogTable.get.location)
} else {
DeltaDataSource.parsePathIdentifier(cpuExec.session, deltaTable.path.toString,
deltaTable.options)._1
}
val deltaLog = DeltaLog.forTable(cpuExec.session, tablePath, deltaTable.options)
RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.plan.schema, Some(deltaLog),
deltaTable.options, cpuExec.session)
extractWriteV1Config(meta, deltaLog, write).foreach { writeConfig =>
meta.setCustomTaggingData(writeConfig)
}
}
}

override def convertToGpu(
cpuExec: AppendDataExecV1,
meta: AppendDataExecV1Meta): GpuExec = {
val writeConfig = meta.getCustomTaggingData match {
case Some(c: DeltaWriteV1Config) => c
case _ => throw new IllegalStateException("Missing Delta write config from tagging pass")
val gpuWrite = cpuExec.write match {
case write: GpuDeltaV1Write => write
case _ =>
val writeConfig = meta.getCustomTaggingData match {
case Some(c: DeltaWriteV1Config) => c
case _ => throw new IllegalStateException("Missing Delta write config from tagging pass")
}
toGpuWrite(writeConfig, meta.conf)
}
val gpuWrite = toGpuWrite(writeConfig, meta.conf)
GpuAppendDataExecV1(cpuExec.table, cpuExec.plan, cpuExec.refreshCache, gpuWrite)
}

Expand Down Expand Up @@ -280,7 +288,7 @@ trait DatabricksDeltaProviderBase extends DeltaProviderImplBase {

private def toGpuWrite(
writeConfig: DeltaWriteV1Config,
rapidsConf: RapidsConf): V1Write = new V1Write {
rapidsConf: RapidsConf): V1Write = new GpuDeltaV1Write {
override def toInsertableRelation(): InsertableRelation = {
new InsertableRelation {
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
Expand Down
2 changes: 0 additions & 2 deletions integration_tests/src/main/python/delta_lake_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,13 @@ def do_write(spark, path):
@delta_lake
@ignore_order(local=True)
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
@pytest.mark.xfail(condition=is_spark_340_or_later() and is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/9676")
def test_delta_atomic_create_table_as_select(spark_tmp_table_factory, spark_tmp_path):
_atomic_write_table_as_select(delta_write_gens, spark_tmp_table_factory, spark_tmp_path, overwrite=False)

@allow_non_gpu(*delta_meta_allow)
@delta_lake
@ignore_order(local=True)
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
@pytest.mark.xfail(condition=is_spark_340_or_later() and is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/9676")
def test_delta_atomic_replace_table_as_select(spark_tmp_table_factory, spark_tmp_path):
_atomic_write_table_as_select(delta_write_gens, spark_tmp_table_factory, spark_tmp_path, overwrite=True)

Expand Down

0 comments on commit 94b25db

Please sign in to comment.