From 0c74d384096cc24e6532bff44679c45af645de38 Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Wed, 5 Jun 2024 18:28:38 +0800 Subject: [PATCH] works, ugly Signed-off-by: Hongbin Ma (Mahone) refine code, broken Signed-off-by: Hongbin Ma (Mahone) fix Signed-off-by: Hongbin Ma (Mahone) complete version 1 Signed-off-by: Hongbin Ma (Mahone) Introduce lore id Introduce lore id Fix type Fix type with loreid Signed-off-by: Hongbin Ma (Mahone) clean ut Signed-off-by: Hongbin Ma (Mahone) refine log Signed-off-by: Hongbin Ma (Mahone) clean Signed-off-by: Hongbin Ma (Mahone) fix ut Signed-off-by: Hongbin Ma (Mahone) fix idgen Signed-off-by: Hongbin Ma (Mahone) --- .../com/nvidia/spark/rapids/DumpUtils.scala | 19 ++- .../com/nvidia/spark/rapids/GpuExec.scala | 118 +++++++++++++++++- .../nvidia/spark/rapids/GpuOverrides.scala | 8 +- .../spark/rapids/GpuTransitionOverrides.scala | 5 + .../com/nvidia/spark/rapids/RapidsConf.scala | 31 +++++ .../com/nvidia/spark/rapids/lore/IdGen.scala | 61 +++++++++ .../rapids/profiling/ReplayDumpRDD.scala | 75 +++++++++++ .../rapids/shims/Spark320PlusShims.scala | 3 +- .../rapids/profiling/DumpedExecReplayer.scala | 77 ++++++++++++ 9 files changed, 391 insertions(+), 6 deletions(-) create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/IdGen.scala create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/profiling/ReplayDumpRDD.scala create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/profiling/DumpedExecReplayer.scala diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala index bf949897c78..ba19619633b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala @@ -15,11 +15,12 @@ */ package com.nvidia.spark.rapids -import java.io.{File, FileOutputStream} +import java.io.{ByteArrayOutputStream, File, FileInputStream, FileOutputStream, ObjectInputStream, ObjectOutputStream} import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag import ai.rapids.cudf._ import ai.rapids.cudf.ColumnWriterOptions._ @@ -99,6 +100,22 @@ object DumpUtils extends Logging { } } + def deserializeObject[T: ClassTag](readPath: String): T = { + val fileIn: FileInputStream = new FileInputStream(readPath) + val in: ObjectInputStream = new ObjectInputStream(fileIn) + val ret = in.readObject().asInstanceOf[T] + in.close() + ret + } + + def serializeObject(obj: Any): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val oos = new ObjectOutputStream(bos) + oos.writeObject(obj) + oos.close() + bos.toByteArray + } + private def dumpToParquetFileImp(table: Table, filePrefix: String): String = { val path = genPath(filePrefix) withResource(new ParquetDumper(path, table)) { dumper => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala index ec87dd62d6c..a3fc1bff0c9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala @@ -18,20 +18,26 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.NvtxColor import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.DumpUtils.serializeObject import com.nvidia.spark.rapids.filecache.FileCacheConf +import com.nvidia.spark.rapids.lore.IdGen.loreIdOf +import com.nvidia.spark.rapids.profiling.ReplayDumpRDD import com.nvidia.spark.rapids.shims.SparkShimImpl +import org.apache.hadoop.fs.{FSDataOutputStream, Path} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.rapids.LocationPreservingMapPartitionsRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, ExprId} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, UnaryLike} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.rapids.GpuTaskMetrics import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration sealed class MetricsLevel(val num: Integer) extends Serializable { def >=(other: MetricsLevel): Boolean = @@ -212,6 +218,27 @@ object GpuExec { } trait GpuExec extends SparkPlan { + + // For LORE replay + @transient var loreReplayInputDir: String = null // null is better than None considering ser/der + @transient var loreIsReplayingOperator: Boolean = false + // For LORE dump + @transient var shouldDumpOutput: Boolean = false + @transient var dumpForLOREId: String = "" + @transient lazy val loreDumpOperator: Option[String] = RapidsConf.LORE_DUMP_OPERATOR.get(conf) + @transient lazy val loreDumpLOREIds: String = RapidsConf.LORE_DUMP_LORE_IDS.get(conf) + @transient lazy val loreDumpPartitions: String = RapidsConf.LORE_DUMP_PARTITIONS.get(conf) + + // For LORE DumpedExecReplayer, the spark plan is deserialized from the plan.meta file, so + // some of the transient fields will be null, and we need to workaround this + override protected def sparkContext = SparkSession.getActiveSession.get.sparkContext + override protected def waitForSubqueries(): Unit = synchronized { + // only do it when it's not doing LORE replaying + if (!loreIsReplayingOperator && loreReplayInputDir == null) { + super.waitForSubqueries() + } + } + import GpuMetric._ def sparkSession: SparkSession = { SparkShimImpl.sessionFromPlan(this) @@ -363,16 +390,101 @@ trait GpuExec extends SparkPlan { this.getTagValue(GpuExec.TASK_METRICS_TAG) final override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val orig = internalDoExecuteColumnar() + val hadoopConf = new SerializableConfiguration(sparkSession.sparkContext.hadoopConfiguration) + def getOutputStream(filePath: String): FSDataOutputStream = { + val hadoopPath = new Path(filePath) + val fs = hadoopPath.getFileSystem(hadoopConf.value) + fs.create(hadoopPath, true) + } + + if (loreReplayInputDir != null) { + return new ReplayDumpRDD(sparkSession, loreReplayInputDir) + } + val className = this.getClass.getSimpleName + val myLoreId = loreIdOf(this).getOrElse("unknown") + if (loreDumpOperator.exists(o => o.equals(className)) || + loreDumpLOREIds.split(',').contains(myLoreId) + ) { + val childAsGpuExec = this.asInstanceOf[UnaryLike[SparkPlan]].child.asInstanceOf[GpuExec] + childAsGpuExec.shouldDumpOutput = true + childAsGpuExec.dumpForLOREId = myLoreId + val childPlanId = childAsGpuExec.id + // dump plan node + val planBytes = serializeObject(this) + val fos = getOutputStream( + s"file:/tmp/lore/lore_id=${myLoreId}_plan_id=${childPlanId}/plan.meta") + fos.write(planBytes) + fos.close() + } + val shouldDumpOutputToBroadcast = shouldDumpOutput + val dumpForLOREIdToBroadcast = dumpForLOREId + val loreDumpPartitionsToBroadcast = loreDumpPartitions + + var orig: RDD[ColumnarBatch] = null + orig = internalDoExecuteColumnar() + + val planId = id val metrics = getTaskMetrics + metrics.map { gpuMetrics => // This is ugly, but it reduces the need to change all exec nodes, so we are doing it here LocationPreservingMapPartitionsRDD(orig) { iter => gpuMetrics.makeSureRegistered() - iter + + var batchId = 0 + iter.map(cb => { + + val tc = TaskContext.get() + if (shouldDumpOutputToBroadcast && + loreDumpPartitionsToBroadcast.split(',').map(_.toInt).contains(tc.partitionId())) { + + println(s"LORE dump activated, for the operator to dump output: " + + s"className: ${className}, " + + s"stage id: ${tc.stageId()}, " + + s"for LORE id: ${dumpForLOREIdToBroadcast}, " + + s"plan id: ${planId}, " + + s"partition id: ${tc.partitionId()}, " + + s"batch id: ${batchId}, " + + s"batch size: ${cb.numRows()} rows.") + + val partitionId = TaskContext.get().partitionId() + + { + // dump col types for column batch to remote storage + val cbTypes = GpuColumnVector.extractTypes(cb) + val bytes = serializeObject(cbTypes) + val fos = getOutputStream( + s"file:/tmp/lore/lore_id=${dumpForLOREIdToBroadcast}_plan_id=${planId}/" + + s"partition_id=${partitionId}/" + + s"batch_id=${batchId}/col_types.meta") + fos.write(bytes) + fos.close() + } + + // dump data for column batch to /tmp dir + withResource(GpuColumnVector.from(cb)) { table => + val path = s"/tmp/lore/lore_id=${dumpForLOREIdToBroadcast}_plan_id=${planId}/" + + s"partition_id=${partitionId}/" + + s"batch_id=${batchId}/cb_data.parquet" + withResource(new ParquetDumper(path, table)) { dumper => + dumper.writeTable(table) + path + } + } + } + batchId = batchId + 1 + cb + }) } }.getOrElse(orig) } + override def nodeName: String = { + loreIdOf(this) match { + case Some(loreId) => s"${super.nodeName} [LOREID=$loreId]" + case None => s"${super.nodeName} [LOREID=unknown]" + } + } + protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 295480d24cc..491271e33ea 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -25,6 +25,7 @@ import scala.util.control.NonFatal import ai.rapids.cudf.DType import com.nvidia.spark.rapids.RapidsConf.{SUPPRESS_PLANNING_FAILURE, TEST_CONF} import com.nvidia.spark.rapids.jni.GpuTimeZoneDB +import com.nvidia.spark.rapids.lore.IdGen import com.nvidia.spark.rapids.shims._ import com.nvidia.spark.rapids.window.{GpuDenseRank, GpuLag, GpuLead, GpuPercentRank, GpuRank, GpuRowNumber, GpuSpecialFrameBoundary, GpuWindowExecMeta, GpuWindowSpecDefinitionMeta} import org.apache.hadoop.fs.Path @@ -4707,7 +4708,12 @@ case class GpuOverrides() extends Rule[SparkPlan] with Logging { } } } - GpuOverrides.doConvertPlan(wrap, conf, optimizations) + val convertedPlan = GpuOverrides.doConvertPlan(wrap, conf, optimizations) + if (conf.get(RapidsConf.TAG_LORE_ID_ENABLED)) { + IdGen.tagLoreId(convertedPlan) + } else { + convertedPlan + } } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala index 48f9de5a61a..14c6cf582ec 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.annotation.tailrec import scala.collection.mutable +import com.nvidia.spark.rapids.lore.IdGen import com.nvidia.spark.rapids.shims.{GpuBatchScanExec, SparkShimImpl} import org.apache.spark.SparkContext @@ -823,6 +824,10 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { updatedPlan = fixupAdaptiveExchangeReuse(updatedPlan) } + if (rapidsConf.get(RapidsConf.TAG_LORE_ID_ENABLED)) { + updatedPlan = IdGen.tagLoreId(updatedPlan) + } + if (rapidsConf.logQueryTransformations) { logWarning(s"Transformed query:" + s"\nOriginal Plan:\n$plan\nTransformed Plan:\n$updatedPlan") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 8ea1641fb4a..5883421d343 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -708,6 +708,33 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .checkValues(Set("DEBUG", "MODERATE", "ESSENTIAL")) .createWithDefault("MODERATE") + val TAG_LORE_ID_ENABLED = conf("spark.rapids.LORE.tagging") + .doc("Enable tagging a LORE id to each gpu plan node") + .booleanConf + .createWithDefault(true) + + val LORE_DUMP_OPERATOR = conf("spark.rapids.LORE.operatorToDump") + .doc("The name of SparkPlan to dump, e.g. GpuHashAggregateExec") + .internal() + .stringConf + .createOptional + + val LORE_DUMP_LORE_IDS = conf("spark.rapids.LORE.idsToDump") + .doc("Specify which operator(s) to dump by LORE ID. " + + "Specify multiple partitions by using comma, e.g. \"12,31\"") + .internal() + .stringConf + .createWithDefault("") + + val LORE_DUMP_PARTITIONS = conf("spark.rapids.LORE.partitionsToDump") + .doc("Which partition of the operator(the operator relates to a fixed stage, " + + "each stage is divided into many tasks by partition id) to dump. User can " + + "specify multiple partitions by using comma, e.g. \"0,3,5\"." + + "By default it will dump the first partitions.") + .internal() + .stringConf + .createWithDefault("0") + val PROFILE_PATH = conf("spark.rapids.profile.pathPrefix") .doc("Enables profiling and specifies a URI path to use when writing profile data") .internal() @@ -2462,6 +2489,10 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val metricsLevel: String = get(METRICS_LEVEL) + lazy val loreDumpOperator: Option[String] = get(LORE_DUMP_OPERATOR) + + lazy val loreDumpPartitions: String = get(LORE_DUMP_PARTITIONS) + lazy val profilePath: Option[String] = get(PROFILE_PATH) lazy val profileExecutors: String = get(PROFILE_EXECUTORS) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/IdGen.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/IdGen.scala new file mode 100644 index 00000000000..dc0cbf1fc7a --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/IdGen.scala @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.lore + +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} +import java.util.concurrent.atomic.AtomicInteger + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} + +object IdGen { + val LORE_ID_TAG: TreeNodeTag[String] = new TreeNodeTag[String]("rapids.gpu.lore.id") + + /** + * LORE id generator. Key is [[SQLExecution.EXECUTION_ID_KEY]]. + */ + private val idGen: ConcurrentMap[String, AtomicInteger] = + new ConcurrentHashMap[String, AtomicInteger]() + + private def nextLoreIdOfSparkPlan(plan: SparkPlan): Option[Int] = { + // When the execution id is not set, it means there is no actual execution happening, in this + // case we don't need to generate lore id. + Option(plan.session.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)) + .map { executionId => + idGen.computeIfAbsent(executionId, _ => new AtomicInteger(0)).getAndIncrement() + } + } + + def tagLoreId(sparkPlan: SparkPlan): SparkPlan = { + sparkPlan.foreachUp { + case g: GpuExec => { + nextLoreIdOfSparkPlan(g).foreach { id => + g.setTagValue(LORE_ID_TAG, id.toString) + } + } + case _ => + } + + sparkPlan + } + + def loreIdOf(node: SparkPlan): Option[String] = { + node.getTagValue(LORE_ID_TAG) + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/profiling/ReplayDumpRDD.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/profiling/ReplayDumpRDD.scala new file mode 100644 index 00000000000..15546fe21e4 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/profiling/ReplayDumpRDD.scala @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.profiling + +import java.io.File + +import ai.rapids.cudf.Table +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.DumpUtils.deserializeObject +import com.nvidia.spark.rapids.GpuColumnVector + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch + +class SimplePartition extends Partition { + override def index: Int = 0 +} + +object ReplayDumpRDD { + +} + +class ReplayDumpRDD( + @transient private val sparkSession: SparkSession, + val path: String) + extends RDD[ColumnarBatch](sparkSession.sparkContext, Nil) { + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val rootFolder = new File(path) + val subFolderIter = rootFolder.listFiles().filter(_.isDirectory).map(_.getPath).iterator + val cbIter: Iterator[ColumnarBatch] = subFolderIter.map(replayDir => { + + val cbTypesPath = replayDir + "/col_types.meta" + if (!(new File(cbTypesPath).exists() && new File(cbTypesPath).isFile)) { + throw new IllegalStateException(s"There is no col_types.meta file in $replayDir") + } + + val parquets = new File(replayDir).listFiles(f => f.getName.equals(s"cb_data.parquet")) + if (parquets.size != 1) { + throw new IllegalStateException( + s"missing cb_data.parquet file in $replayDir") + } + val cbPath = parquets(0).getAbsolutePath + + // restore column types + val restoredCbTypes = deserializeObject[Array[DataType]](cbTypesPath) + + // construct a column batch + withResource(Table.readParquet(new File(cbPath))) { restoredTable => + println("a input batch with size " + restoredTable.getRowCount) + GpuColumnVector.from(restoredTable, restoredCbTypes) + } + }) + cbIter + } + + override protected def getPartitions: Array[Partition] = Seq(new SimplePartition()).toArray +} + diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index 6e25ab6162a..59cb5423c46 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -89,7 +89,8 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { (exec, conf, p, r) => new GpuCustomShuffleReaderMeta(exec, conf, p, r)) override final def sessionFromPlan(plan: SparkPlan): SparkSession = { - plan.session +// plan.session + SparkSession.getActiveSession.get } override def isEmptyRelation(relation: Any): Boolean = relation match { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/profiling/DumpedExecReplayer.scala b/tests/src/test/scala/com/nvidia/spark/rapids/profiling/DumpedExecReplayer.scala new file mode 100644 index 00000000000..4a756608bd9 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/profiling/DumpedExecReplayer.scala @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.profiling + +import java.io.File + +import com.nvidia.spark.rapids.DumpUtils.deserializeObject +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.trees.UnaryLike + +object DumpedExecReplayer extends Logging { + + def main(args: Array[String]): Unit = { + // check arguments and get paths + if (args.length != 1) { + throw new IllegalStateException("Specify 1 args: , " + + "e.g. /tmp/lore/lore_id=41_plan_id=240/partition_id=0") + } + + // start a Spark session with Spark-Rapids initialization + SparkSession.builder() + .master("local[*]") + .config("spark.plugins", "com.nvidia.spark.SQLPlugin") + .appName("Replaying Dumped Exec") + .getOrCreate() + + val rootFolder: File = new File(args(0)) + val subFolders = rootFolder.listFiles().filter(_.isDirectory) + if (subFolders.length < 1) { + throw new IllegalStateException("There is no subfolder in the replay dir") + } + + + val planMetaPath = rootFolder.getParent + s"/plan.meta" + if (!(new File(planMetaPath).exists() && new File(planMetaPath).isFile)) { + throw new IllegalStateException(s"There is no plan.meta file in ${rootFolder.getParent}") + } + + // restore SparkPlan + val restoredExec = deserializeObject[GpuExec](planMetaPath) + + if (!restoredExec.isInstanceOf[UnaryLike[_]]) throw new IllegalStateException( + s"For now, restored exec only supports UnaryLike: ${restoredExec.getClass}") + val unaryLike = restoredExec.asInstanceOf[UnaryLike[_]] + + if (!unaryLike.child.isInstanceOf[GpuExec]) throw new IllegalStateException( + s"For now, restored exec's child only supports GpuExec: " + + s"${unaryLike.child.getClass}") + val child = unaryLike.child.asInstanceOf[GpuExec] + child.loreReplayInputDir = rootFolder.getPath + restoredExec.loreIsReplayingOperator = true + + restoredExec.doExecuteColumnar().foreach( + cb => { + println(s"return ColumnarBatch with size: ${cb.numRows()}") + cb.close() + } + ) + } +} \ No newline at end of file