Skip to content

Commit

Permalink
Orc writes don't fully support Booleans with nulls (#11763)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhushukla authored Dec 7, 2024
1 parent 738c8e3 commit fb2f72d
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 42 deletions.
8 changes: 5 additions & 3 deletions integration_tests/src/main/python/datasourcev2_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from data_gen import gen_df, decimal_gens, non_utc_allow
from marks import *
from spark_session import is_hive_available, is_spark_330_or_later, with_cpu_session, with_gpu_session
from hive_parquet_write_test import _hive_bucket_gens, _hive_array_gens, _hive_struct_gens
from hive_parquet_write_test import _hive_bucket_gens_sans_bools, _hive_array_gens, _hive_struct_gens
from hive_parquet_write_test import read_single_bucket

_hive_write_conf = {
Expand All @@ -33,9 +33,11 @@
@allow_non_gpu(*non_utc_allow)
def test_write_hive_bucketed_table(spark_tmp_table_factory, file_format):
num_rows = 2048

# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen
def gen_table(spark):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(_hive_bucket_gens)]
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(_hive_bucket_gens_sans_bools)]
types_sql_str = ','.join('{} {}'.format(
name, gen.data_type.simpleString()) for name, gen in gen_list)
col_names_str = ','.join(name for name, gen in gen_list)
Expand Down
5 changes: 3 additions & 2 deletions integration_tests/src/main/python/hive_parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
# "GpuInsertIntoHiveTable" for Parquet write.
_write_to_hive_conf = {"spark.sql.hive.convertMetastoreParquet": False}

_hive_bucket_gens = [
boolean_gen, byte_gen, short_gen, int_gen, long_gen, string_gen, float_gen, double_gen,
_hive_bucket_gens_sans_bools = [
byte_gen, short_gen, int_gen, long_gen, string_gen, float_gen, double_gen,
DateGen(start=date(1590, 1, 1)), _restricted_timestamp()]
_hive_bucket_gens = [boolean_gen] + _hive_bucket_gens_sans_bools

_hive_basic_gens = _hive_bucket_gens + [
DecimalGen(precision=19, scale=1, nullable=True),
Expand Down
10 changes: 8 additions & 2 deletions integration_tests/src/main/python/hive_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ def _restricted_timestamp(nullable=True):
end=datetime(2262, 4, 11, tzinfo=timezone.utc),
nullable=nullable)

# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen
_basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
string_gen, DateGen(start=date(1590, 1, 1)),
_restricted_timestamp()
] + decimal_gens

Expand All @@ -45,8 +48,11 @@ def _restricted_timestamp(nullable=True):
ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10),
ArrayGen(StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]]))]

# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen
_map_gens = [simple_string_to_string_map_gen] + [MapGen(f(nullable=False), f()) for f in [
BooleanGen, ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen,
ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen,
lambda nullable=True: _restricted_timestamp(nullable=nullable),
lambda nullable=True: DateGen(start=date(1590, 1, 1), nullable=nullable),
lambda nullable=True: DecimalGen(precision=15, scale=1, nullable=nullable),
Expand Down
42 changes: 33 additions & 9 deletions integration_tests/src/main/python/orc_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -112,8 +112,11 @@ def test_basic_read(std_input_path, name, read_func, v1_enabled_list, orc_impl,
#E at org.apache.orc.TypeDescription.parseInt(TypeDescription.java:244)
#E at org.apache.orc.TypeDescription.parseType(TypeDescription.java:362)
# ...
# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen
orc_basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
string_gen, DateGen(start=date(1590, 1, 1)),
orc_timestamp_gen] + decimal_gens

orc_basic_struct_gen = StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(orc_basic_gens)])
Expand Down Expand Up @@ -201,8 +204,11 @@ def test_read_round_trip(spark_tmp_path, orc_gens, read_func, reader_confs, v1_e
read_func(data_path),
conf=all_confs)

# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen
orc_pred_push_gens = [
byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen,
byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen,
# Once https://github.com/NVIDIA/spark-rapids/issues/139 is fixed replace this with
# date_gen
Expand Down Expand Up @@ -277,8 +283,11 @@ def test_compress_read_round_trip(spark_tmp_path, compress, v1_enabled_list, rea
def test_simple_partitioned_read(spark_tmp_path, v1_enabled_list, reader_confs):
# Once https://github.com/NVIDIA/spark-rapids/issues/131 is fixed
# we should go with a more standard set of generators
# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen
orc_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
string_gen, DateGen(start=date(1590, 1, 1)),
orc_timestamp_gen]
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)]
first_data_path = spark_tmp_path + '/ORC_DATA/key=0/key2=20'
Expand Down Expand Up @@ -344,8 +353,11 @@ def test_partitioned_read_just_partitions(spark_tmp_path, v1_enabled_list, reade
def test_merge_schema_read(spark_tmp_path, v1_enabled_list, reader_confs):
# Once https://github.com/NVIDIA/spark-rapids/issues/131 is fixed
# we should go with a more standard set of generators
# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen
orc_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
string_gen, DateGen(start=date(1590, 1, 1)),
orc_timestamp_gen]
first_gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)]
first_data_path = spark_tmp_path + '/ORC_DATA/key=0'
Expand Down Expand Up @@ -825,8 +837,11 @@ def test_read_round_trip_for_multithreaded_combining(spark_tmp_path, gens, keep_
@pytest.mark.parametrize('keep_order', [True, pytest.param(False, marks=pytest.mark.ignore_order(local=True))])
@allow_non_gpu(*non_utc_allow_orc_scan)
def test_simple_partitioned_read_for_multithreaded_combining(spark_tmp_path, keep_order):
# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen
orc_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
string_gen, DateGen(start=date(1590, 1, 1)),
orc_timestamp_gen]
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)]
first_data_path = spark_tmp_path + '/ORC_DATA/key=0/key2=20'
Expand Down Expand Up @@ -927,7 +942,10 @@ def test_orc_column_name_with_dots(spark_tmp_path, reader_confs):
("f.g", int_gen),
("h", string_gen)])),
("i.j", long_gen)])),
("k", boolean_gen)]
# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen for column k
("k", int_gen)]
with_cpu_session(lambda spark: gen_df(spark, gens).write.orc(data_path))
assert_gpu_and_cpu_are_equal_collect(lambda spark: reader(spark), conf=all_confs)
assert_gpu_and_cpu_are_equal_collect(lambda spark: reader(spark).selectExpr("`a.b`"), conf=all_confs)
Expand All @@ -945,7 +963,10 @@ def test_orc_with_null_column(spark_tmp_path, reader_confs):
def gen_null_df(spark):
return spark.createDataFrame(
[(None, None, None, None, None)],
"c1 int, c2 long, c3 float, c4 double, c5 boolean")
# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen
"c1 int, c2 long, c3 float, c4 double, c5 int")

assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_null_df(spark).write.orc(path),
Expand All @@ -966,7 +987,10 @@ def test_orc_with_null_column_with_1m_rows(spark_tmp_path, reader_confs):
def gen_null_df(spark):
return spark.createDataFrame(
data,
"c1 int, c2 long, c3 float, c4 double, c5 boolean")
# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen
"c1 int, c2 long, c3 float, c4 double, c5 int")
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_null_df(spark).write.orc(path),
lambda spark, path: spark.read.orc(path),
Expand Down
51 changes: 43 additions & 8 deletions integration_tests/src/main/python/orc_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
from pyspark.sql.types import *

pytestmark = pytest.mark.nightly_resource_consuming_test

# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen.
orc_write_basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
string_gen, DateGen(start=date(1590, 1, 1)),
TimestampGen(start=datetime(1970, 1, 1, tzinfo=timezone.utc)) ] + \
decimal_gens

Expand All @@ -52,7 +54,8 @@
all_nulls_map_gen,
all_empty_map_gen]

orc_write_basic_struct_gen = StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(orc_write_basic_gens)])
orc_write_basic_struct_gen = StructGen(
[['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(orc_write_basic_gens)])

orc_write_struct_gens_sample = [orc_write_basic_struct_gen,
StructGen([['child0', byte_gen], ['child1', orc_write_basic_struct_gen]]),
Expand All @@ -62,15 +65,18 @@
ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10),
ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10),
ArrayGen(StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]]))]

# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen.
orc_write_basic_map_gens = [simple_string_to_string_map_gen] + [MapGen(f(nullable=False), f()) for f in [
BooleanGen, ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen,
ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen,
# Using timestamps from 1970 to work around a cudf ORC bug
# https://github.com/NVIDIA/spark-rapids/issues/140.
lambda nullable=True: TimestampGen(start=datetime(1970, 1, 1, tzinfo=timezone.utc), nullable=nullable),
lambda nullable=True: DateGen(start=date(1590, 1, 1), nullable=nullable),
lambda nullable=True: DecimalGen(precision=15, scale=1, nullable=nullable),
lambda nullable=True: DecimalGen(precision=36, scale=5, nullable=nullable)]]
lambda nullable=True: DecimalGen(precision=36, scale=5, nullable=nullable)]] + [MapGen(
f(nullable=False), f(nullable=False)) for f in [IntegerGen]]

orc_write_gens_list = [orc_write_basic_gens,
orc_write_struct_gens_sample,
Expand All @@ -79,6 +85,7 @@
pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/139')),
pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/140'))]

bool_gen = [BooleanGen(nullable=True), BooleanGen(nullable=False)]
@pytest.mark.parametrize('orc_gens', orc_write_gens_list, ids=idfn)
@pytest.mark.parametrize('orc_impl', ["native", "hive"])
@allow_non_gpu(*non_utc_allow)
Expand All @@ -91,6 +98,30 @@ def test_write_round_trip(spark_tmp_path, orc_gens, orc_impl):
data_path,
conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True})

@pytest.mark.parametrize('orc_gens', [bool_gen], ids=idfn)
@pytest.mark.parametrize('orc_impl', ["native", "hive"])
@allow_non_gpu('ExecutedCommandExec', 'DataWritingCommandExec', 'WriteFilesExec')
def test_write_round_trip_bools_only_fallback(spark_tmp_path, orc_gens, orc_impl):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)]
data_path = spark_tmp_path + '/ORC_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.orc(path),
lambda spark, path: spark.read.orc(path),
data_path,
conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True})

