Skip to content

Commit

Permalink
Fix Authorization Failure While Reading Tables From Unity Catalog [da…
Browse files Browse the repository at this point in the history
…tabricks] (NVIDIA#10756)

* Use cached ThreadPoolExecutor

* Revert "Fix Multithreaded Readers working with Unity Catalog on Databricks [databricks] (NVIDIA#8296)"

* Signing off

Signed-off-by: Raza Jafri <[email protected]>

* Removed spark311 version of ReaderUtils.scala

---------

Signed-off-by: Raza Jafri <[email protected]>
  • Loading branch information
razajafri authored May 1, 2024
1 parent db4f44a commit 8403941
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids

import java.io.{File, IOException}
import java.net.{URI, URISyntaxException}
import java.util.concurrent.{Callable, ConcurrentLinkedQueue, ExecutorCompletionService, Future, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.{Callable, ConcurrentLinkedQueue, ExecutorCompletionService, Future, ThreadPoolExecutor, TimeUnit}

import scala.annotation.tailrec
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -123,20 +123,11 @@ object MultiFileReaderThreadPool extends Logging {

private def initThreadPool(
maxThreads: Int,
keepAliveSeconds: Long = 60): ThreadPoolExecutor = synchronized {
keepAliveSeconds: Int = 60): ThreadPoolExecutor = synchronized {
if (threadPool.isEmpty) {
val threadFactory = new ThreadFactoryBuilder()
.setNameFormat(s"multithreaded file reader worker-%d")
.setDaemon(true)
.build()

val threadPoolExecutor = new ThreadPoolExecutor(
maxThreads, // corePoolSize: max number of threads to create before queuing the tasks
maxThreads, // maximumPoolSize: because we use LinkedBlockingDeque, this is not used
keepAliveSeconds,
TimeUnit.SECONDS,
new LinkedBlockingQueue[Runnable],
threadFactory)
val threadPoolExecutor =
TrampolineUtil.newDaemonCachedThreadPool("multithreaded file reader worker", maxThreads,
keepAliveSeconds)
threadPoolExecutor.allowCoreThreadTimeOut(true)
logDebug(s"Using $maxThreads for the multithreaded reader thread pool")
threadPool = Some(threadPoolExecutor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import com.nvidia.spark.rapids.RapidsConf.ParquetFooterReaderType
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.filecache.FileCache
import com.nvidia.spark.rapids.jni.{DateTimeRebase, ParquetFooter}
import com.nvidia.spark.rapids.shims.{ColumnDefaultValuesShims, GpuParquetCrypto, GpuTypeShims, ParquetLegacyNanoAsLongShims, ParquetSchemaClipShims, ParquetStringPredShims, ReaderUtils, ShimFilePartitionReaderFactory, SparkShimImpl}
import com.nvidia.spark.rapids.shims.{ColumnDefaultValuesShims, GpuParquetCrypto, GpuTypeShims, ParquetLegacyNanoAsLongShims, ParquetSchemaClipShims, ParquetStringPredShims, ShimFilePartitionReaderFactory, SparkShimImpl}
import org.apache.commons.io.IOUtils
import org.apache.commons.io.output.{CountingOutputStream, NullOutputStream}
import org.apache.hadoop.conf.Configuration
Expand Down Expand Up @@ -683,12 +683,10 @@ private case class GpuParquetFileFilterHandler(
conf.unset(encryptConf)
}
}
val fileHadoopConf =
ReaderUtils.getHadoopConfForReaderThread(new Path(file.filePath.toString), conf)
val footer: ParquetMetadata = try {
footerReader match {
case ParquetFooterReaderType.NATIVE =>
val serialized = withResource(readAndFilterFooter(file, fileHadoopConf,
val serialized = withResource(readAndFilterFooter(file, conf,
readDataSchema, filePath)) { tableFooter =>
if (tableFooter.getNumColumns <= 0) {
// Special case because java parquet reader does not like having 0 columns.
Expand All @@ -712,7 +710,7 @@ private case class GpuParquetFileFilterHandler(
}
}
case _ =>
readAndSimpleFilterFooter(file, fileHadoopConf, filePath)
readAndSimpleFilterFooter(file, conf, filePath)
}
} catch {
case e if GpuParquetCrypto.isColumnarCryptoException(e) =>
Expand All @@ -739,9 +737,9 @@ private case class GpuParquetFileFilterHandler(
val blocks = if (pushedFilters.isDefined) {
withResource(new NvtxRange("getBlocksWithFilter", NvtxColor.CYAN)) { _ =>
// Use the ParquetFileReader to perform dictionary-level filtering
ParquetInputFormat.setFilterPredicate(fileHadoopConf, pushedFilters.get)
ParquetInputFormat.setFilterPredicate(conf, pushedFilters.get)
//noinspection ScalaDeprecation
withResource(new ParquetFileReader(fileHadoopConf, footer.getFileMetaData, filePath,
withResource(new ParquetFileReader(conf, footer.getFileMetaData, filePath,
footer.getBlocks, Collections.emptyList[ColumnDescriptor])) { parquetReader =>
parquetReader.getRowGroups
}
Expand Down Expand Up @@ -1551,14 +1549,13 @@ trait ParquetPartitionReaderBase extends Logging with ScanWithMetrics
val filePathString: String = filePath.toString
val remoteItems = new ArrayBuffer[CopyRange](blocks.length)
var totalBytesToCopy = 0L
val fileHadoopConf = ReaderUtils.getHadoopConfForReaderThread(filePath, conf)
withResource(new ArrayBuffer[LocalCopy](blocks.length)) { localItems =>
blocks.foreach { block =>
block.getColumns.asScala.foreach { column =>
val columnSize = column.getTotalSize
val outputOffset = totalBytesToCopy + startPos
val channel = FileCache.get.getDataRangeChannel(filePathString,
column.getStartingPos, columnSize, fileHadoopConf)
column.getStartingPos, columnSize, conf)
if (channel.isDefined) {
localItems += LocalCopy(channel.get, columnSize, outputOffset)
} else {
Expand Down Expand Up @@ -1589,14 +1586,13 @@ trait ParquetPartitionReaderBase extends Logging with ScanWithMetrics
return 0L
}

val fileHadoopConf = ReaderUtils.getHadoopConfForReaderThread(filePath, conf)
val coalescedRanges = coalesceReads(remoteCopies)

val totalBytesCopied = PerfIO.readToHostMemory(
fileHadoopConf, out.buffer, filePath.toUri,
conf, out.buffer, filePath.toUri,
coalescedRanges.map(r => IntRangeWithOffset(r.offset, r.length, r.outputOffset))
).getOrElse {
withResource(filePath.getFileSystem(fileHadoopConf).open(filePath)) { in =>
withResource(filePath.getFileSystem(conf).open(filePath)) { in =>
val copyBuffer: Array[Byte] = new Array[Byte](copyBufferSize)
coalescedRanges.foldLeft(0L) { (acc, blockCopy) =>
acc + copyDataRange(blockCopy, in, out, copyBuffer)
Expand All @@ -1608,7 +1604,7 @@ trait ParquetPartitionReaderBase extends Logging with ScanWithMetrics
metrics.getOrElse(GpuMetric.FILECACHE_DATA_RANGE_MISSES, NoopMetric) += 1
metrics.getOrElse(GpuMetric.FILECACHE_DATA_RANGE_MISSES_SIZE, NoopMetric) += range.length
val cacheToken = FileCache.get.startDataRangeCache(
filePathString, range.offset, range.length, fileHadoopConf)
filePathString, range.offset, range.length, conf)
// If we get a filecache token then we can complete the caching by providing the data.
// If we do not get a token then we should not cache this data.
cacheToken.foreach { token =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.apache.spark.sql.rapids.execution

import java.util.concurrent.ThreadPoolExecutor

import org.json4s.JsonAST

import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkMasterRegex, SparkUpgradeException, TaskContext}
Expand Down Expand Up @@ -219,6 +221,13 @@ object TrampolineUtil {
}
}

def newDaemonCachedThreadPool(
prefix: String,
maxThreadNumber: Int,
keepAliveSeconds: Int): ThreadPoolExecutor = {
org.apache.spark.util.ThreadUtils.newDaemonCachedThreadPool(prefix, maxThreadNumber,
keepAliveSeconds)
}

def postEvent(sc: SparkContext, sparkEvent: SparkListenerEvent): Unit = {
sc.listenerBus.post(sparkEvent)
Expand Down

This file was deleted.

This file was deleted.

0 comments on commit 8403941

Please sign in to comment.