Skip to content

Commit

Permalink
no longer rely on setSpilling
Browse files Browse the repository at this point in the history
Signed-off-by: Zach Puller <[email protected]>
  • Loading branch information
zpuller committed Dec 19, 2024
1 parent 47e55ad commit d568860
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import java.nio.file.StandardOpenOption
import java.util
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable

Expand Down Expand Up @@ -178,44 +177,14 @@ trait SpillableHandle extends StoreHandle {
*/
def spill(): Long

private val spilling = new AtomicBoolean(false)

/**
* Method used to atomically check and set the spilling state, so that anyone who wants to
* actually perform a spill can ensure they are the only one spilling, without having to block
* on the actual spill operation (IO). Only someone who has set spilling to true to perform their
* spill may set it back to false when they are done. (Visible for tests)
*
* This is a separate check from spillable, which actually checks the state of the buffer handle
*
* @param s whether the caller is trying to spill or not (ie finished)
* @return whether the caller is allowed to spill (or true if s is false)
*/
def setSpilling(s: Boolean): Boolean = {
if (!s) {
if (!spilling.getAndSet(false)) {
throw new IllegalStateException("tried to setSpilling to false while not spilling!")
}
true
} else {
!spilling.getAndSet(true)
}
}

/**
* Method used to determine whether a handle tracks an object that could be spilled
* @note At the level of `SpillableHandle`, the only requirement of spillability
* is that the size of the handle is > 0. `approxSizeInBytes` is known at
* construction, and is immutable.
* @return true if currently spillable, false otherwise
*/
private[spill] def spillable: Boolean = {
if (approxSizeInBytes > 0) {
!spilling.get()
} else {
false
}
}
private[spill] def spillable: Boolean = approxSizeInBytes > 0
}

/**
Expand Down Expand Up @@ -349,34 +318,37 @@ class SpillableHostBufferHandle private (
materialized
}

private var toSpill: HostMemoryBuffer = _
private var toSpill: Option[HostMemoryBuffer] = None
override def releaseHostResource(): Unit = {
super.releaseHostResource()
synchronized {
if (toSpill != null) {
toSpill.close()
toSpill = null
}
toSpill.foreach(_.close())
toSpill = None
}
}

override def spill(): Long = {
if (!spillable || !setSpilling(true)) {
if (!spillable) {
0L
} else {
synchronized {
if (disk.isEmpty && host.isDefined) {
toSpill = host.get
toSpill.incRefCount()
val thisThreadSpills = synchronized {
if (disk.isEmpty && host.isDefined && toSpill.isEmpty) {
toSpill = host
toSpill.get.incRefCount()
true
} else {
false
}
}
val spilled = if (toSpill != null) {
withResource(toSpill) { _ =>
val spilled = if (thisThreadSpills) {
val buf = toSpill.get
withResource(buf) { _ =>
withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder =>
val outputChannel = diskHandleBuilder.getChannel
// the spill IO is non-blocking as it won't impact dev or host directly
// instead we "atomically" swap the buffers below once they are ready
GpuTaskMetrics.get.spillToDiskTime {
val iter = new HostByteBufferIterator(host.get)
val iter = new HostByteBufferIterator(buf)
iter.foreach { bb =>
try {
while (bb.hasRemaining) {
Expand All @@ -398,8 +370,6 @@ class SpillableHostBufferHandle private (
} else {
0
}
// Make sure to only set spilling to false if it was previously set to true
setSpilling(false)
releaseHostResource()
spilled
}
Expand Down Expand Up @@ -507,33 +477,35 @@ class SpillableDeviceBufferHandle private (
materialized
}

private var toSpill: DeviceMemoryBuffer = _
private var toSpill: Option[DeviceMemoryBuffer] = None
override def releaseDeviceResource(): Unit = {
super.releaseDeviceResource()
synchronized {
if (toSpill != null) {
toSpill.close()
toSpill = null
}
toSpill.foreach(_.close())
toSpill = None
}
}

override def spill(): Long = {
if (!spillable || !setSpilling(true)) {
if (!spillable) {
0L
} else {
synchronized {
if (host.isEmpty && dev.isDefined) {
toSpill = dev.get
toSpill.incRefCount()
val thisThreadSpills = synchronized {
if (host.isEmpty && dev.isDefined && toSpill.isEmpty) {
toSpill = dev
toSpill.get.incRefCount()
true
} else {
false
}
}
val spilled = if (toSpill != null) {
withResource(toSpill) { _ =>
if (thisThreadSpills) {
val buf = toSpill.get
withResource(buf) { _ =>
// the spill IO is non-blocking as it won't impact dev or host directly
// instead we "atomically" swap the buffers below once they are ready
val stagingHost =
Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff(dev.get))
Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff(buf))
synchronized {
host = stagingHost
dev = None
Expand All @@ -543,9 +515,6 @@ class SpillableDeviceBufferHandle private (
} else {
0
}
// Make sure to only set spilling to false if it was previously set to true
setSpilling(false)
spilled
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1102,46 +1102,4 @@ class SpillFrameworkSuite
testBufferFileDeletion(canShareDiskPaths = true)
}

test("device handle cannot spill once marked as spilling by another thread") {
val (ct, _) = buildContiguousTable()
val buff = ct.getBuffer
buff.incRefCount()
withResource(SpillableDeviceBufferHandle(buff)) { handle =>
withResource(ct) { _ =>
assert(!handle.spillable)
}
assert(handle.spillable)

// we're just simulating the another thread coming in and spilling here
// so we don't have to worry about a race
assert(handle.setSpilling(true))
// the "other thread is spilling" so we cannot claim the spill lock
assert(!handle.setSpilling(true))
assertResult(0)(SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes))
assert(handle.setSpilling(false))

// now that nobody else is spilling (but the buffer is still not actually spilled),
// we will succeed
assertResult(512)(SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes))
}
}

test("host handle cannot spill once marked as spilling by another thread") {
withResource(SpillableHostBufferHandle(HostMemoryBuffer.allocate(512))) { hostHandle =>
assert(hostHandle.spillable)

// we're just simulating the another thread coming in and spilling here
// so we don't have to worry about a race
assert(hostHandle.setSpilling(true))
// the "other thread is spilling" so we cannot claim the spill lock
assert(!hostHandle.setSpilling(true))
assertResult(0)(SpillFramework.stores.hostStore.spill(hostHandle.approxSizeInBytes))
assert(hostHandle.setSpilling(false))

// now that nobody else is spilling (but the buffer is still not actually spilled),
// we will succeed
assertResult(512)(SpillFramework.stores.hostStore.spill(hostHandle.approxSizeInBytes))
}
}

}

0 comments on commit d568860

Please sign in to comment.