From b37beff059c37a1288208a62a0c52e05e88058cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Spo=CC=88rri?= Date: Tue, 22 Aug 2023 10:17:47 +0200 Subject: [PATCH] Improve S3BlockStream API. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Pascal SpoĢˆrri --- .../spark/storage/S3ShuffleBlockStream.scala | 95 +++++++++++-------- 1 file changed, 53 insertions(+), 42 deletions(-) diff --git a/src/main/scala/org/apache/spark/storage/S3ShuffleBlockStream.scala b/src/main/scala/org/apache/spark/storage/S3ShuffleBlockStream.scala index ccc6c37..30dd3e4 100644 --- a/src/main/scala/org/apache/spark/storage/S3ShuffleBlockStream.scala +++ b/src/main/scala/org/apache/spark/storage/S3ShuffleBlockStream.scala @@ -9,7 +9,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.shuffle.helper.S3ShuffleDispatcher -import java.io.{IOException, InputStream} +import java.io.{EOFException, IOException, InputStream} /** * InputStream that reads data from a shuffleBlock, mapId and exposes an InputStream from startReduceId to endReduceId. @@ -41,68 +41,79 @@ class S3ShuffleBlockStream( private var streamClosed = startPosition == endPosition // automatically mark stream as closed if length is empty. val maxBytes: Long = endPosition - startPosition - private var numBytes = 0 + private var numBytes: Long = 0 private val singleByteBuffer = new Array[Byte](1) - override def close(): Unit = { + override def close(): Unit = synchronized { if (streamClosed) { return } - this.synchronized { - if (dispatcher.supportsUnbuffer) { - stream.unbuffer() - streamClosed = true - } else { - stream.close() - streamClosed = true - } + if (dispatcher.supportsUnbuffer) { + stream.unbuffer() + streamClosed = true + } else { + stream.close() + streamClosed = true } super.close() } - override def read(): Int = { + override def read(): Int = synchronized { if (streamClosed || numBytes >= maxBytes) { return -1 } - this.synchronized { - try { - stream.readFully(startPosition + numBytes, singleByteBuffer) - numBytes += 1 - if (numBytes >= maxBytes) { - close() - } - return singleByteBuffer(0) - } catch { - case e: IOException => - logError(f"Encountered an unexpected IOException: ${e.toString}") - close() - return -1 + try { + stream.readFully(startPosition + numBytes, singleByteBuffer) + numBytes += 1 + if (numBytes >= maxBytes) { + close() } + return singleByteBuffer(0) + } catch { + case e: IOException => + logError(f"Encountered an unexpected IOException: ${e.toString}") + close() + return -1 } } - override def read(b: Array[Byte], off: Int, len: Int): Int = { + override def read(b: Array[Byte], off: Int, len: Int): Int = synchronized { if (streamClosed || numBytes >= maxBytes) { return -1 } - this.synchronized { - val maxLength = (maxBytes - numBytes).toInt - assert(maxLength >= 0) - val length = math.min(maxLength, len) - try { - stream.readFully(startPosition + numBytes, b, off, length) - numBytes += length - if (numBytes >= maxBytes) { - close() - } - return length - } catch { - case e: IOException => - logError(f"Encountered an unexpected IOException: ${e.toString}") - close() - return -1 + val maxLength = (maxBytes - numBytes).toInt + assert(maxLength >= 0) + val length = math.min(maxLength, len) + try { + stream.readFully(startPosition + numBytes, b, off, length) + numBytes += length + if (numBytes >= maxBytes) { + close() } + return length + } catch { + case e: IOException => + logError(f"Encountered an unexpected IOException: ${e.toString}") + close() + return -1 + } + } + + override def skip(n: Long): Long = synchronized { + if (streamClosed || numBytes >= maxBytes || n <= 0) { + return 0 + } + val toSkip = math.min(maxBytes - numBytes, n) + val skipped = stream.skip(toSkip) + numBytes += skipped + skipped + } + + override def available(): Int = synchronized { + if (streamClosed) { + return 0 } + return (maxBytes - numBytes).toInt } }