-
Notifications
You must be signed in to change notification settings - Fork 28.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-50705][SQL] Make
QueryPlan
lock-free
### What changes were proposed in this pull request? Replace a group of `lazy val` in `QueryPlan` with new lock-free helper classes. Not all `lazy val`s are replaced in this PR, we will need to handle remaining `lazy val`s together with its subclasses to make it take effect. ### Why are the changes needed? for the deadlock issues on query plan nodes: - sometimes we want the plan node methods to use a coarse lock (just lock the plan node itself), as these methods (expressions , output , references , deterministic , schema , canonicalized , etc.) may call each other, so using a coarse lock can prevent deadlocks. - sometimes we want these methods to use fine-grained locks, because these methods may call each other of a parent/child plan node. If you traverse the tree with different directions at the same time, it's likely to hit deadlock using coarse lock. the only solution is to not use locks if possible ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests Manually test against a deadlock case ### Was this patch authored or co-authored using generative AI tooling? No Closes #49212 from zhengruifeng/query_plan_atom_refs. Lead-authored-by: Ruifeng Zheng <[email protected]> Co-authored-by: Josh Rosen <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
- Loading branch information
Showing
7 changed files
with
402 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
69 changes: 69 additions & 0 deletions
69
core/src/main/scala/org/apache/spark/util/TransientBestEffortLazyVal.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.util | ||
|
||
import java.io.{IOException, ObjectInputStream} | ||
import java.util.concurrent.atomic.AtomicReference | ||
|
||
/** | ||
* A lock-free implementation of a lazily-initialized variable. | ||
* If there are concurrent initializations then the `compute()` function may be invoked | ||
* multiple times. However, only a single `compute()` result will be stored and all readers | ||
* will receive the same result object instance. | ||
* | ||
* This may be helpful for avoiding deadlocks in certain scenarios where exactly-once | ||
* value computation is not a hard requirement. | ||
* | ||
* The main difference between this and [[BestEffortLazyVal]] is that: | ||
* [[BestEffortLazyVal]] serializes the cached value after computation, while | ||
* [[TransientBestEffortLazyVal]] always serializes the compute function. | ||
* | ||
* @note | ||
* This helper class has additional requirements on the compute function: | ||
* 1) The compute function MUST not return null; | ||
* 2) The computation failure is not cached. | ||
* | ||
* @note | ||
* Scala 3 uses a different implementation of lazy vals which doesn't have this problem. | ||
* Please refer to <a | ||
* href="https://docs.scala-lang.org/scala3/reference/changed-features/lazy-vals-init.html">Lazy | ||
* Vals Initialization</a> for more details. | ||
*/ | ||
private[spark] class TransientBestEffortLazyVal[T <: AnyRef]( | ||
private[this] val compute: () => T) extends Serializable { | ||
|
||
@transient | ||
private[this] var cached: AtomicReference[T] = new AtomicReference(null.asInstanceOf[T]) | ||
|
||
def apply(): T = { | ||
val value = cached.get() | ||
if (value != null) { | ||
value | ||
} else { | ||
val newValue = compute() | ||
assert(newValue != null, "compute function cannot return null.") | ||
cached.compareAndSet(null.asInstanceOf[T], newValue) | ||
cached.get() | ||
} | ||
} | ||
|
||
@throws(classOf[IOException]) | ||
private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { | ||
ois.defaultReadObject() | ||
cached = new AtomicReference(null.asInstanceOf[T]) | ||
} | ||
} |
44 changes: 44 additions & 0 deletions
44
core/src/test/scala/org/apache/spark/SerializerTestUtils.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark | ||
|
||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} | ||
|
||
trait SerializerTestUtils { | ||
|
||
protected def roundtripSerialize[T](obj: T): T = { | ||
deserializeFromBytes(serializeToBytes(obj)) | ||
} | ||
|
||
protected def serializeToBytes[T](o: T): Array[Byte] = { | ||
val baos = new ByteArrayOutputStream | ||
val oos = new ObjectOutputStream(baos) | ||
try { | ||
oos.writeObject(o) | ||
baos.toByteArray | ||
} finally { | ||
oos.close() | ||
} | ||
} | ||
|
||
protected def deserializeFromBytes[T](bytes: Array[Byte]): T = { | ||
val bais = new ByteArrayInputStream(bytes) | ||
val ois = new ObjectInputStream(bais) | ||
ois.readObject().asInstanceOf[T] | ||
} | ||
} |
120 changes: 120 additions & 0 deletions
120
core/src/test/scala/org/apache/spark/util/BestEffortLazyValSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.util | ||
|
||
import java.io.NotSerializableException | ||
import java.util.concurrent.CountDownLatch | ||
import java.util.concurrent.atomic.AtomicInteger | ||
|
||
import scala.concurrent.{ExecutionContext, Future} | ||
import scala.concurrent.duration._ | ||
|
||
import org.apache.spark.{SerializerTestUtils, SparkFunSuite} | ||
|
||
class BestEffortLazyValSuite extends SparkFunSuite with SerializerTestUtils { | ||
|
||
test("BestEffortLazy works") { | ||
val numInitializerCalls = new AtomicInteger(0) | ||
// Simulate a race condition where two threads concurrently | ||
// initialize the lazy value: | ||
val latch = new CountDownLatch(2) | ||
val lazyval = new BestEffortLazyVal(() => { | ||
numInitializerCalls.incrementAndGet() | ||
latch.countDown() | ||
latch.await() | ||
new Object() | ||
}) | ||
|
||
// Ensure no initialization happened before the lazy value was invoked | ||
assert(numInitializerCalls.get() === 0) | ||
|
||
// Two threads concurrently invoke the lazy value | ||
implicit val ec: ExecutionContext = ExecutionContext.global | ||
val future1 = Future { lazyval() } | ||
val future2 = Future { lazyval() } | ||
val value1 = ThreadUtils.awaitResult(future1, 10.seconds) | ||
val value2 = ThreadUtils.awaitResult(future2, 10.seconds) | ||
|
||
// The initializer should have been invoked twice (due to how we set up the | ||
// race condition via the latch): | ||
assert(numInitializerCalls.get() === 2) | ||
|
||
// But the value should only have been computed once: | ||
assert(value1 eq value2) | ||
|
||
// Ensure the subsequent invocation serves the same object | ||
assert(lazyval() eq value1) | ||
assert(numInitializerCalls.get() === 2) | ||
} | ||
|
||
test("BestEffortLazyVal is serializable") { | ||
val lazyval = new BestEffortLazyVal(() => "test") | ||
|
||
// serialize and deserialize before first invocation | ||
val lazyval2 = roundtripSerialize(lazyval) | ||
assert(lazyval2() === "test") | ||
|
||
// first invocation | ||
assert(lazyval() === "test") | ||
|
||
// serialize and deserialize after first invocation | ||
val lazyval3 = roundtripSerialize(lazyval) | ||
assert(lazyval3() === "test") | ||
} | ||
|
||
test("BestEffortLazyVal is serializable: unserializable value") { | ||
val lazyval = new BestEffortLazyVal(() => new Object()) | ||
|
||
// serialize and deserialize before first invocation | ||
val lazyval2 = roundtripSerialize(lazyval) | ||
assert(lazyval2() != null) | ||
|
||
// first invocation | ||
assert(lazyval() != null) | ||
|
||
// serialize and deserialize after first invocation | ||
// try to serialize the cached value and cause NotSerializableException | ||
val e = intercept[NotSerializableException] { | ||
val lazyval3 = roundtripSerialize(lazyval) | ||
} | ||
assert(e.getMessage.contains("java.lang.Object")) | ||
} | ||
|
||
test("BestEffortLazyVal is serializable: initialization failure") { | ||
val lazyval = new BestEffortLazyVal[String](() => throw new RuntimeException("test")) | ||
|
||
// serialize and deserialize before first invocation | ||
val lazyval2 = roundtripSerialize(lazyval) | ||
val e2 = intercept[RuntimeException] { | ||
val v = lazyval2() | ||
} | ||
assert(e2.getMessage.contains("test")) | ||
|
||
// initialization failure | ||
val e = intercept[RuntimeException] { | ||
val v = lazyval() | ||
} | ||
assert(e.getMessage.contains("test")) | ||
|
||
// serialize and deserialize after initialization failure | ||
val lazyval3 = roundtripSerialize(lazyval) | ||
val e3 = intercept[RuntimeException] { | ||
val v = lazyval3() | ||
} | ||
assert(e3.getMessage.contains("test")) | ||
} | ||
} |
Oops, something went wrong.