@pytest.mark.parametrize('orc_gens', [bool_gen], ids=idfn)
@pytest.mark.parametrize('orc_impl', ["native", "hive"])
def test_write_round_trip_bools_only_no_fallback(spark_tmp_path, orc_gens, orc_impl):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)]
data_path = spark_tmp_path + '/ORC_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.orc(path),
lambda spark, path: spark.read.orc(path),
data_path,
conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True,
'spark.rapids.sql.format.orc.write.boolType.enabled': True})

@pytest.mark.parametrize('orc_gen', orc_write_odd_empty_strings_gens_sample, ids=idfn)
@pytest.mark.parametrize('orc_impl', ["native", "hive"])
def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl):
Expand All @@ -103,7 +134,8 @@ def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl):
conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True})

orc_part_write_gens = [
byte_gen, short_gen, int_gen, long_gen, boolean_gen,
# Add back boolean_gen when https://github.com/rapidsai/cudf/issues/6763 is fixed
byte_gen, short_gen, int_gen, long_gen,
# Some file systems have issues with UTF8 strings so to help the test pass even there
StringGen('(\\w| ){0,50}'),
# Once https://github.com/NVIDIA/spark-rapids/issues/139 is fixed replace this with
Expand Down Expand Up @@ -345,7 +377,10 @@ def test_orc_write_column_name_with_dots(spark_tmp_path):
("f.g", int_gen),
("h", string_gen)])),
("i.j", long_gen)])),
("k", boolean_gen)]
# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen for column k
("k", int_gen)]
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gens).coalesce(1).write.orc(path),
lambda spark, path: spark.read.orc(path),
Expand Down
11 changes: 8 additions & 3 deletions integration_tests/src/main/python/schema_evolution_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,7 +34,9 @@

# List of additional column data generators to use when adding columns
_additional_gens = [
boolean_gen,
# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen
byte_gen,
short_gen,
int_gen,
Expand All @@ -49,7 +51,10 @@
# simple_string_to_string_map_gen),
ArrayGen(_custom_date_gen),
struct_gen_decimal128,
StructGen([("c0", ArrayGen(long_gen)), ("c1", boolean_gen)]),
# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and
# https://github.com/rapidsai/cudf/issues/6763 .
# Once the first issue is fixed, add back boolean_gen from int_gen for c1
StructGen([("c0", ArrayGen(long_gen)), ("c1", int_gen)]),
]

def get_additional_columns():
Expand Down
10 changes: 10 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,14 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.booleanConf
.createWithDefault(true)

val ENABLE_ORC_BOOL = conf("spark.rapids.sql.format.orc.write.boolType.enabled")
.doc("When set to false disables boolean columns for ORC writes. " +
"Set to true if you want to experiment. " +
"See https://github.com/NVIDIA/spark-rapids/issues/11736.")
.internal()
.booleanConf
.createWithDefault(false)

val ENABLE_EXPAND_PREPROJECT = conf("spark.rapids.sql.expandPreproject.enabled")
.doc("When set to false disables the pre-projection for GPU Expand. " +
"Pre-projection leverages the tiered projection to evaluate expressions that " +
Expand Down Expand Up @@ -3028,6 +3036,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val maxNumOrcFilesParallel: Int = get(ORC_MULTITHREAD_READ_MAX_NUM_FILES_PARALLEL)

lazy val isOrcBoolTypeEnabled: Boolean = get(ENABLE_ORC_BOOL)

lazy val isCsvEnabled: Boolean = get(ENABLE_CSV)

lazy val isCsvReadEnabled: Boolean = get(ENABLE_CSV_READ)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.orc.{OrcFileFormat, OrcOptions, OrcUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types._

object GpuOrcFileFormat extends Logging {
Expand Down Expand Up @@ -83,6 +84,11 @@ object GpuOrcFileFormat extends Logging {
// [[org.apache.spark.sql.execution.datasources.DaysWritable]] object
// which is a subclass of [[org.apache.hadoop.hive.serde2.io.DateWritable]].
val types = schema.map(_.dataType).toSet
val hasBools = schema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, t =>
t.isInstanceOf[BooleanType])
}

if (types.exists(GpuOverrides.isOrContainsDateOrTimestamp(_))) {
if (!GpuOverrides.isUTCTimezone()) {
meta.willNotWorkOnGpu("Only UTC timezone is supported for ORC. " +
Expand All @@ -91,6 +97,10 @@ object GpuOrcFileFormat extends Logging {
}
}

if (hasBools && !meta.conf.isOrcBoolTypeEnabled) {
meta.willNotWorkOnGpu("Nullable Booleans can not work in certain cases with ORC writer." +
"See https://github.com/rapidsai/cudf/issues/6763")
}
FileFormatChecks.tag(meta, schema, OrcFormatType, WriteFileOp)

val sqlConf = spark.sessionState.conf
Expand Down
Loading

0 comments on commit fb2f72d

Please sign in to comment.