Skip to content

Commit

Permalink
[SPARK-50705][SQL] Make QueryPlan lock-free
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Replace a group of `lazy val` in `QueryPlan` with new lock-free helper classes.
Not all `lazy val`s are replaced in this PR, we will need to handle remaining `lazy val`s together with its subclasses to make it take effect.

### Why are the changes needed?
for the deadlock issues on query plan nodes:

- sometimes we want the plan node methods to use a coarse lock (just lock the plan node itself), as these methods (expressions , output , references , deterministic , schema , canonicalized , etc.) may call each other, so using a coarse lock can prevent deadlocks.
- sometimes we want these methods to use fine-grained locks, because these methods may call each other of a parent/child plan node. If you traverse the tree with different directions at the same time, it's likely to hit deadlock using coarse lock.

the only solution is to not use locks if possible

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added unit tests
Manually test against a deadlock case

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #49212 from zhengruifeng/query_plan_atom_refs.

Lead-authored-by: Ruifeng Zheng <[email protected]>
Co-authored-by: Josh Rosen <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
2 people authored and cloud-fan committed Jan 9, 2025
1 parent 0123a5e commit f8e8fcc
Show file tree
Hide file tree
Showing 7 changed files with 402 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,46 @@
*/
package org.apache.spark.util

import java.util.concurrent.atomic.AtomicReference

