Skip to content

Commit

Permalink
fix test regression
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Feb 1, 2022
1 parent 001d549 commit 085a5cb
Showing 1 changed file with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,7 +59,8 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite {
conf: SparkConf = new SparkConf(),
execsAllowedNonGpu: Seq[String] = Seq.empty,
batchSize: Int = 0,
repart: Int = 1)
repart: Int = 1,
maxFloatDiff: Double = 0.0)
(fn: DataFrame => DataFrame) {
if (batchSize > 0) {
makeBatchedBytes(batchSize, conf)
Expand All @@ -69,7 +70,7 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite {
testSparkResultsAreEqual(testName, df,
conf = conf, repart = repart,
execsAllowedNonGpu = execsAllowedNonGpu,
incompat = true, sort = true)(fn)
incompat = true, sort = true, maxFloatDiff = maxFloatDiff)(fn)
}

def firstDf(spark: SparkSession): DataFrame = {
Expand Down Expand Up @@ -637,6 +638,7 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite {
FLOAT_TEST_testSparkResultsAreEqual(
"doubles basic aggregates group by doubles",
doubleCsvDf,
maxFloatDiff = 0.0001,
conf = makeBatchedBytes(3, enableCsvConf())) {
frame => frame.groupBy("doubles").agg(
lit(456f),
Expand All @@ -653,6 +655,7 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite {
FLOAT_TEST_testSparkResultsAreEqual(
"doubles basic aggregates group by more_doubles",
doubleCsvDf,
maxFloatDiff = 0.0001,
conf = makeBatchedBytes(3, enableCsvConf())) {
frame => frame.groupBy("more_doubles").agg(
lit(456f),
Expand Down

0 comments on commit 085a5cb

Please sign in to comment.