Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xxhash64 supports nested types [databricks] #11859

Merged
merged 11 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -18553,9 +18553,9 @@ are limited.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down
56 changes: 38 additions & 18 deletions integration_tests/src/main/python/hashing_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,9 +17,7 @@
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect
from data_gen import *
from marks import allow_non_gpu, ignore_order
from spark_session import is_before_spark_320

# Spark 3.1.x does not normalize -0.0 and 0.0 but GPU version does
_xxhash_gens = [
null_gen,
boolean_gen,
Expand All @@ -31,36 +29,58 @@
timestamp_gen,
decimal_gen_32bit,
decimal_gen_64bit,
decimal_gen_128bit]
if not is_before_spark_320():
_xxhash_gens += [float_gen, double_gen]
decimal_gen_128bit,
float_gen,
firestarman marked this conversation as resolved.
Show resolved Hide resolved
double_gen
]

_struct_of_xxhash_gens = StructGen([(f"c{i}", g) for i, g in enumerate(_xxhash_gens)])

_xxhash_fallback_gens = single_level_array_gens + nested_array_gens_sample + [
all_basic_struct_gen,
struct_array_gen,
_struct_of_xxhash_gens]
if is_before_spark_320():
_xxhash_fallback_gens += [float_gen, double_gen]
_xxhash_gens = (_xxhash_gens + [_struct_of_xxhash_gens] + single_level_array_gens
+ nested_array_gens_sample + [
all_basic_struct_gen,
struct_array_gen,
_struct_of_xxhash_gens
] + map_gens_sample)

@ignore_order(local=True)
@pytest.mark.parametrize("gen", _xxhash_gens, ids=idfn)
def test_xxhash64_single_column(gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, gen).selectExpr("a", "xxhash64(a)"))
lambda spark : unary_op_df(spark, gen).selectExpr("a", "xxhash64(a)"),
{"spark.sql.legacy.allowHashOnMapType" : True})

firestarman marked this conversation as resolved.
Show resolved Hide resolved
@ignore_order(local=True)
def test_xxhash64_multi_column():
gen = StructGen(_struct_of_xxhash_gens.children, nullable=False)
col_list = ",".join(gen.data_type.fieldNames())
assert_gpu_and_cpu_are_equal_collect(
lambda spark : gen_df(spark, gen).selectExpr("c0", f"xxhash64({col_list})"))
lambda spark : gen_df(spark, gen).selectExpr("c0", f"xxhash64({col_list})"),
{"spark.sql.legacy.allowHashOnMapType" : True})

def test_xxhash64_8_depth():
gen_8_depth = StructGen([('l1', # level 1
StructGen([('l2',
StructGen([('l3',
res-life marked this conversation as resolved.
Show resolved Hide resolved
StructGen([('l4',
StructGen([('l5',
StructGen([('l6',
StructGen([('l7',
int_gen)]))]))]))]))]))]))]) # level 8
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, gen_8_depth).selectExpr("a", "xxhash64(a)"))

@allow_non_gpu("ProjectExec")
@ignore_order(local=True)
@pytest.mark.parametrize("gen", _xxhash_fallback_gens, ids=idfn)
def test_xxhash64_fallback(gen):
def test_xxhash64_fallback_exceeds_stack_size():
gen_9_depth = StructGen([('l1', # level 1
res-life marked this conversation as resolved.
Show resolved Hide resolved
StructGen([('l2',
StructGen([('l3',
StructGen([('l4',
StructGen([('l5',
StructGen([('l6',
StructGen([('l7',
StructGen([('l8',
int_gen)]))]))]))]))]))]))]))]) # level 9
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, gen).selectExpr("a", "xxhash64(a)"),
lambda spark : unary_op_df(spark, gen_9_depth).selectExpr("a", "xxhash64(a)"),
"ProjectExec")
Original file line number Diff line number Diff line change
Expand Up @@ -3328,6 +3328,14 @@ object GpuOverrides extends Logging {
override val childExprs: Seq[BaseExprMeta[_]] = a.children
.map(GpuOverrides.wrapExpr(_, this.conf, Some(this)))

override def tagExprForGpu(): Unit = {
val maxDepth = a.children.map(
c => XxHash64Utils.computeMaxStackSize(c.dataType)).max
if (maxDepth > Hash.MAX_STACK_DEPTH) {
willNotWorkOnGpu(s"The data type requires a stack size of $maxDepth, " +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not very clear because Maps count as a depth of 2. Users who see this will be confused, because they will try to add things up manually and it will not work out. At a minimum we need to mention that Maps count as a depth of 2 here.

Copy link
Collaborator Author

@res-life res-life Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. commit

s"which exceeds the GPU limit of ${Hash.MAX_STACK_DEPTH}")
}
}
def convertToGpu(): GpuExpression =
GpuXxHash64(childExprs.map(_.convertToGpu()), a.seed)
}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,59 @@ case class GpuXxHash64(children: Seq[Expression], seed: Long) extends GpuHashExp
}
}

