Skip to content

Commit

Permalink
Implement stream prefetching and double-buffering.
Browse files Browse the repository at this point in the history
Signed-off-by: Pascal Spörri <[email protected]>
  • Loading branch information
pspoerri committed Aug 24, 2023
1 parent b37beff commit 13b94ef
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 41 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ Changing these values might have an impact on performance.

- `spark.shuffle.s3.bufferSize`: Default size of the buffered output streams (default: `32768`,
uses `spark.shuffle.file.buffer` as default)
- `spark.shuffle.s3.bufferInputSize`: Maximum size of buffered input streams (default: `209715200`,
uses `spark.network.maxRemoteBlockSizeFetchToMem` as default)
- `spark.shuffle.s3.cachePartitionLengths`: Cache partition lengths in memory (default: `true`)
- `spark.shuffle.s3.cacheChecksums`: Cache checksums in memory (default: `true`)
- `spark.shuffle.s3.cleanup`: Cleanup the shuffle files (default: `true`)
Expand Down
36 changes: 36 additions & 0 deletions src/main/scala/org/apache/spark/storage/PrefetchIterator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright 2023- IBM Inc. All rights reserved
* SPDX-License-Identifier: Apache2.0
*/

package org.apache.spark.storage

import scala.collection.AbstractIterator

class PrefetchIterator[A](iter: Iterator[A]) extends AbstractIterator[A] {

private var value: A = _
private var valueDefined: Boolean = false

override def hasNext: Boolean = {
populateNext()
valueDefined
}

private def populateNext(): Unit = {
if (valueDefined) {
return
}
if (iter.hasNext) {
value = iter.next()
valueDefined = true
}
}

override def next(): A = {
val result = value
valueDefined = false
populateNext()
result
}
}
228 changes: 228 additions & 0 deletions src/main/scala/org/apache/spark/storage/S3DoubleBufferedStream.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/**
* Copyright 2023- IBM Inc. All rights reserved
* SPDX-License-Identifier: Apache2.0
*/

package org.apache.spark.storage

import org.apache.hadoop.io.ElasticByteBufferPool
import org.apache.spark.SparkException
import org.apache.spark.storage.S3DoubleBufferedStream.{getBuffer, releaseBuffer}

import java.io.{EOFException, InputStream}
import java.nio.ByteBuffer
import scala.concurrent.Future
import scala.util.{Failure, Success, Try}

