diff --git a/src/main/cpp/src/HashJni.cpp b/src/main/cpp/src/HashJni.cpp index c0adf3868..520b6f24c 100644 --- a/src/main/cpp/src/HashJni.cpp +++ b/src/main/cpp/src/HashJni.cpp @@ -21,6 +21,11 @@ extern "C" { +JNIEXPORT jint JNICALL Java_com_nvidia_spark_rapids_jni_Hash_getMaxStackDepth(JNIEnv* env, jclass) +{ + return spark_rapids_jni::MAX_STACK_DEPTH; +} + JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_Hash_murmurHash32( JNIEnv* env, jclass, jint seed, jlongArray column_handles) { diff --git a/src/main/cpp/src/hash.hpp b/src/main/cpp/src/hash.hpp index 4021b9e75..9ec749603 100644 --- a/src/main/cpp/src/hash.hpp +++ b/src/main/cpp/src/hash.hpp @@ -25,6 +25,7 @@ namespace spark_rapids_jni { constexpr int64_t DEFAULT_XXHASH64_SEED = 42; +constexpr int MAX_STACK_DEPTH = 8; /** * @brief Computes the murmur32 hash value of each row in the input set of columns. diff --git a/src/main/cpp/src/hive_hash.cu b/src/main/cpp/src/hive_hash.cu index dcabd870a..89d08425f 100644 --- a/src/main/cpp/src/hive_hash.cu +++ b/src/main/cpp/src/hive_hash.cu @@ -15,6 +15,7 @@ */ #include "hash.cuh" +#include "hash.hpp" #include #include @@ -37,8 +38,6 @@ using hive_hash_value_t = int32_t; constexpr hive_hash_value_t HIVE_HASH_FACTOR = 31; constexpr hive_hash_value_t HIVE_INIT_HASH = 0; -constexpr int MAX_STACK_DEPTH = 8; - hive_hash_value_t __device__ inline compute_int(int32_t key) { return key; } hive_hash_value_t __device__ inline compute_long(int64_t key) diff --git a/src/main/cpp/src/xxhash64.cu b/src/main/cpp/src/xxhash64.cu index 1aa1847be..f7a0d7cb3 100644 --- a/src/main/cpp/src/xxhash64.cu +++ b/src/main/cpp/src/xxhash64.cu @@ -15,6 +15,7 @@ */ #include "hash.cuh" +#include "hash.hpp" #include #include @@ -34,8 +35,6 @@ namespace { using hash_value_type = int64_t; using half_size_type = int32_t; -constexpr int MAX_STACK_DEPTH = 8; - constexpr __device__ inline int64_t rotate_bits_left_signed(hash_value_type h, int8_t r) { return (h << r) | (h >> (64 - r)) & ~(-1 << r); diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Hash.java b/src/main/java/com/nvidia/spark/rapids/jni/Hash.java index 2b8288286..96b66555a 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Hash.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Hash.java @@ -25,6 +25,8 @@ public class Hash { // there doesn't appear to be a useful constant in spark to reference. this could break. static final long DEFAULT_XXHASH64_SEED = 42; + public static final int MAX_STACK_DEPTH = getMaxStackDepth(); + static { NativeDepsLoader.loadNativeDeps(); } @@ -100,6 +102,8 @@ public static ColumnVector hiveHash(ColumnView columns[]) { return new ColumnVector(hiveHash(columnViews)); } + private static native int getMaxStackDepth(); + private static native long murmurHash32(int seed, long[] viewHandles) throws CudfException; private static native long xxhash64(long seed, long[] viewHandles) throws CudfException;