diff --git a/core/src/main/scala/org/apache/spark/util/TransientLazy.scala b/core/src/main/scala/org/apache/spark/util/BestEffortLazyVal.scala similarity index 50% rename from core/src/main/scala/org/apache/spark/util/TransientLazy.scala rename to core/src/main/scala/org/apache/spark/util/BestEffortLazyVal.scala index 2833ef93669a6..83044055fe404 100644 --- a/core/src/main/scala/org/apache/spark/util/TransientLazy.scala +++ b/core/src/main/scala/org/apache/spark/util/BestEffortLazyVal.scala @@ -16,15 +16,21 @@ */ package org.apache.spark.util +import java.util.concurrent.atomic.AtomicReference + /** - * Construct to lazily initialize a variable. - * This may be helpful for avoiding deadlocks in certain scenarios. For example, - * a) Thread 1 entered a synchronized method, grabbing a coarse lock on the parent object. - * b) Thread 2 gets spawned off, and tries to initialize a lazy value on the same parent object - * (in our case, this was the logger). This causes scala to also try to grab a coarse lock on - * the parent object. - * c) If thread 1 waits for thread 2 to join, a deadlock occurs. - * The main difference between this and [[LazyTry]] is that this does not cache failures. + * A lock-free implementation of a lazily-initialized variable. + * If there are concurrent initializations then the `compute()` function may be invoked + * multiple times. However, only a single `compute()` result will be stored and all readers + * will receive the same result object instance. + * + * This may be helpful for avoiding deadlocks in certain scenarios where exactly-once + * value computation is not a hard requirement. + * + * @note + * This helper class has additional requirements on the compute function: + * 1) The compute function MUST not return null; + * 2) The computation failure is not cached. * * @note * Scala 3 uses a different implementation of lazy vals which doesn't have this problem. @@ -32,12 +38,24 @@ package org.apache.spark.util * href="https://docs.scala-lang.org/scala3/reference/changed-features/lazy-vals-init.html">Lazy * Vals Initialization for more details. */ -private[spark] class TransientLazy[T](initializer: => T) extends Serializable { +private[spark] class BestEffortLazyVal[T <: AnyRef]( + @volatile private[this] var compute: () => T) extends Serializable { - @transient - private[this] lazy val value: T = initializer + private[this] val cached: AtomicReference[T] = new AtomicReference(null.asInstanceOf[T]) def apply(): T = { - value + val value = cached.get() + if (value != null) { + value + } else { + val f = compute + if (f != null) { + val newValue = f() + assert(newValue != null, "compute function cannot return null.") + cached.compareAndSet(null.asInstanceOf[T], newValue) + compute = null // allow closure to be GC'd + } + cached.get() + } } } diff --git a/core/src/main/scala/org/apache/spark/util/TransientBestEffortLazyVal.scala b/core/src/main/scala/org/apache/spark/util/TransientBestEffortLazyVal.scala new file mode 100644 index 0000000000000..033b783ede40b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TransientBestEffortLazyVal.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.io.{IOException, ObjectInputStream} +import java.util.concurrent.atomic.AtomicReference + +/** + * A lock-free implementation of a lazily-initialized variable. + * If there are concurrent initializations then the `compute()` function may be invoked + * multiple times. However, only a single `compute()` result will be stored and all readers + * will receive the same result object instance. + * + * This may be helpful for avoiding deadlocks in certain scenarios where exactly-once + * value computation is not a hard requirement. + * + * The main difference between this and [[BestEffortLazyVal]] is that: + * [[BestEffortLazyVal]] serializes the cached value after computation, while + * [[TransientBestEffortLazyVal]] always serializes the compute function. + * + * @note + * This helper class has additional requirements on the compute function: + * 1) The compute function MUST not return null; + * 2) The computation failure is not cached. + * + * @note + * Scala 3 uses a different implementation of lazy vals which doesn't have this problem. + * Please refer to Lazy + * Vals Initialization for more details. + */ +private[spark] class TransientBestEffortLazyVal[T <: AnyRef]( + private[this] val compute: () => T) extends Serializable { + + @transient + private[this] var cached: AtomicReference[T] = new AtomicReference(null.asInstanceOf[T]) + + def apply(): T = { + val value = cached.get() + if (value != null) { + value + } else { + val newValue = compute() + assert(newValue != null, "compute function cannot return null.") + cached.compareAndSet(null.asInstanceOf[T], newValue) + cached.get() + } + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { + ois.defaultReadObject() + cached = new AtomicReference(null.asInstanceOf[T]) + } +} diff --git a/core/src/test/scala/org/apache/spark/SerializerTestUtils.scala b/core/src/test/scala/org/apache/spark/SerializerTestUtils.scala new file mode 100644 index 0000000000000..bd81003777317 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SerializerTestUtils.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + +trait SerializerTestUtils { + + protected def roundtripSerialize[T](obj: T): T = { + deserializeFromBytes(serializeToBytes(obj)) + } + + protected def serializeToBytes[T](o: T): Array[Byte] = { + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(baos) + try { + oos.writeObject(o) + baos.toByteArray + } finally { + oos.close() + } + } + + protected def deserializeFromBytes[T](bytes: Array[Byte]): T = { + val bais = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bais) + ois.readObject().asInstanceOf[T] + } +} diff --git a/core/src/test/scala/org/apache/spark/util/BestEffortLazyValSuite.scala b/core/src/test/scala/org/apache/spark/util/BestEffortLazyValSuite.scala new file mode 100644 index 0000000000000..a6555eca8b859 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/BestEffortLazyValSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.io.NotSerializableException +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration._ + +import org.apache.spark.{SerializerTestUtils, SparkFunSuite} + +class BestEffortLazyValSuite extends SparkFunSuite with SerializerTestUtils { + + test("BestEffortLazy works") { + val numInitializerCalls = new AtomicInteger(0) + // Simulate a race condition where two threads concurrently + // initialize the lazy value: + val latch = new CountDownLatch(2) + val lazyval = new BestEffortLazyVal(() => { + numInitializerCalls.incrementAndGet() + latch.countDown() + latch.await() + new Object() + }) + + // Ensure no initialization happened before the lazy value was invoked + assert(numInitializerCalls.get() === 0) + + // Two threads concurrently invoke the lazy value + implicit val ec: ExecutionContext = ExecutionContext.global + val future1 = Future { lazyval() } + val future2 = Future { lazyval() } + val value1 = ThreadUtils.awaitResult(future1, 10.seconds) + val value2 = ThreadUtils.awaitResult(future2, 10.seconds) + + // The initializer should have been invoked twice (due to how we set up the + // race condition via the latch): + assert(numInitializerCalls.get() === 2) + + // But the value should only have been computed once: + assert(value1 eq value2) + + // Ensure the subsequent invocation serves the same object + assert(lazyval() eq value1) + assert(numInitializerCalls.get() === 2) + } + + test("BestEffortLazyVal is serializable") { + val lazyval = new BestEffortLazyVal(() => "test") + + // serialize and deserialize before first invocation + val lazyval2 = roundtripSerialize(lazyval) + assert(lazyval2() === "test") + + // first invocation + assert(lazyval() === "test") + + // serialize and deserialize after first invocation + val lazyval3 = roundtripSerialize(lazyval) + assert(lazyval3() === "test") + } + + test("BestEffortLazyVal is serializable: unserializable value") { + val lazyval = new BestEffortLazyVal(() => new Object()) + + // serialize and deserialize before first invocation + val lazyval2 = roundtripSerialize(lazyval) + assert(lazyval2() != null) + + // first invocation + assert(lazyval() != null) + + // serialize and deserialize after first invocation + // try to serialize the cached value and cause NotSerializableException + val e = intercept[NotSerializableException] { + val lazyval3 = roundtripSerialize(lazyval) + } + assert(e.getMessage.contains("java.lang.Object")) + } + + test("BestEffortLazyVal is serializable: initialization failure") { + val lazyval = new BestEffortLazyVal[String](() => throw new RuntimeException("test")) + + // serialize and deserialize before first invocation + val lazyval2 = roundtripSerialize(lazyval) + val e2 = intercept[RuntimeException] { + val v = lazyval2() + } + assert(e2.getMessage.contains("test")) + + // initialization failure + val e = intercept[RuntimeException] { + val v = lazyval() + } + assert(e.getMessage.contains("test")) + + // serialize and deserialize after initialization failure + val lazyval3 = roundtripSerialize(lazyval) + val e3 = intercept[RuntimeException] { + val v = lazyval3() + } + assert(e3.getMessage.contains("test")) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/TransientBestEffortLazyValSuite.scala b/core/src/test/scala/org/apache/spark/util/TransientBestEffortLazyValSuite.scala new file mode 100644 index 0000000000000..3ed9f2958fd9c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/TransientBestEffortLazyValSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration._ + +import org.apache.spark.{SerializerTestUtils, SparkFunSuite} + +class TransientBestEffortLazyValSuite extends SparkFunSuite with SerializerTestUtils { + + test("TransientBestEffortLazyVal works") { + val numInitializerCalls = new AtomicInteger(0) + // Simulate a race condition where two threads concurrently + // initialize the lazy value: + val latch = new CountDownLatch(2) + val lazyval = new TransientBestEffortLazyVal(() => { + numInitializerCalls.incrementAndGet() + latch.countDown() + latch.await() + new Object() + }) + + // Ensure no initialization happened before the lazy value was invoked + assert(numInitializerCalls.get() === 0) + + // Two threads concurrently invoke the lazy value + implicit val ec: ExecutionContext = ExecutionContext.global + val future1 = Future { lazyval() } + val future2 = Future { lazyval() } + val value1 = ThreadUtils.awaitResult(future1, 10.seconds) + val value2 = ThreadUtils.awaitResult(future2, 10.seconds) + + // The initializer should have been invoked twice (due to how we set up the + // race condition via the latch): + assert(numInitializerCalls.get() === 2) + + // But the value should only have been computed once: + assert(value1 eq value2) + + // Ensure the subsequent invocation serves the same object + assert(lazyval() eq value1) + assert(numInitializerCalls.get() === 2) + } + + test("TransientBestEffortLazyVal is serializable") { + val lazyval = new TransientBestEffortLazyVal(() => "test") + + // serialize and deserialize before first invocation + val lazyval2 = roundtripSerialize(lazyval) + assert(lazyval2() === "test") + + // first invocation + assert(lazyval() === "test") + + // serialize and deserialize after first invocation + val lazyval3 = roundtripSerialize(lazyval) + assert(lazyval3() === "test") + } + + test("TransientBestEffortLazyVal is serializable: unserializable value") { + val lazyval = new TransientBestEffortLazyVal(() => new Object()) + + // serialize and deserialize before first invocation + val lazyval2 = roundtripSerialize(lazyval) + assert(lazyval2() != null) + + // first invocation + assert(lazyval() != null) + + // serialize and deserialize after first invocation + val lazyval3 = roundtripSerialize(lazyval) + assert(lazyval3() != null) + } + + test("TransientBestEffortLazyVal is serializable: failure in compute function") { + val lazyval = new TransientBestEffortLazyVal[String](() => throw new RuntimeException("test")) + + // serialize and deserialize before first invocation + val lazyval2 = roundtripSerialize(lazyval) + val e2 = intercept[RuntimeException] { + val v = lazyval2() + } + assert(e2.getMessage.contains("test")) + + // initialization failure + val e = intercept[RuntimeException] { + val v = lazyval() + } + assert(e.getMessage.contains("test")) + + // serialize and deserialize after initialization failure + val lazyval3 = roundtripSerialize(lazyval) + val e3 = intercept[RuntimeException] { + val v = lazyval3() + } + assert(e3.getMessage.contains("test")) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/TransientLazySuite.scala b/core/src/test/scala/org/apache/spark/util/TransientLazySuite.scala deleted file mode 100644 index c0754ee063d67..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/TransientLazySuite.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.util - -import java.io.{ByteArrayOutputStream, ObjectOutputStream} - -import org.apache.spark.SparkFunSuite - -class TransientLazySuite extends SparkFunSuite { - - test("TransientLazy val works") { - var test: Option[Object] = None - - val lazyval = new TransientLazy({ - test = Some(new Object()) - test - }) - - // Ensure no initialization happened before the lazy value was dereferenced - assert(test.isEmpty) - - // Ensure the first invocation creates a new object - assert(lazyval() == test && test.isDefined) - - // Ensure the subsequent invocation serves the same object - assert(lazyval() == test && test.isDefined) - } - - test("TransientLazy val is serializable") { - val lazyval = new TransientLazy({ - new Object() - }) - - // Ensure serializable before the dereference - val oos = new ObjectOutputStream(new ByteArrayOutputStream()) - oos.writeObject(lazyval) - - val dereferenced = lazyval() - - // Ensure serializable after the dereference - val oos2 = new ObjectOutputStream(new ByteArrayOutputStream()) - oos2.writeObject(lazyval) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 40244595da57f..9052f6228a9d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans +import java.lang.{Boolean => JBoolean} import java.util.IdentityHashMap import scala.collection.mutable @@ -32,7 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePatternBits import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.TransientLazy +import org.apache.spark.util.{BestEffortLazyVal, TransientBestEffortLazyVal} import org.apache.spark.util.collection.BitSet /** @@ -54,8 +55,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] /** * Returns the set of attributes that are output by this node. */ - @transient - lazy val outputSet: AttributeSet = AttributeSet(output) + def outputSet: AttributeSet = _outputSet() + + private val _outputSet = new TransientBestEffortLazyVal(() => AttributeSet(output)) /** * Returns the output ordering that this plan generates, although the semantics differ in logical @@ -97,16 +99,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] */ def references: AttributeSet = _references() - private val _references = new TransientLazy({ - AttributeSet(expressions) -- producedAttributes - }) + private val _references = new TransientBestEffortLazyVal(() => + AttributeSet(expressions) -- producedAttributes) /** * Returns true when the all the expressions in the current node as well as all of its children * are deterministic */ - lazy val deterministic: Boolean = expressions.forall(_.deterministic) && - children.forall(_.deterministic) + def deterministic: Boolean = _deterministic() + + private val _deterministic = new BestEffortLazyVal[JBoolean](() => + expressions.forall(_.deterministic) && children.forall(_.deterministic)) /** * Attributes that are referenced by expressions but not provided by this node's children. @@ -427,7 +430,10 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] } } - lazy val schema: StructType = DataTypeUtils.fromAttributes(output) + def schema: StructType = _schema() + + private val _schema = new BestEffortLazyVal[StructType](() => + DataTypeUtils.fromAttributes(output)) /** Returns the output schema in the tree format. */ def schemaString: String = schema.treeString @@ -480,11 +486,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] /** * All the top-level subqueries of the current plan node. Nested subqueries are not included. */ - @transient lazy val subqueries: Seq[PlanType] = { + def subqueries: Seq[PlanType] = _subqueries() + + private val _subqueries = new TransientBestEffortLazyVal(() => expressions.filter(_.containsPattern(PLAN_EXPRESSION)).flatMap(_.collect { case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] }) - } + ) /** * All the subqueries of the current plan node and all its children. Nested subqueries are also @@ -620,7 +628,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] * Plan nodes that require special canonicalization should override [[doCanonicalize()]]. * They should remove expressions cosmetic variations themselves. */ - @transient final lazy val canonicalized: PlanType = { + def canonicalized: PlanType = _canonicalized() + + private val _canonicalized = new TransientBestEffortLazyVal(() => { var plan = doCanonicalize() // If the plan has not been changed due to canonicalization, make a copy of it so we don't // mutate the original plan's _isCanonicalizedPlan flag. @@ -629,7 +639,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] } plan._isCanonicalizedPlan = true plan - } + }) /** * Defines how the canonicalization should work for the current plan.