class S3DoubleBufferedStream(stream: S3ShuffleBlockStream, bufferSize: Int) extends InputStream {
private var buffers: Array[ByteBuffer] = {
val array = new Array[ByteBuffer](2)
array(0) = getBuffer(bufferSize)
array(1) = getBuffer(bufferSize)
// Mark buffers as empty
array.foreach(b => {
b.clear().limit(0)
})
array
}
private var prefetching = Array.fill(2)(false)

var streamClosed = false
var pos: Long = 0
val maxBytes: Long = stream.maxBytes

private var bufIdx: Int = 0
var dataAvailable: Boolean = false
var error: Option[Throwable] = None

doPrefetch(nextBuffer)

private def currentBuffer: ByteBuffer = synchronized {
buffers(bufIdx)
}

private def nextBuffer: ByteBuffer = synchronized {
buffers((bufIdx + 1) % buffers.length)
}

private def swap() = synchronized {
bufIdx = (bufIdx + 1) % buffers.length
}

private def eof: Boolean = synchronized {
if (buffers == null) {
throw new EOFException("Stream already closed")
}
pos >= maxBytes
}

private def prepareRead(): Unit = synchronized {
if (!currentBuffer.hasRemaining && dataAvailable) {
swap()
dataAvailable = false
doPrefetch(nextBuffer)
}
}

private def doPrefetch(buffer: ByteBuffer): Unit = {
if (stream.available() == 0) {
// no data available
return
}
val fut = Future[Int] {
buffer.clear()
var len: Int = 0
do {
len = writeTo(buffer, stream)
if (len < 0) {
throw new EOFException()
}
} while (len == 0)
buffer.flip()
len
}(S3ShuffleReader.asyncExecutionContext)
fut.onComplete(onCompletePrefetch)(S3ShuffleReader.asyncExecutionContext)
}

private def onCompletePrefetch(result: Try[Int]): Unit = synchronized {
result match {
case Failure(exception) => error = Some(exception)
case Success(len) =>
dataAvailable = true
}
notifyAll()
}

override def read(): Int = synchronized {
if (eof) {
return -1
}

while (error.isEmpty) {
if (buffers == null) {
throw new EOFException("Stream already closed")
}
prepareRead()
if (currentBuffer.hasRemaining) {
val l = readFrom(currentBuffer)
if (l < 0) {
throw new SparkException("Invalid state in shuffle read.")
}
pos += 1
return l
}
try {
wait()
}
catch {
case _: InterruptedException =>
Thread.currentThread.interrupt()
}
}
throw error.get
}

override def read(b: Array[Byte], off: Int, len: Int): Int = synchronized {
if (off < 0 || len < 0 || off + len < 0 || off + len > b.length) {
throw new IndexOutOfBoundsException()
}
if (eof) {
return -1
}
while (error.isEmpty) {
if (buffers == null) {
throw new EOFException("Stream already closed")
}
prepareRead()
if (currentBuffer.hasRemaining) {
val l = readFrom(currentBuffer, b, off, len)
if (l < 0) {
throw new SparkException("Invalid state in shuffle read(buf).")
}
pos += l
return l
}
try {
wait()
}
catch {
case _: InterruptedException =>
Thread.currentThread.interrupt()
}
}
throw error.get
}

override def available(): Int = synchronized {
if (buffers == null) {
throw new EOFException("Stream already closed")
}
prepareRead()
currentBuffer.remaining
}

override def skip(n: Long): Long = synchronized {
if (eof) {
throw new EOFException("Stream already closed")
}
if (n <= 0) {
return 0
}
if (n <= currentBuffer.remaining) {
val len = skipIn(currentBuffer, n.toInt)
pos += len
return len
}
val maxSkip = math.min(n, maxBytes - pos)
val skippedFromBuffer = currentBuffer.remaining
val skipFromStream = maxSkip - skippedFromBuffer
currentBuffer.limit(0)
val skipped = skippedFromBuffer + stream.skip(skipFromStream)
pos += skipped
skipped
}

override def close(): Unit = synchronized {
if (buffers == null) {
return
}
buffers.foreach(b => releaseBuffer(b))
stream.close()
// Release buffers
buffers = null
}

private def skipIn(buf: ByteBuffer, n: Int): Int = {
val l = math.min(n, buf.remaining())
buf.position(buf.position() + l)
l
}

private def readFrom(buf: ByteBuffer, dst: Array[Byte], off: Int, len: Int): Int = {
val length = math.min(len, buf.remaining())
System.arraycopy(buf.array(), buf.position() + buf.arrayOffset(), dst, off, length)
buf.position(buf.position() + length)
length
}

private def readFrom(buf: ByteBuffer): Int = {
if (!buf.hasRemaining) {
return -1
}
buf.get() & 0xFF
}

private def writeTo(buf: ByteBuffer, src: InputStream): Int = {
val len = src.read(buf.array(), buf.position() + buf.arrayOffset(), buf.remaining())
buf.position(buf.position() + len)
len
}
}

object S3DoubleBufferedStream {
private lazy val pool = new ElasticByteBufferPool()

private def getBuffer(size: Int) = pool.getBuffer(false, size)

private def releaseBuffer(buf: ByteBuffer) = pool.putBuffer(buf)
}
58 changes: 19 additions & 39 deletions src/main/scala/org/apache/spark/storage/S3ShuffleReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,9 @@ import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter,
import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
import org.apache.spark.util.{CompletionIterator, ThreadUtils}
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.{InterruptibleIterator, SparkConf, SparkEnv, SparkException, TaskContext}
import org.apache.spark.{InterruptibleIterator, SparkConf, SparkEnv, TaskContext}

