Skip to content

Commit

Permalink
Make hive column matches not case-sensitive (#11327)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Aug 15, 2024
1 parent a876df0 commit 25be396
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
50 changes: 50 additions & 0 deletions integration_tests/src/main/python/hive_delimited_text_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,24 @@ def read_impl(spark):

return read_impl

def read_hive_text_sql_wrong_case(data_path, schema, spark_tmp_table_factory, options=None):
if options is None:
options = {}
def mk_upper(f):
return StructField(f.name.upper(), f.dataType)

upper_s = StructType(list(map(mk_upper, schema.fields)))
print("CONVERTED " +str(schema) + " TO " + str(upper_s))
lower_fields = ','.join(map(lambda name: name.lower(), schema.fieldNames()))
opts = copy_and_update(options, {'schema': upper_s})

def read_impl(spark):
tmp_name = spark_tmp_table_factory.get()
spark.catalog.createTable(tmp_name, source='hive', path=data_path, **opts)
return spark.sql("SELECT " + lower_fields + " FROM " + tmp_name)

return read_impl


non_utc_allow_for_test_basic_hive_text_read=['HiveTableScanExec', 'DataWritingCommandExec', 'WriteFilesExec'] if is_not_utc() else []
@pytest.mark.skipif(is_spark_cdh(),
Expand Down Expand Up @@ -195,6 +213,20 @@ def test_basic_hive_text_read(std_input_path, name, schema, spark_tmp_table_fact
conf=hive_text_enabled_conf)


@pytest.mark.skipif(is_spark_cdh(),
reason="Hive text reads are disabled on CDH, as per "
"https://github.com/NVIDIA/spark-rapids/pull/7628")
@approximate_float
@pytest.mark.parametrize('name,schema,options', [
('hive-delim-text/simple-boolean-values', make_schema(BooleanType()), {})
], ids=idfn)
@allow_non_gpu(*non_utc_allow_for_test_basic_hive_text_read)
def test_case_insensitive_hive_text_read(std_input_path, name, schema, spark_tmp_table_factory, options):
assert_gpu_and_cpu_are_equal_collect(read_hive_text_sql_wrong_case(std_input_path + '/' + name,
schema, spark_tmp_table_factory, options),
conf=hive_text_enabled_conf)


hive_text_supported_gens = [
StringGen('(\\w| |\t|\ud720){0,10}', nullable=False),
StringGen('[aAbB ]{0,10}'),
Expand Down Expand Up @@ -297,6 +329,24 @@ def test_hive_text_round_trip_partitioned(spark_tmp_path, data_gen, spark_tmp_ta
conf=hive_text_enabled_conf)


@pytest.mark.skipif(is_spark_cdh(),
reason="Hive text reads are disabled on CDH, as per "
"https://github.com/NVIDIA/spark-rapids/pull/7628")
@approximate_float
@allow_non_gpu("EqualTo,IsNotNull,Literal", *non_utc_allow_for_test_basic_hive_text_read) # Accounts for partition predicate: `WHERE dt='1'`
@pytest.mark.parametrize('data_gen', [boolean_gen], ids=idfn)
def test_hive_text_round_trip_partitioned_case_insensitive(spark_tmp_path, data_gen, spark_tmp_table_factory):
gen = StructGen([('my_field', data_gen)], nullable=False)
data_path = spark_tmp_path + '/hive_text_table'
table_name = spark_tmp_table_factory.get()

with_cpu_session(lambda spark: create_hive_text_table_partitioned(spark, gen, table_name, data_path))

# The 'DT' would need to be 'dt' for it to be case sensitive
assert_gpu_and_cpu_are_equal_collect(
lambda spark: read_hive_text_table_partitions(spark, table_name, "DT='1'"),
conf=hive_text_enabled_conf)

@pytest.mark.skipif(is_spark_cdh(),
reason="Hive text reads are disabled on CDH, as per "
"https://github.com/NVIDIA/spark-rapids/pull/7628")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeReference, AttributeSeq, AttributeSet, BindReferences, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.read.PartitionReader
import org.apache.spark.sql.execution.{ExecSubqueryExpression, LeafExecNode, SQLExecution}
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionDirectory, PartitionedFile}
Expand Down Expand Up @@ -233,7 +234,10 @@ case class GpuHiveTableScanExec(requestedAttributes: Seq[Attribute],
val requestedCols = requestedAttributes.filter(a => !partitionKeys.contains(a.name))
.toList
val distinctColumns = requestedCols.distinct
val distinctFields = distinctColumns.map(a => tableSchema.apply(a.name))
// In hive column names are case-insensitive but the default tableSchema lookup is
// case-sensitive
val fieldMap = CaseInsensitiveMap(tableSchema.map(f => (f.name, f)).toMap)
val distinctFields = distinctColumns.map(a => fieldMap(a.name))
StructType(distinctFields)
}

Expand Down

0 comments on commit 25be396

Please sign in to comment.