Skip to content

Commit

Permalink
Add integration test for async writer
Browse files Browse the repository at this point in the history
  • Loading branch information
jihoonson committed Dec 10, 2024
1 parent 017fdef commit 24a525f
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 11 deletions.
21 changes: 21 additions & 0 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,27 @@ def test_write_daytime_interval(spark_tmp_path):
data_path,
conf=writer_confs)


hold_gpu_configs = [True, False]
@pytest.mark.parametrize('hold_gpu', hold_gpu_configs, ids=idfn)
def test_async_writer(spark_tmp_path, hold_gpu):
data_path = spark_tmp_path + '/PARQUET_DATA'
num_rows = 2048
num_cols = 10
parquet_gen = [int_gen for _ in range(num_cols)]
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gen)]
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gen_list, length=num_rows).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
copy_and_update(
writer_confs,
{"spark.rapids.sql.asyncWrite.queryOutput.enabled": "true",
"spark.rapids.sql.batchSizeBytes": 4 * num_cols * 100, # 100 rows per batch
"spark.rapids.sql.queryOutput.holdGpuInTask": hold_gpu}
))


@ignore_order
@pytest.mark.skipif(is_before_spark_320(), reason="is only supported in Spark 320+")
def test_concurrent_writer(spark_tmp_path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,14 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext,
dataSchema: StructType,
rangeName: String,
includeRetry: Boolean,
holdGpuBetweenBatches: Boolean = false) extends HostBufferConsumer with Logging {
holdGpuBetweenBatches: Boolean = false,
useAsyncWrite: Boolean = false) extends HostBufferConsumer with Logging {

protected val tableWriter: TableWriter

protected val conf: Configuration = context.getConfiguration

private val trafficController: Option[TrafficController] = TrafficController.getInstance
private val trafficController: TrafficController = TrafficController.getInstance

private def openOutputStream(): OutputStream = {
val hadoopPath = new Path(path)
Expand All @@ -90,10 +91,12 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext,
// This is implemented as a method to make it easier to subclass
// ColumnarOutputWriter in the tests, and override this behavior.
protected def getOutputStream: OutputStream = {
trafficController.map(controller => {
if (useAsyncWrite) {
logWarning("Async output write enabled")
new AsyncOutputStream(() => openOutputStream(), controller)
}).getOrElse(openOutputStream())
new AsyncOutputStream(() => openOutputStream(), trafficController)
} else {
openOutputStream()
}
}

protected val outputStream: OutputStream = getOutputStream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging {
context: TaskAttemptContext): ColumnarOutputWriter = {
new GpuParquetWriter(path, dataSchema, compressionType, outputTimestampType.toString,
dateTimeRebaseMode, timestampRebaseMode, context, parquetFieldIdWriteEnabled,
holdGpuBetweenBatches)
holdGpuBetweenBatches, asyncOutputWriteEnabled)
}

override def getFileExtension(context: TaskAttemptContext): String = {
Expand All @@ -306,8 +306,10 @@ class GpuParquetWriter(
timestampRebaseMode: DateTimeRebaseMode,
context: TaskAttemptContext,
parquetFieldIdEnabled: Boolean,
holdGpuBetweenBatches: Boolean)
extends ColumnarOutputWriter(context, dataSchema, "Parquet", true, holdGpuBetweenBatches) {
holdGpuBetweenBatches: Boolean,
useAsyncWrite: Boolean)
extends ColumnarOutputWriter(context, dataSchema, "Parquet", true, holdGpuBetweenBatches,
useAsyncWrite) {
override def throwIfRebaseNeededInExceptionMode(batch: ColumnarBatch): Unit = {
val cols = GpuColumnVector.extractBases(batch)
cols.foreach { col =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ object TrafficController {
* This is called once per executor.
*/
def initialize(conf: RapidsConf): Unit = synchronized {
if (conf.isAsyncOutputWriteEnabled && instance == null) {
if (instance == null) {
instance = new TrafficController(
new HostMemoryThrottle(conf.asyncWriteMaxInFlightHostMemoryBytes))
}
}

def getInstance: Option[TrafficController] = synchronized {
Option(instance)
def getInstance: TrafficController = synchronized {
instance
}

def shutdown(): Unit = synchronized {
Expand Down

0 comments on commit 24a525f

Please sign in to comment.