import java.io.{BufferedInputStream, InputStream}
import java.util.zip.{CheckedInputStream, Checksum}
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.{ExecutionContext}

/**
* This class was adapted from Apache Spark: BlockStoreShuffleReader.
Expand All @@ -55,7 +52,6 @@ class S3ShuffleReader[K, C](

private val dispatcher = S3ShuffleDispatcher.get
private val dep = handle.dependency
private val bufferInputSize = dispatcher.bufferInputSize

private val fetchContinousBlocksInBatch: Boolean = {
val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects
Expand All @@ -77,17 +73,6 @@ class S3ShuffleReader[K, C](
doBatchFetch
}

// Source: Cassandra connector for Apache Spark (https://github.com/datastax/spark-cassandra-connector)
// com.datastax.spark.connector.datasource.JoinHelper
// License: Apache 2.0
// See here for an explanation: http://www.russellspitzer.com/2017/02/27/Concurrency-In-Spark/
def slidingPrefetchIterator[T](it: Iterator[Future[T]], batchSize: Int): Iterator[T] = {
val (firstElements, lastElement) = it.grouped(batchSize)
.sliding(2)
.span(_ => it.hasNext)
(firstElements.map(_.head) ++ lastElement.flatten).flatten.map(Await.result(_, Duration.Inf))
}

override def read(): Iterator[Product2[K, C]] = {
val serializerInstance = dep.serializer.newInstance()
val blocks = computeShuffleBlocks(handle.shuffleId,
Expand All @@ -97,36 +82,31 @@ class S3ShuffleReader[K, C](
useBlockManager = dispatcher.useBlockManager)

val wrappedStreams = new S3ShuffleBlockIterator(blocks)
val bufferSize = dispatcher.bufferSize.toInt

// Create a key/value iterator for each stream
val recordIterPromise = wrappedStreams.filterNot(_._2.maxBytes == 0).map { case (blockId, wrappedStream) =>
val streamIter = wrappedStreams.filterNot(_._2.maxBytes == 0).map { case (blockId, wrappedStream) =>
readMetrics.incRemoteBytesRead(wrappedStream.maxBytes) // increase byte count.
readMetrics.incRemoteBlocksFetched(1)
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
Future {
val bufferSize = scala.math.min(wrappedStream.maxBytes, bufferInputSize).toInt
val stream = new BufferedInputStream(wrappedStream, bufferSize)

// Fill the buffered input stream by reading and then resetting the stream.
stream.mark(bufferSize)
stream.read()
stream.reset()

val checkedStream = if (dispatcher.checksumEnabled) {
new S3ChecksumValidationStream(blockId, stream, dispatcher.checksumAlgorithm)
} else {
stream
}

serializerInstance
.deserializeStream(serializerManager.wrapStream(blockId, checkedStream))
.asKeyValueIterator
}(S3ShuffleReader.asyncExecutionContext)

val stream = new S3DoubleBufferedStream(wrappedStream, bufferSize)
val checkedStream = if (dispatcher.checksumEnabled) {
new S3ChecksumValidationStream(blockId, stream, dispatcher.checksumAlgorithm)
} else {
stream
}

(blockId, checkedStream)
}

val recordIter = slidingPrefetchIterator(recordIterPromise, dispatcher.prefetchBatchSize).flatten
val recordIter = new PrefetchIterator(streamIter).flatMap { case (blockId, stream) =>
serializerInstance
.deserializeStream(serializerManager.wrapStream(blockId, stream))
.asKeyValueIterator
}

// Update the context task metrics for each record read.
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
Expand Down Expand Up @@ -204,5 +184,5 @@ class S3ShuffleReader[K, C](

object S3ShuffleReader {
private lazy val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("s3-shuffle-reader-async-thread-pool", S3ShuffleDispatcher.get.prefetchThreadPoolSize)
private lazy implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool)
lazy implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool)
}

0 comments on commit 13b94ef

Please sign in to comment.