From 13b94ef2db379c7832e97c1bb2b86bd7246271dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Spo=CC=88rri?= Date: Wed, 23 Aug 2023 10:11:32 +0200 Subject: [PATCH] Implement stream prefetching and double-buffering. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Pascal SpoĢˆrri --- README.md | 2 - .../spark/storage/PrefetchIterator.scala | 36 +++ .../storage/S3DoubleBufferedStream.scala | 228 ++++++++++++++++++ .../spark/storage/S3ShuffleReader.scala | 58 ++--- 4 files changed, 283 insertions(+), 41 deletions(-) create mode 100644 src/main/scala/org/apache/spark/storage/PrefetchIterator.scala create mode 100644 src/main/scala/org/apache/spark/storage/S3DoubleBufferedStream.scala diff --git a/README.md b/README.md index d759376..b917941 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,6 @@ Changing these values might have an impact on performance. - `spark.shuffle.s3.bufferSize`: Default size of the buffered output streams (default: `32768`, uses `spark.shuffle.file.buffer` as default) -- `spark.shuffle.s3.bufferInputSize`: Maximum size of buffered input streams (default: `209715200`, - uses `spark.network.maxRemoteBlockSizeFetchToMem` as default) - `spark.shuffle.s3.cachePartitionLengths`: Cache partition lengths in memory (default: `true`) - `spark.shuffle.s3.cacheChecksums`: Cache checksums in memory (default: `true`) - `spark.shuffle.s3.cleanup`: Cleanup the shuffle files (default: `true`) diff --git a/src/main/scala/org/apache/spark/storage/PrefetchIterator.scala b/src/main/scala/org/apache/spark/storage/PrefetchIterator.scala new file mode 100644 index 0000000..72d006f --- /dev/null +++ b/src/main/scala/org/apache/spark/storage/PrefetchIterator.scala @@ -0,0 +1,36 @@ +/** + * Copyright 2023- IBM Inc. All rights reserved + * SPDX-License-Identifier: Apache2.0 + */ + +package org.apache.spark.storage + +import scala.collection.AbstractIterator + +class PrefetchIterator[A](iter: Iterator[A]) extends AbstractIterator[A] { + + private var value: A = _ + private var valueDefined: Boolean = false + + override def hasNext: Boolean = { + populateNext() + valueDefined + } + + private def populateNext(): Unit = { + if (valueDefined) { + return + } + if (iter.hasNext) { + value = iter.next() + valueDefined = true + } + } + + override def next(): A = { + val result = value + valueDefined = false + populateNext() + result + } +} diff --git a/src/main/scala/org/apache/spark/storage/S3DoubleBufferedStream.scala b/src/main/scala/org/apache/spark/storage/S3DoubleBufferedStream.scala new file mode 100644 index 0000000..06fd8f9 --- /dev/null +++ b/src/main/scala/org/apache/spark/storage/S3DoubleBufferedStream.scala @@ -0,0 +1,228 @@ +/** + * Copyright 2023- IBM Inc. All rights reserved + * SPDX-License-Identifier: Apache2.0 + */ + +package org.apache.spark.storage + +import org.apache.hadoop.io.ElasticByteBufferPool +import org.apache.spark.SparkException +import org.apache.spark.storage.S3DoubleBufferedStream.{getBuffer, releaseBuffer} + +import java.io.{EOFException, InputStream} +import java.nio.ByteBuffer +import scala.concurrent.Future +import scala.util.{Failure, Success, Try} + +class S3DoubleBufferedStream(stream: S3ShuffleBlockStream, bufferSize: Int) extends InputStream { + private var buffers: Array[ByteBuffer] = { + val array = new Array[ByteBuffer](2) + array(0) = getBuffer(bufferSize) + array(1) = getBuffer(bufferSize) + // Mark buffers as empty + array.foreach(b => { + b.clear().limit(0) + }) + array + } + private var prefetching = Array.fill(2)(false) + + var streamClosed = false + var pos: Long = 0 + val maxBytes: Long = stream.maxBytes + + private var bufIdx: Int = 0 + var dataAvailable: Boolean = false + var error: Option[Throwable] = None + + doPrefetch(nextBuffer) + + private def currentBuffer: ByteBuffer = synchronized { + buffers(bufIdx) + } + + private def nextBuffer: ByteBuffer = synchronized { + buffers((bufIdx + 1) % buffers.length) + } + + private def swap() = synchronized { + bufIdx = (bufIdx + 1) % buffers.length + } + + private def eof: Boolean = synchronized { + if (buffers == null) { + throw new EOFException("Stream already closed") + } + pos >= maxBytes + } + + private def prepareRead(): Unit = synchronized { + if (!currentBuffer.hasRemaining && dataAvailable) { + swap() + dataAvailable = false + doPrefetch(nextBuffer) + } + } + + private def doPrefetch(buffer: ByteBuffer): Unit = { + if (stream.available() == 0) { + // no data available + return + } + val fut = Future[Int] { + buffer.clear() + var len: Int = 0 + do { + len = writeTo(buffer, stream) + if (len < 0) { + throw new EOFException() + } + } while (len == 0) + buffer.flip() + len + }(S3ShuffleReader.asyncExecutionContext) + fut.onComplete(onCompletePrefetch)(S3ShuffleReader.asyncExecutionContext) + } + + private def onCompletePrefetch(result: Try[Int]): Unit = synchronized { + result match { + case Failure(exception) => error = Some(exception) + case Success(len) => + dataAvailable = true + } + notifyAll() + } + + override def read(): Int = synchronized { + if (eof) { + return -1 + } + + while (error.isEmpty) { + if (buffers == null) { + throw new EOFException("Stream already closed") + } + prepareRead() + if (currentBuffer.hasRemaining) { + val l = readFrom(currentBuffer) + if (l < 0) { + throw new SparkException("Invalid state in shuffle read.") + } + pos += 1 + return l + } + try { + wait() + } + catch { + case _: InterruptedException => + Thread.currentThread.interrupt() + } + } + throw error.get + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = synchronized { + if (off < 0 || len < 0 || off + len < 0 || off + len > b.length) { + throw new IndexOutOfBoundsException() + } + if (eof) { + return -1 + } + while (error.isEmpty) { + if (buffers == null) { + throw new EOFException("Stream already closed") + } + prepareRead() + if (currentBuffer.hasRemaining) { + val l = readFrom(currentBuffer, b, off, len) + if (l < 0) { + throw new SparkException("Invalid state in shuffle read(buf).") + } + pos += l + return l + } + try { + wait() + } + catch { + case _: InterruptedException => + Thread.currentThread.interrupt() + } + } + throw error.get + } + + override def available(): Int = synchronized { + if (buffers == null) { + throw new EOFException("Stream already closed") + } + prepareRead() + currentBuffer.remaining + } + + override def skip(n: Long): Long = synchronized { + if (eof) { + throw new EOFException("Stream already closed") + } + if (n <= 0) { + return 0 + } + if (n <= currentBuffer.remaining) { + val len = skipIn(currentBuffer, n.toInt) + pos += len + return len + } + val maxSkip = math.min(n, maxBytes - pos) + val skippedFromBuffer = currentBuffer.remaining + val skipFromStream = maxSkip - skippedFromBuffer + currentBuffer.limit(0) + val skipped = skippedFromBuffer + stream.skip(skipFromStream) + pos += skipped + skipped + } + + override def close(): Unit = synchronized { + if (buffers == null) { + return + } + buffers.foreach(b => releaseBuffer(b)) + stream.close() + // Release buffers + buffers = null + } + + private def skipIn(buf: ByteBuffer, n: Int): Int = { + val l = math.min(n, buf.remaining()) + buf.position(buf.position() + l) + l + } + + private def readFrom(buf: ByteBuffer, dst: Array[Byte], off: Int, len: Int): Int = { + val length = math.min(len, buf.remaining()) + System.arraycopy(buf.array(), buf.position() + buf.arrayOffset(), dst, off, length) + buf.position(buf.position() + length) + length + } + + private def readFrom(buf: ByteBuffer): Int = { + if (!buf.hasRemaining) { + return -1 + } + buf.get() & 0xFF + } + + private def writeTo(buf: ByteBuffer, src: InputStream): Int = { + val len = src.read(buf.array(), buf.position() + buf.arrayOffset(), buf.remaining()) + buf.position(buf.position() + len) + len + } +} + +object S3DoubleBufferedStream { + private lazy val pool = new ElasticByteBufferPool() + + private def getBuffer(size: Int) = pool.getBuffer(false, size) + + private def releaseBuffer(buf: ByteBuffer) = pool.putBuffer(buf) +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala b/src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala index 684d6c8..a34be6c 100644 --- a/src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala +++ b/src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala @@ -30,12 +30,9 @@ import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter, import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo import org.apache.spark.util.{CompletionIterator, ThreadUtils} import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.{InterruptibleIterator, SparkConf, SparkEnv, SparkException, TaskContext} +import org.apache.spark.{InterruptibleIterator, SparkConf, SparkEnv, TaskContext} -import java.io.{BufferedInputStream, InputStream} -import java.util.zip.{CheckedInputStream, Checksum} -import scala.concurrent.duration.Duration -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{ExecutionContext} /** * This class was adapted from Apache Spark: BlockStoreShuffleReader. @@ -55,7 +52,6 @@ class S3ShuffleReader[K, C]( private val dispatcher = S3ShuffleDispatcher.get private val dep = handle.dependency - private val bufferInputSize = dispatcher.bufferInputSize private val fetchContinousBlocksInBatch: Boolean = { val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects @@ -77,17 +73,6 @@ class S3ShuffleReader[K, C]( doBatchFetch } - // Source: Cassandra connector for Apache Spark (https://github.com/datastax/spark-cassandra-connector) - // com.datastax.spark.connector.datasource.JoinHelper - // License: Apache 2.0 - // See here for an explanation: http://www.russellspitzer.com/2017/02/27/Concurrency-In-Spark/ - def slidingPrefetchIterator[T](it: Iterator[Future[T]], batchSize: Int): Iterator[T] = { - val (firstElements, lastElement) = it.grouped(batchSize) - .sliding(2) - .span(_ => it.hasNext) - (firstElements.map(_.head) ++ lastElement.flatten).flatten.map(Await.result(_, Duration.Inf)) - } - override def read(): Iterator[Product2[K, C]] = { val serializerInstance = dep.serializer.newInstance() val blocks = computeShuffleBlocks(handle.shuffleId, @@ -97,36 +82,31 @@ class S3ShuffleReader[K, C]( useBlockManager = dispatcher.useBlockManager) val wrappedStreams = new S3ShuffleBlockIterator(blocks) + val bufferSize = dispatcher.bufferSize.toInt // Create a key/value iterator for each stream - val recordIterPromise = wrappedStreams.filterNot(_._2.maxBytes == 0).map { case (blockId, wrappedStream) => + val streamIter = wrappedStreams.filterNot(_._2.maxBytes == 0).map { case (blockId, wrappedStream) => readMetrics.incRemoteBytesRead(wrappedStream.maxBytes) // increase byte count. readMetrics.incRemoteBlocksFetched(1) // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. - Future { - val bufferSize = scala.math.min(wrappedStream.maxBytes, bufferInputSize).toInt - val stream = new BufferedInputStream(wrappedStream, bufferSize) - - // Fill the buffered input stream by reading and then resetting the stream. - stream.mark(bufferSize) - stream.read() - stream.reset() - - val checkedStream = if (dispatcher.checksumEnabled) { - new S3ChecksumValidationStream(blockId, stream, dispatcher.checksumAlgorithm) - } else { - stream - } - - serializerInstance - .deserializeStream(serializerManager.wrapStream(blockId, checkedStream)) - .asKeyValueIterator - }(S3ShuffleReader.asyncExecutionContext) + + val stream = new S3DoubleBufferedStream(wrappedStream, bufferSize) + val checkedStream = if (dispatcher.checksumEnabled) { + new S3ChecksumValidationStream(blockId, stream, dispatcher.checksumAlgorithm) + } else { + stream + } + + (blockId, checkedStream) } - val recordIter = slidingPrefetchIterator(recordIterPromise, dispatcher.prefetchBatchSize).flatten + val recordIter = new PrefetchIterator(streamIter).flatMap { case (blockId, stream) => + serializerInstance + .deserializeStream(serializerManager.wrapStream(blockId, stream)) + .asKeyValueIterator + } // Update the context task metrics for each record read. val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( @@ -204,5 +184,5 @@ class S3ShuffleReader[K, C]( object S3ShuffleReader { private lazy val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("s3-shuffle-reader-async-thread-pool", S3ShuffleDispatcher.get.prefetchThreadPoolSize) - private lazy implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) + lazy implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) }