/**
* Construct to lazily initialize a variable.
* This may be helpful for avoiding deadlocks in certain scenarios. For example,
* a) Thread 1 entered a synchronized method, grabbing a coarse lock on the parent object.
* b) Thread 2 gets spawned off, and tries to initialize a lazy value on the same parent object
* (in our case, this was the logger). This causes scala to also try to grab a coarse lock on
* the parent object.
* c) If thread 1 waits for thread 2 to join, a deadlock occurs.
* The main difference between this and [[LazyTry]] is that this does not cache failures.
* A lock-free implementation of a lazily-initialized variable.
* If there are concurrent initializations then the `compute()` function may be invoked
* multiple times. However, only a single `compute()` result will be stored and all readers
* will receive the same result object instance.
*
* This may be helpful for avoiding deadlocks in certain scenarios where exactly-once
* value computation is not a hard requirement.
*
* @note
* This helper class has additional requirements on the compute function:
* 1) The compute function MUST not return null;
* 2) The computation failure is not cached.
*
* @note
* Scala 3 uses a different implementation of lazy vals which doesn't have this problem.
* Please refer to <a
* href="https://docs.scala-lang.org/scala3/reference/changed-features/lazy-vals-init.html">Lazy
* Vals Initialization</a> for more details.
*/
private[spark] class TransientLazy[T](initializer: => T) extends Serializable {
private[spark] class BestEffortLazyVal[T <: AnyRef](
@volatile private[this] var compute: () => T) extends Serializable {

@transient
private[this] lazy val value: T = initializer
private[this] val cached: AtomicReference[T] = new AtomicReference(null.asInstanceOf[T])

def apply(): T = {
value
val value = cached.get()
if (value != null) {
value
} else {
val f = compute
if (f != null) {
val newValue = f()
assert(newValue != null, "compute function cannot return null.")
cached.compareAndSet(null.asInstanceOf[T], newValue)
compute = null // allow closure to be GC'd
}
cached.get()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util

import java.io.{IOException, ObjectInputStream}
import java.util.concurrent.atomic.AtomicReference

/**
* A lock-free implementation of a lazily-initialized variable.
* If there are concurrent initializations then the `compute()` function may be invoked
* multiple times. However, only a single `compute()` result will be stored and all readers
* will receive the same result object instance.
*
* This may be helpful for avoiding deadlocks in certain scenarios where exactly-once
* value computation is not a hard requirement.
*
* The main difference between this and [[BestEffortLazyVal]] is that:
* [[BestEffortLazyVal]] serializes the cached value after computation, while
* [[TransientBestEffortLazyVal]] always serializes the compute function.
*
* @note
* This helper class has additional requirements on the compute function:
* 1) The compute function MUST not return null;
* 2) The computation failure is not cached.
*
* @note
* Scala 3 uses a different implementation of lazy vals which doesn't have this problem.
* Please refer to <a
* href="https://docs.scala-lang.org/scala3/reference/changed-features/lazy-vals-init.html">Lazy
* Vals Initialization</a> for more details.
*/
private[spark] class TransientBestEffortLazyVal[T <: AnyRef](
private[this] val compute: () => T) extends Serializable {

@transient
private[this] var cached: AtomicReference[T] = new AtomicReference(null.asInstanceOf[T])

def apply(): T = {
val value = cached.get()
if (value != null) {
value
} else {
val newValue = compute()
assert(newValue != null, "compute function cannot return null.")
cached.compareAndSet(null.asInstanceOf[T], newValue)
cached.get()
}
}

@throws(classOf[IOException])
private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException {
ois.defaultReadObject()
cached = new AtomicReference(null.asInstanceOf[T])
}
}
44 changes: 44 additions & 0 deletions core/src/test/scala/org/apache/spark/SerializerTestUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}

trait SerializerTestUtils {

protected def roundtripSerialize[T](obj: T): T = {
deserializeFromBytes(serializeToBytes(obj))
}

protected def serializeToBytes[T](o: T): Array[Byte] = {
val baos = new ByteArrayOutputStream
val oos = new ObjectOutputStream(baos)
try {
oos.writeObject(o)
baos.toByteArray
} finally {
oos.close()
}
}

protected def deserializeFromBytes[T](bytes: Array[Byte]): T = {
val bais = new ByteArrayInputStream(bytes)
val ois = new ObjectInputStream(bais)
ois.readObject().asInstanceOf[T]
}
}
120 changes: 120 additions & 0 deletions core/src/test/scala/org/apache/spark/util/BestEffortLazyValSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util

import java.io.NotSerializableException
import java.util.concurrent.CountDownLatch
import java.util.concurrent.atomic.AtomicInteger

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._

import org.apache.spark.{SerializerTestUtils, SparkFunSuite}

class BestEffortLazyValSuite extends SparkFunSuite with SerializerTestUtils {

test("BestEffortLazy works") {
val numInitializerCalls = new AtomicInteger(0)
// Simulate a race condition where two threads concurrently
// initialize the lazy value:
val latch = new CountDownLatch(2)
val lazyval = new BestEffortLazyVal(() => {
numInitializerCalls.incrementAndGet()
latch.countDown()
latch.await()
new Object()
})

// Ensure no initialization happened before the lazy value was invoked
assert(numInitializerCalls.get() === 0)

// Two threads concurrently invoke the lazy value
implicit val ec: ExecutionContext = ExecutionContext.global
val future1 = Future { lazyval() }
val future2 = Future { lazyval() }
val value1 = ThreadUtils.awaitResult(future1, 10.seconds)
val value2 = ThreadUtils.awaitResult(future2, 10.seconds)

// The initializer should have been invoked twice (due to how we set up the
// race condition via the latch):
assert(numInitializerCalls.get() === 2)

// But the value should only have been computed once:
assert(value1 eq value2)

// Ensure the subsequent invocation serves the same object
assert(lazyval() eq value1)
assert(numInitializerCalls.get() === 2)
}

test("BestEffortLazyVal is serializable") {
val lazyval = new BestEffortLazyVal(() => "test")

// serialize and deserialize before first invocation
val lazyval2 = roundtripSerialize(lazyval)
assert(lazyval2() === "test")

// first invocation
assert(lazyval() === "test")

// serialize and deserialize after first invocation
val lazyval3 = roundtripSerialize(lazyval)
assert(lazyval3() === "test")
}

test("BestEffortLazyVal is serializable: unserializable value") {
val lazyval = new BestEffortLazyVal(() => new Object())

// serialize and deserialize before first invocation
val lazyval2 = roundtripSerialize(lazyval)
assert(lazyval2() != null)

// first invocation
assert(lazyval() != null)

// serialize and deserialize after first invocation
// try to serialize the cached value and cause NotSerializableException
val e = intercept[NotSerializableException] {
val lazyval3 = roundtripSerialize(lazyval)
}
assert(e.getMessage.contains("java.lang.Object"))
}

test("BestEffortLazyVal is serializable: initialization failure") {
val lazyval = new BestEffortLazyVal[String](() => throw new RuntimeException("test"))

// serialize and deserialize before first invocation
val lazyval2 = roundtripSerialize(lazyval)
val e2 = intercept[RuntimeException] {
val v = lazyval2()
}
assert(e2.getMessage.contains("test"))

// initialization failure
val e = intercept[RuntimeException] {
val v = lazyval()
}
assert(e.getMessage.contains("test"))

// serialize and deserialize after initialization failure
val lazyval3 = roundtripSerialize(lazyval)
val e3 = intercept[RuntimeException] {
val v = lazyval3()
}
assert(e3.getMessage.contains("test"))
}
}
Loading

0 comments on commit f8e8fcc

Please sign in to comment.