From 32834e62835efeee4174e00a347236ccfe9c3154 Mon Sep 17 00:00:00 2001 From: ustcfy <96854327+ustcfy@users.noreply.github.com> Date: Thu, 12 Dec 2024 15:56:09 +0800 Subject: [PATCH] Expose max nesting depth in hash function to plugin (#2680) * Expose max nesting depth in hash function to plugin Signed-off-by: Yan Feng * Format C++ code Signed-off-by: Yan Feng * Remove MAX_NESTED_DEPTH definition from xxhash64 Signed-off-by: Yan Feng * Update src/main/cpp/src/HashJni.cpp Co-authored-by: Chong Gao * Update src/main/cpp/src/HashJni.cpp Co-authored-by: Nghia Truong <7416935+ttnghia@users.noreply.github.com> * Rename MAX_NESTED_DEPTH to MAX_STACK_DEPTH Signed-off-by: Yan Feng --------- Signed-off-by: Yan Feng Co-authored-by: Chong Gao Co-authored-by: Nghia Truong <7416935+ttnghia@users.noreply.github.com> --- src/main/cpp/src/HashJni.cpp | 5 +++++ src/main/cpp/src/hash.hpp | 1 + src/main/cpp/src/hive_hash.cu | 3 +-- src/main/cpp/src/xxhash64.cu | 3 +-- src/main/java/com/nvidia/spark/rapids/jni/Hash.java | 4 ++++ 5 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/main/cpp/src/HashJni.cpp b/src/main/cpp/src/HashJni.cpp index c0adf38686..520b6f24c0 100644 --- a/src/main/cpp/src/HashJni.cpp +++ b/src/main/cpp/src/HashJni.cpp @@ -21,6 +21,11 @@ extern "C" { +JNIEXPORT jint JNICALL Java_com_nvidia_spark_rapids_jni_Hash_getMaxStackDepth(JNIEnv* env, jclass) +{ + return spark_rapids_jni::MAX_STACK_DEPTH; +} + JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_Hash_murmurHash32( JNIEnv* env, jclass, jint seed, jlongArray column_handles) { diff --git a/src/main/cpp/src/hash.hpp b/src/main/cpp/src/hash.hpp index 4021b9e75c..9ec7496031 100644 --- a/src/main/cpp/src/hash.hpp +++ b/src/main/cpp/src/hash.hpp @@ -25,6 +25,7 @@ namespace spark_rapids_jni { constexpr int64_t DEFAULT_XXHASH64_SEED = 42; +constexpr int MAX_STACK_DEPTH = 8; /** * @brief Computes the murmur32 hash value of each row in the input set of columns. diff --git a/src/main/cpp/src/hive_hash.cu b/src/main/cpp/src/hive_hash.cu index dcabd870af..89d08425fd 100644 --- a/src/main/cpp/src/hive_hash.cu +++ b/src/main/cpp/src/hive_hash.cu @@ -15,6 +15,7 @@ */ #include "hash.cuh" +#include "hash.hpp" #include #include @@ -37,8 +38,6 @@ using hive_hash_value_t = int32_t; constexpr hive_hash_value_t HIVE_HASH_FACTOR = 31; constexpr hive_hash_value_t HIVE_INIT_HASH = 0; -constexpr int MAX_STACK_DEPTH = 8; - hive_hash_value_t __device__ inline compute_int(int32_t key) { return key; } hive_hash_value_t __device__ inline compute_long(int64_t key) diff --git a/src/main/cpp/src/xxhash64.cu b/src/main/cpp/src/xxhash64.cu index 1aa1847be7..f7a0d7cb35 100644 --- a/src/main/cpp/src/xxhash64.cu +++ b/src/main/cpp/src/xxhash64.cu @@ -15,6 +15,7 @@ */ #include "hash.cuh" +#include "hash.hpp" #include #include @@ -34,8 +35,6 @@ namespace { using hash_value_type = int64_t; using half_size_type = int32_t; -constexpr int MAX_STACK_DEPTH = 8; - constexpr __device__ inline int64_t rotate_bits_left_signed(hash_value_type h, int8_t r) { return (h << r) | (h >> (64 - r)) & ~(-1 << r); diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Hash.java b/src/main/java/com/nvidia/spark/rapids/jni/Hash.java index 2b82882868..96b66555a7 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Hash.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Hash.java @@ -25,6 +25,8 @@ public class Hash { // there doesn't appear to be a useful constant in spark to reference. this could break. static final long DEFAULT_XXHASH64_SEED = 42; + public static final int MAX_STACK_DEPTH = getMaxStackDepth(); + static { NativeDepsLoader.loadNativeDeps(); } @@ -100,6 +102,8 @@ public static ColumnVector hiveHash(ColumnView columns[]) { return new ColumnVector(hiveHash(columnViews)); } + private static native int getMaxStackDepth(); + private static native long murmurHash32(int seed, long[] viewHandles) throws CudfException; private static native long xxhash64(long seed, long[] viewHandles) throws CudfException;