diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 901bdae9a05..420af625486 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -378,6 +378,7 @@ class SpillableHostBufferHandle private ( 0L } } + // Make sure to only set spilling to false if it was previously set to true setSpilling(false) releaseHostResource() spilled @@ -499,6 +500,7 @@ class SpillableDeviceBufferHandle private ( } } } + // Make sure to only set spilling to false if it was previously set to true setSpilling(false) spilled } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala index bfd58e33408..332cafa7abf 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala @@ -1102,7 +1102,7 @@ class SpillFrameworkSuite testBufferFileDeletion(canShareDiskPaths = true) } - test("handle cannot spill once marked as spilling by another thread") { + test("device handle cannot spill once marked as spilling by another thread") { val (ct, _) = buildContiguousTable() val buff = ct.getBuffer buff.incRefCount() @@ -1126,4 +1126,22 @@ class SpillFrameworkSuite } } + test("host handle cannot spill once marked as spilling by another thread") { + withResource(SpillableHostBufferHandle(HostMemoryBuffer.allocate(512))) { hostHandle => + assert(hostHandle.spillable) + + // we're just simulating the another thread coming in and spilling here + // so we don't have to worry about a race + assert(hostHandle.setSpilling(true)) + // the "other thread is spilling" so we cannot claim the spill lock + assert(!hostHandle.setSpilling(true)) + assertResult(0)(SpillFramework.stores.hostStore.spill(hostHandle.approxSizeInBytes)) + assert(hostHandle.setSpilling(false)) + + // now that nobody else is spilling (but the buffer is still not actually spilled), + // we will succeed + assertResult(512)(SpillFramework.stores.hostStore.spill(hostHandle.approxSizeInBytes)) + } + } + }