Skip to content

Commit

Permalink
Improve S3BlockStream API.
Browse files Browse the repository at this point in the history
Signed-off-by: Pascal Spörri <[email protected]>
  • Loading branch information
pspoerri committed Aug 22, 2023
1 parent ed2404e commit b37beff
Showing 1 changed file with 53 additions and 42 deletions.
95 changes: 53 additions & 42 deletions src/main/scala/org/apache/spark/storage/S3ShuffleBlockStream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
import org.apache.spark.shuffle.helper.S3ShuffleDispatcher

import java.io.{IOException, InputStream}
import java.io.{EOFException, IOException, InputStream}

/**
* InputStream that reads data from a shuffleBlock, mapId and exposes an InputStream from startReduceId to endReduceId.
Expand Down Expand Up @@ -41,68 +41,79 @@ class S3ShuffleBlockStream(
private var streamClosed = startPosition == endPosition // automatically mark stream as closed if length is empty.

val maxBytes: Long = endPosition - startPosition
private var numBytes = 0
private var numBytes: Long = 0

private val singleByteBuffer = new Array[Byte](1)

override def close(): Unit = {
override def close(): Unit = synchronized {
if (streamClosed) {
return
}
this.synchronized {
if (dispatcher.supportsUnbuffer) {
stream.unbuffer()
streamClosed = true
} else {
stream.close()
streamClosed = true
}
if (dispatcher.supportsUnbuffer) {
stream.unbuffer()
streamClosed = true
} else {
stream.close()
streamClosed = true
}
super.close()
}

override def read(): Int = {
override def read(): Int = synchronized {
if (streamClosed || numBytes >= maxBytes) {
return -1
}
this.synchronized {
try {
stream.readFully(startPosition + numBytes, singleByteBuffer)
numBytes += 1
if (numBytes >= maxBytes) {
close()
}
return singleByteBuffer(0)
} catch {
case e: IOException =>
logError(f"Encountered an unexpected IOException: ${e.toString}")
close()
return -1
try {
stream.readFully(startPosition + numBytes, singleByteBuffer)
numBytes += 1
if (numBytes >= maxBytes) {
close()
}
return singleByteBuffer(0)
} catch {
case e: IOException =>
logError(f"Encountered an unexpected IOException: ${e.toString}")
close()
return -1
}
}

override def read(b: Array[Byte], off: Int, len: Int): Int = {
override def read(b: Array[Byte], off: Int, len: Int): Int = synchronized {
if (streamClosed || numBytes >= maxBytes) {
return -1
}
this.synchronized {
val maxLength = (maxBytes - numBytes).toInt
assert(maxLength >= 0)
val length = math.min(maxLength, len)
try {
stream.readFully(startPosition + numBytes, b, off, length)
numBytes += length
if (numBytes >= maxBytes) {
close()
}
return length
} catch {
case e: IOException =>
logError(f"Encountered an unexpected IOException: ${e.toString}")
close()
return -1
val maxLength = (maxBytes - numBytes).toInt
assert(maxLength >= 0)
val length = math.min(maxLength, len)
try {
stream.readFully(startPosition + numBytes, b, off, length)
numBytes += length
if (numBytes >= maxBytes) {
close()
}
return length
} catch {
case e: IOException =>
logError(f"Encountered an unexpected IOException: ${e.toString}")
close()
return -1
}
}

override def skip(n: Long): Long = synchronized {
if (streamClosed || numBytes >= maxBytes || n <= 0) {
return 0
}
val toSkip = math.min(maxBytes - numBytes, n)
val skipped = stream.skip(toSkip)
numBytes += skipped
skipped
}

override def available(): Int = synchronized {
if (streamClosed) {
return 0
}
return (maxBytes - numBytes).toInt
}
}

0 comments on commit b37beff

Please sign in to comment.