Skip to content

Commit

Permalink
Check Hybrid jar in executor
Browse files Browse the repository at this point in the history
  • Loading branch information
Chong Gao committed Dec 24, 2024
1 parent 6149589 commit 114b93a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

object HybridExecutionUtils {

private val HYBRID_JAR_PLUGIN_CLASS_NAME = "com.nvidia.spark.rapids.hybrid.HybridPluginWrapper"

/**
* Check if the Hybrid jar is in the classpath,
* report error if not
*/
def checkHybridJarInClassPath(): Unit = {
try {
Class.forName(HYBRID_JAR_PLUGIN_CLASS_NAME)
} catch {
case e: ClassNotFoundException => throw new RuntimeException(
"Hybrid jar is not in the classpath, Please add Hybrid jar into the class path, or " +
"Please disable Hybrid feature by setting " +
"spark.rapids.sql.parquet.useHybridReader=false", e)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,11 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
// Fail if there are multiple plugin jars in the classpath.
RapidsPluginUtils.detectMultipleJars(conf)

// Check Hybrid jar if needed.
if (conf.useHybridParquetReader) {
HybridExecutionUtils.checkHybridJarInClassPath()
}

// Compare if the cudf version mentioned in the classpath is equal to the version which
// plugin expects. If there is a version mismatch, throw error. This check can be disabled
// by setting this config spark.rapids.cudfVersionOverride=true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ class HybridFileSourceScanExecMeta(plan: FileSourceScanExec,
}

object HybridFileSourceScanExecMeta {
private val HYBRID_JAR_PLUGIN_CLASS_NAME = "com.nvidia.spark.rapids.hybrid.HybridPluginWrapper"

// Determines whether using HybridScan or GpuScan
def useHybridScan(conf: RapidsConf, fsse: FileSourceScanExec): Boolean = {
val isEnabled = if (conf.useHybridParquetReader) {
Expand Down Expand Up @@ -148,7 +146,7 @@ object HybridFileSourceScanExecMeta {
*/
def checkRuntimes(v1DataSourceList: String): Unit = {
checkNotRunningCDHorDatabricks()
checkHybridJarInClassPath()
HybridExecutionUtils.checkHybridJarInClassPath()
checkJavaVersion()
checkScalaVersion()
checkV1Datasource(v1DataSourceList)
Expand All @@ -166,21 +164,6 @@ object HybridFileSourceScanExecMeta {
}
}

/**
* Check if the Hybrid jar is in the classpath,
* report error if not
*/
private def checkHybridJarInClassPath(): Unit = {
try {
Class.forName(HYBRID_JAR_PLUGIN_CLASS_NAME)
} catch {
case e: ClassNotFoundException => throw new RuntimeException(
"Hybrid jar is not in the classpath, Please add Hybrid jar into the class path, or " +
"Please disable Hybrid feature by setting " +
"spark.rapids.sql.parquet.useHybridReader=false", e)
}
}

/**
* Hybrid feature only supports 1.8 Java version,
* report error if not
Expand Down

0 comments on commit 114b93a

Please sign in to comment.