From 4e3624e683309815b1d3187e9d60064ea61b5354 Mon Sep 17 00:00:00 2001 From: Zach Puller Date: Tue, 17 Dec 2024 10:01:59 -0800 Subject: [PATCH] atomic bool Signed-off-by: Zach Puller --- .../spark/rapids/spill/SpillFramework.scala | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 8cc27686efa..d4fdf2be3f9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -23,6 +23,7 @@ import java.nio.file.StandardOpenOption import java.util import java.util.UUID import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable @@ -177,8 +178,7 @@ trait SpillableHandle extends StoreHandle { */ def spill(): Long - private var spilling = false - private val spillLock = new Object() + private val spilling = new AtomicBoolean(false) /** * Method used to atomically check and set the spilling state, so that anyone who wants to @@ -191,20 +191,14 @@ trait SpillableHandle extends StoreHandle { * @param s whether the caller is trying to spill or not (ie finished) * @return whether the caller is allowed to spill (or true if s is false) */ - def setSpilling(s: Boolean): Boolean = spillLock.synchronized { + def setSpilling(s: Boolean): Boolean = { if (!s) { - // done spilling, nothing to check - spilling = false + if (!spilling.getAndSet(false)) { + throw new IllegalStateException("tried to setSpilling to false while not spilling!") + } true } else { - if (!spilling) { - // we may spill - spilling = true - true - } else { - // someone else is already spilling - false - } + !spilling.getAndSet(true) } } @@ -217,9 +211,7 @@ trait SpillableHandle extends StoreHandle { */ private[spill] def spillable: Boolean = { if (approxSizeInBytes > 0) { - spillLock.synchronized { - !spilling - } + !spilling.get() } else { false }