object XxHash64Utils {
/**
* Convert map to list of struct.
* Note: Do not support UserDefinedType and other unregular types.
* Do not retain the nullable and other info
*/
private def flatMap(inputType: DataType): DataType = {
inputType match {
case mapType: MapType =>
ArrayType(StructType(Array(
StructField("key", flatMap(mapType.keyType)),
StructField("value", flatMap(mapType.valueType))
)))
case arrayType: ArrayType => ArrayType(flatMap(arrayType.elementType))
case structType: StructType =>
StructType(structType.map(f => StructField(f.name, flatMap(f.dataType))).toArray)
case nullType: NullType => nullType
case atomicType: AtomicType => atomicType
case other => throw new RuntimeException(s"Unsupported type: $other")
}
}

/**
* Compute the max stack size that the flattenType will use.
* @param flattenType should be fatten type generated by function `flatMap`
* @return max stack size
*/
private def computeMaxStackSizeForFlatten(flattenType: DataType): Int = {
flattenType match {
case ArrayType(c: StructType, _) => 1 + computeMaxStackSizeForFlatten(c)
case ArrayType(c: DataType, _) => computeMaxStackSizeForFlatten(c)
case st: StructType =>
1 + st.map(f => computeMaxStackSizeForFlatten(f.dataType)).max
case _ => 1 // primitive types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we don't need the flatMap function and can directly compute the depth of the map type using

case mt: MapType => 
    2 + math.max(computeMaxStackSizeForFlatten(mt.keyType), computeMaxStackSizeForFlatten(mt.valueType))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for not rewriting the data type to check the depth.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, let's post a follow-up PR to do the improvement

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@res-life why? It takes 5 mins and I want the fallback message updated too?

Copy link
Collaborator Author

@res-life res-life Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. commit

}
}

/**
* Compute the max stack size that `inputType` will use,
* refer to the function `check_nested_depth` in src/main/cpp/src/xxhash64.cu
* in spark-rapids-jni repo.
* Note:
* - This should be sync with `check_nested_depth`
* - Map in cuDF is list of struct
*
* @param inputType the input type
* @return the max stack size that xxhash64 will use for this input type.
*/
def computeMaxStackSize(inputType: DataType): Int = {
computeMaxStackSizeForFlatten(flatMap(inputType))
}
}

case class GpuHiveHash(children: Seq[Expression]) extends GpuHashExpression {
override def dataType: DataType = IntegerType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,6 @@ package com.nvidia.spark.rapids.shims
import com.nvidia.spark.rapids.TypeSig

object XxHash64Shims {
val supportedTypes: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128
val supportedTypes: TypeSig = (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested()
}
2 changes: 1 addition & 1 deletion tools/generated_files/320/supportedExprs.csv
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ WindowExpression,S, ,None,window,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,S,
WindowSpecDefinition,S, ,None,project,partition,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,PS,NS,NS,NS
WindowSpecDefinition,S, ,None,project,value,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,PS,NS,NS,NS
WindowSpecDefinition,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,PS,NS,NS,NS
XxHash64,S,`xxhash64`,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS,NS,NS
XxHash64,S,`xxhash64`,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS,NS,NS
XxHash64,S,`xxhash64`,None,project,result,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Year,S,`year`,None,project,input,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Year,S,`year`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Expand Down
2 changes: 1 addition & 1 deletion tools/generated_files/supportedExprs.csv
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ WindowExpression,S, ,None,window,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,S,
WindowSpecDefinition,S, ,None,project,partition,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,PS,NS,NS,NS
WindowSpecDefinition,S, ,None,project,value,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,PS,NS,NS,NS
WindowSpecDefinition,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,PS,NS,NS,NS
XxHash64,S,`xxhash64`,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS,NS,NS
XxHash64,S,`xxhash64`,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS,NS,NS
XxHash64,S,`xxhash64`,None,project,result,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Year,S,`year`,None,project,input,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Year,S,`year`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Expand Down
Loading