-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement stream prefetching and double-buffering.
Signed-off-by: Pascal Spörri <[email protected]>
- Loading branch information
Showing
4 changed files
with
283 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
src/main/scala/org/apache/spark/storage/PrefetchIterator.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
/** | ||
* Copyright 2023- IBM Inc. All rights reserved | ||
* SPDX-License-Identifier: Apache2.0 | ||
*/ | ||
|
||
package org.apache.spark.storage | ||
|
||
import scala.collection.AbstractIterator | ||
|
||
class PrefetchIterator[A](iter: Iterator[A]) extends AbstractIterator[A] { | ||
|
||
private var value: A = _ | ||
private var valueDefined: Boolean = false | ||
|
||
override def hasNext: Boolean = { | ||
populateNext() | ||
valueDefined | ||
} | ||
|
||
private def populateNext(): Unit = { | ||
if (valueDefined) { | ||
return | ||
} | ||
if (iter.hasNext) { | ||
value = iter.next() | ||
valueDefined = true | ||
} | ||
} | ||
|
||
override def next(): A = { | ||
val result = value | ||
valueDefined = false | ||
populateNext() | ||
result | ||
} | ||
} |
228 changes: 228 additions & 0 deletions
228
src/main/scala/org/apache/spark/storage/S3DoubleBufferedStream.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
/** | ||
* Copyright 2023- IBM Inc. All rights reserved | ||
* SPDX-License-Identifier: Apache2.0 | ||
*/ | ||
|
||
package org.apache.spark.storage | ||
|
||
import org.apache.hadoop.io.ElasticByteBufferPool | ||
import org.apache.spark.SparkException | ||
import org.apache.spark.storage.S3DoubleBufferedStream.{getBuffer, releaseBuffer} | ||
|
||
import java.io.{EOFException, InputStream} | ||
import java.nio.ByteBuffer | ||
import scala.concurrent.Future | ||
import scala.util.{Failure, Success, Try} | ||
|
||
class S3DoubleBufferedStream(stream: S3ShuffleBlockStream, bufferSize: Int) extends InputStream { | ||
private var buffers: Array[ByteBuffer] = { | ||
val array = new Array[ByteBuffer](2) | ||
array(0) = getBuffer(bufferSize) | ||
array(1) = getBuffer(bufferSize) | ||
// Mark buffers as empty | ||
array.foreach(b => { | ||
b.clear().limit(0) | ||
}) | ||
array | ||
} | ||
private var prefetching = Array.fill(2)(false) | ||
|
||
var streamClosed = false | ||
var pos: Long = 0 | ||
val maxBytes: Long = stream.maxBytes | ||
|
||
private var bufIdx: Int = 0 | ||
var dataAvailable: Boolean = false | ||
var error: Option[Throwable] = None | ||
|
||
doPrefetch(nextBuffer) | ||
|
||
private def currentBuffer: ByteBuffer = synchronized { | ||
buffers(bufIdx) | ||
} | ||
|
||
private def nextBuffer: ByteBuffer = synchronized { | ||
buffers((bufIdx + 1) % buffers.length) | ||
} | ||
|
||
private def swap() = synchronized { | ||
bufIdx = (bufIdx + 1) % buffers.length | ||
} | ||
|
||
private def eof: Boolean = synchronized { | ||
if (buffers == null) { | ||
throw new EOFException("Stream already closed") | ||
} | ||
pos >= maxBytes | ||
} | ||
|
||
private def prepareRead(): Unit = synchronized { | ||
if (!currentBuffer.hasRemaining && dataAvailable) { | ||
swap() | ||
dataAvailable = false | ||
doPrefetch(nextBuffer) | ||
} | ||
} | ||
|
||
private def doPrefetch(buffer: ByteBuffer): Unit = { | ||
if (stream.available() == 0) { | ||
// no data available | ||
return | ||
} | ||
val fut = Future[Int] { | ||
buffer.clear() | ||
var len: Int = 0 | ||
do { | ||
len = writeTo(buffer, stream) | ||
if (len < 0) { | ||
throw new EOFException() | ||
} | ||
} while (len == 0) | ||
buffer.flip() | ||
len | ||
}(S3ShuffleReader.asyncExecutionContext) | ||
fut.onComplete(onCompletePrefetch)(S3ShuffleReader.asyncExecutionContext) | ||
} | ||
|
||
private def onCompletePrefetch(result: Try[Int]): Unit = synchronized { | ||
result match { | ||
case Failure(exception) => error = Some(exception) | ||
case Success(len) => | ||
dataAvailable = true | ||
} | ||
notifyAll() | ||
} | ||
|
||
override def read(): Int = synchronized { | ||
if (eof) { | ||
return -1 | ||
} | ||
|
||
while (error.isEmpty) { | ||
if (buffers == null) { | ||
throw new EOFException("Stream already closed") | ||
} | ||
prepareRead() | ||
if (currentBuffer.hasRemaining) { | ||
val l = readFrom(currentBuffer) | ||
if (l < 0) { | ||
throw new SparkException("Invalid state in shuffle read.") | ||
} | ||
pos += 1 | ||
return l | ||
} | ||
try { | ||
wait() | ||
} | ||
catch { | ||
case _: InterruptedException => | ||
Thread.currentThread.interrupt() | ||
} | ||
} | ||
throw error.get | ||
} | ||
|
||
override def read(b: Array[Byte], off: Int, len: Int): Int = synchronized { | ||
if (off < 0 || len < 0 || off + len < 0 || off + len > b.length) { | ||
throw new IndexOutOfBoundsException() | ||
} | ||
if (eof) { | ||
return -1 | ||
} | ||
while (error.isEmpty) { | ||
if (buffers == null) { | ||
throw new EOFException("Stream already closed") | ||
} | ||
prepareRead() | ||
if (currentBuffer.hasRemaining) { | ||
val l = readFrom(currentBuffer, b, off, len) | ||
if (l < 0) { | ||
throw new SparkException("Invalid state in shuffle read(buf).") | ||
} | ||
pos += l | ||
return l | ||
} | ||
try { | ||
wait() | ||
} | ||
catch { | ||
case _: InterruptedException => | ||
Thread.currentThread.interrupt() | ||
} | ||
} | ||
throw error.get | ||
} | ||
|
||
override def available(): Int = synchronized { | ||
if (buffers == null) { | ||
throw new EOFException("Stream already closed") | ||
} | ||
prepareRead() | ||
currentBuffer.remaining | ||
} | ||
|
||
override def skip(n: Long): Long = synchronized { | ||
if (eof) { | ||
throw new EOFException("Stream already closed") | ||
} | ||
if (n <= 0) { | ||
return 0 | ||
} | ||
if (n <= currentBuffer.remaining) { | ||
val len = skipIn(currentBuffer, n.toInt) | ||
pos += len | ||
return len | ||
} | ||
val maxSkip = math.min(n, maxBytes - pos) | ||
val skippedFromBuffer = currentBuffer.remaining | ||
val skipFromStream = maxSkip - skippedFromBuffer | ||
currentBuffer.limit(0) | ||
val skipped = skippedFromBuffer + stream.skip(skipFromStream) | ||
pos += skipped | ||
skipped | ||
} | ||
|
||
override def close(): Unit = synchronized { | ||
if (buffers == null) { | ||
return | ||
} | ||
buffers.foreach(b => releaseBuffer(b)) | ||
stream.close() | ||
// Release buffers | ||
buffers = null | ||
} | ||
|
||
private def skipIn(buf: ByteBuffer, n: Int): Int = { | ||
val l = math.min(n, buf.remaining()) | ||
buf.position(buf.position() + l) | ||
l | ||
} | ||
|
||
private def readFrom(buf: ByteBuffer, dst: Array[Byte], off: Int, len: Int): Int = { | ||
val length = math.min(len, buf.remaining()) | ||
System.arraycopy(buf.array(), buf.position() + buf.arrayOffset(), dst, off, length) | ||
buf.position(buf.position() + length) | ||
length | ||
} | ||
|
||
private def readFrom(buf: ByteBuffer): Int = { | ||
if (!buf.hasRemaining) { | ||
return -1 | ||
} | ||
buf.get() & 0xFF | ||
} | ||
|
||
private def writeTo(buf: ByteBuffer, src: InputStream): Int = { | ||
val len = src.read(buf.array(), buf.position() + buf.arrayOffset(), buf.remaining()) | ||
buf.position(buf.position() + len) | ||
len | ||
} | ||
} | ||
|
||
object S3DoubleBufferedStream { | ||
private lazy val pool = new ElasticByteBufferPool() | ||
|
||
private def getBuffer(size: Int) = pool.getBuffer(false, size) | ||
|
||
private def releaseBuffer(buf: ByteBuffer) = pool.putBuffer(buf) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters