diff --git a/docs/compatibility.md b/docs/compatibility.md index 2a950f9069e..cf18d0e72ad 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -349,7 +349,7 @@ with Spark, and can be enabled by setting `spark.rapids.sql.expression.JsonToStr Dates are partially supported but there are some known issues: -- Only the default `dateFormat` of `yyyy-MM-dd` is supported. The query will fall back to CPU if any other format +- Only the default `dateFormat` of `yyyy-MM-dd` is supported in Spark 3.1.x. The query will fall back to CPU if any other format is specified ([#9667](https://github.com/NVIDIA/spark-rapids/issues/9667)) - Strings containing integers with more than four digits will be parsed as null ([#9664](https://github.com/NVIDIA/spark-rapids/issues/9664)) whereas Spark versions prior to 3.4 diff --git a/integration_tests/src/main/python/json_test.py b/integration_tests/src/main/python/json_test.py index 69de6a326c3..2fa8bcf1fe6 100644 --- a/integration_tests/src/main/python/json_test.py +++ b/integration_tests/src/main/python/json_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,6 +37,8 @@ DoubleGen(no_nans=False) ] +optional_whitespace_regex = '[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?' + _enable_all_types_conf = { 'spark.rapids.sql.format.json.enabled': 'true', 'spark.rapids.sql.format.json.read.enabled': 'true', @@ -81,6 +83,35 @@ _string_schema = StructType([ StructField('a', StringType())]) +json_supported_date_formats = [ + None, # represents not specifying a format (which is different from explicitly specifying the default format in some Spark versions) + 'yyyy-MM-dd', 'yyyy/MM/dd', + 'yyyy-MM', 'yyyy/MM', + 'MM-yyyy', 'MM/yyyy', + 'MM-dd-yyyy', 'MM/dd/yyyy', + 'dd-MM-yyyy', 'dd/MM/yyyy'] + +json_supported_ts_parts = [ + "'T'HH:mm:ss.SSSXXX", + "'T'HH:mm:ss[.SSS][XXX]", + "'T'HH:mm:ss.SSS", + "'T'HH:mm:ss[.SSS]", + "'T'HH:mm:ss", + "'T'HH:mm[:ss]", + "'T'HH:mm"] + +json_supported_timestamp_formats = [ + None, # represents not specifying a format (which is different from explicitly specifying the default format in some Spark versions) +] +for date_part in json_supported_date_formats: + if date_part: + # use date format without time component + json_supported_timestamp_formats.append(date_part) + # use date format and each supported time format + for ts_part in json_supported_ts_parts: + json_supported_timestamp_formats.append(date_part + ts_part) + + def read_json_df(data_path, schema, spark_tmp_table_factory_ignored, options = {}): def read_impl(spark): reader = spark.read @@ -153,9 +184,7 @@ def test_json_input_meta(spark_tmp_path, v1_enabled_list): conf=updated_conf) allow_non_gpu_for_json_scan = ['FileSourceScanExec', 'BatchScanExec'] if is_not_utc() else [] -json_supported_date_formats = ['yyyy-MM-dd', 'yyyy/MM/dd', 'yyyy-MM', 'yyyy/MM', - 'MM-yyyy', 'MM/yyyy', 'MM-dd-yyyy', 'MM/dd/yyyy', 'dd-MM-yyyy', 'dd/MM/yyyy'] -@pytest.mark.parametrize('date_format', json_supported_date_formats, ids=idfn) +@pytest.mark.parametrize('date_format', [None, 'yyyy-MM-dd'] if is_before_spark_320 else json_supported_date_formats, ids=idfn) @pytest.mark.parametrize('v1_enabled_list', ["", "json"]) @allow_non_gpu(*allow_non_gpu_for_json_scan) def test_json_date_formats_round_trip(spark_tmp_path, date_format, v1_enabled_list): @@ -163,75 +192,82 @@ def test_json_date_formats_round_trip(spark_tmp_path, date_format, v1_enabled_li data_path = spark_tmp_path + '/JSON_DATA' schema = gen.data_type updated_conf = copy_and_update(_enable_all_types_conf, {'spark.sql.sources.useV1SourceList': v1_enabled_list}) - with_cpu_session( - lambda spark : gen_df(spark, gen).write\ - .option('dateFormat', date_format)\ - .json(data_path)) + + def create_test_data(spark): + write = gen_df(spark, gen).write + if date_format: + write = write.option('dateFormat', date_format) + return write.json(data_path) + + with_cpu_session(lambda spark : create_test_data(spark)) + + def do_read(spark): + read = spark.read.schema(schema) + if date_format: + read = read.option('dateFormat', date_format) + return read.json(data_path) + assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read\ - .schema(schema)\ - .option('dateFormat', date_format)\ - .json(data_path), + lambda spark: do_read(spark), conf=updated_conf) -json_supported_ts_parts = ['', # Just the date - "'T'HH:mm:ss.SSSXXX", - "'T'HH:mm:ss[.SSS][XXX]", - "'T'HH:mm:ss.SSS", - "'T'HH:mm:ss[.SSS]", - "'T'HH:mm:ss", - "'T'HH:mm[:ss]", - "'T'HH:mm"] not_utc_allow_for_test_json_scan = ['BatchScanExec', 'FileSourceScanExec'] if is_not_utc() else [] -@pytest.mark.parametrize('ts_part', json_supported_ts_parts) -@pytest.mark.parametrize('date_format', json_supported_date_formats) -@pytest.mark.parametrize('v1_enabled_list', ["", "json"]) @allow_non_gpu(*not_utc_allow_for_test_json_scan) -def test_json_ts_formats_round_trip(spark_tmp_path, date_format, ts_part, v1_enabled_list): - full_format = date_format + ts_part +@pytest.mark.parametrize('timestamp_format', json_supported_timestamp_formats) +@pytest.mark.parametrize('v1_enabled_list', ["", "json"]) +def test_json_ts_formats_round_trip(spark_tmp_path, timestamp_format, v1_enabled_list): data_gen = TimestampGen() gen = StructGen([('a', data_gen)], nullable=False) data_path = spark_tmp_path + '/JSON_DATA' schema = gen.data_type - with_cpu_session( - lambda spark : gen_df(spark, gen).write\ - .option('timestampFormat', full_format)\ - .json(data_path)) + + def create_test_data(spark): + write = gen_df(spark, gen).write + if timestamp_format: + write = write.option('timestampFormat', timestamp_format) + write.json(data_path) + + with_cpu_session(lambda spark: create_test_data(spark)) updated_conf = copy_and_update(_enable_all_types_conf, {'spark.sql.sources.useV1SourceList': v1_enabled_list}) + + def do_read(spark): + read = spark.read.schema(schema) + if timestamp_format: + read = read.option('timestampFormat', timestamp_format) + return read.json(data_path) + assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read\ - .schema(schema)\ - .option('timestampFormat', full_format)\ - .json(data_path), + lambda spark: do_read(spark), conf=updated_conf) @allow_non_gpu('FileSourceScanExec', 'ProjectExec') -@pytest.mark.skipif(is_before_spark_341(), reason='`TIMESTAMP_NTZ` is only supported in PySpark 341+.') -@pytest.mark.parametrize('ts_part', json_supported_ts_parts) -@pytest.mark.parametrize('date_format', json_supported_date_formats) +@pytest.mark.skipif(is_before_spark_341(), reason='`TIMESTAMP_NTZ` is only supported in PySpark 341+') +@pytest.mark.parametrize('timestamp_format', json_supported_timestamp_formats) @pytest.mark.parametrize("timestamp_type", ["TIMESTAMP_LTZ", "TIMESTAMP_NTZ"]) -def test_json_ts_formats_round_trip_ntz_v1(spark_tmp_path, date_format, ts_part, timestamp_type): - json_ts_formats_round_trip_ntz(spark_tmp_path, date_format, ts_part, timestamp_type, 'json', 'FileSourceScanExec') +def test_json_ts_formats_round_trip_ntz_v1(spark_tmp_path, timestamp_format, timestamp_type): + json_ts_formats_round_trip_ntz(spark_tmp_path, timestamp_format, timestamp_type, 'json', 'FileSourceScanExec') @allow_non_gpu('BatchScanExec', 'ProjectExec') -@pytest.mark.skipif(is_before_spark_341(), reason='`TIMESTAMP_NTZ` is only supported in PySpark 341+.') -@pytest.mark.parametrize('ts_part', json_supported_ts_parts) -@pytest.mark.parametrize('date_format', json_supported_date_formats) +@pytest.mark.skipif(is_before_spark_341(), reason='`TIMESTAMP_NTZ` is only supported in PySpark 341+') +@pytest.mark.parametrize('timestamp_format', json_supported_timestamp_formats) @pytest.mark.parametrize("timestamp_type", ["TIMESTAMP_LTZ", "TIMESTAMP_NTZ"]) -def test_json_ts_formats_round_trip_ntz_v2(spark_tmp_path, date_format, ts_part, timestamp_type): - json_ts_formats_round_trip_ntz(spark_tmp_path, date_format, ts_part, timestamp_type, '', 'BatchScanExec') +def test_json_ts_formats_round_trip_ntz_v2(spark_tmp_path, timestamp_format, timestamp_type): + json_ts_formats_round_trip_ntz(spark_tmp_path, timestamp_format, timestamp_type, '', 'BatchScanExec') -def json_ts_formats_round_trip_ntz(spark_tmp_path, date_format, ts_part, timestamp_type, v1_enabled_list, cpu_scan_class): - full_format = date_format + ts_part +def json_ts_formats_round_trip_ntz(spark_tmp_path, timestamp_format, timestamp_type, v1_enabled_list, cpu_scan_class): data_gen = TimestampGen(tzinfo=None if timestamp_type == "TIMESTAMP_NTZ" else timezone.utc) gen = StructGen([('a', data_gen)], nullable=False) data_path = spark_tmp_path + '/JSON_DATA' schema = gen.data_type - with_cpu_session( - lambda spark : gen_df(spark, gen).write \ - .option('timestampFormat', full_format) \ - .json(data_path)) + + def create_test_data(spark): + write = gen_df(spark, gen).write + if timestamp_format: + write = write.option('timestampFormat', timestamp_format) + write.json(data_path) + + with_cpu_session(lambda spark: create_test_data(spark)) updated_conf = copy_and_update(_enable_all_types_conf, { 'spark.sql.sources.useV1SourceList': v1_enabled_list, @@ -239,10 +275,10 @@ def json_ts_formats_round_trip_ntz(spark_tmp_path, date_format, ts_part, timesta }) def do_read(spark): - return spark.read \ - .schema(schema) \ - .option('timestampFormat', full_format) \ - .json(data_path) + read = spark.read.schema(schema) + if timestamp_format: + read = read.option('timestampFormat', timestamp_format) + return read.json(data_path) if timestamp_type == "TIMESTAMP_LTZ": @@ -286,20 +322,31 @@ def do_read(spark): _float_schema, _double_schema, _decimal_10_2_schema, _decimal_10_3_schema, \ _date_schema]) @pytest.mark.parametrize('read_func', [read_json_df, read_json_sql]) -@pytest.mark.parametrize('allow_non_numeric_numbers', ["true", "false"]) -@pytest.mark.parametrize('allow_numeric_leading_zeros', ["true"]) +@pytest.mark.parametrize('allow_non_numeric_numbers', ['true', 'false']) +@pytest.mark.parametrize('allow_numeric_leading_zeros', [ + 'true', + pytest.param('false', marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/9588')), +]) @pytest.mark.parametrize('ansi_enabled', ["true", "false"]) @allow_non_gpu(*not_utc_allow_for_test_json_scan) -def test_basic_json_read(std_input_path, filename, schema, read_func, allow_non_numeric_numbers, allow_numeric_leading_zeros, ansi_enabled, spark_tmp_table_factory): +@pytest.mark.parametrize('date_format', [None, 'yyyy-MM-dd']) +def test_basic_json_read(std_input_path, filename, schema, read_func, allow_non_numeric_numbers, \ + allow_numeric_leading_zeros, ansi_enabled, spark_tmp_table_factory, date_format): updated_conf = copy_and_update(_enable_all_types_conf, {'spark.sql.ansi.enabled': ansi_enabled, 'spark.sql.legacy.timeParserPolicy': 'CORRECTED'}) + options = {"allowNonNumericNumbers": allow_non_numeric_numbers, + "allowNumericLeadingZeros": allow_numeric_leading_zeros, + } + + if date_format: + options['dateFormat'] = date_format + assert_gpu_and_cpu_are_equal_collect( read_func(std_input_path + '/' + filename, - schema, - spark_tmp_table_factory, - { "allowNonNumericNumbers": allow_non_numeric_numbers, - "allowNumericLeadingZeros": allow_numeric_leading_zeros}), + schema, + spark_tmp_table_factory, + options), conf=updated_conf) @ignore_order @@ -368,6 +415,39 @@ def test_json_read_valid_dates(std_input_path, filename, schema, read_func, ansi else: assert_gpu_and_cpu_are_equal_collect(f, conf=updated_conf) +@pytest.mark.parametrize('date_gen_pattern', [ + '[0-9]{1,4}-[0-3]{1,2}-[0-3]{1,2}', + '[0-9]{1,2}-[0-3]{1,2}-[0-9]{1,4}', + '[1-9]{4}-[1-3]{2}-[1-3]{2}', + '[1-9]{4}-[1-3]{1,2}-[1-3]{1,2}', + '[1-3]{1,2}-[1-3]{1,2}-[1-9]{4}', + '[1-3]{1,2}/[1-3]{1,2}/[1-9]{4}', +]) +@pytest.mark.parametrize('schema', [StructType([StructField('value', DateType())])]) +@pytest.mark.parametrize('date_format', [None, 'yyyy-MM-dd'] if is_before_spark_320 else json_supported_date_formats) +@pytest.mark.parametrize('ansi_enabled', [True, False]) +@pytest.mark.parametrize('allow_numeric_leading_zeros', [True, False]) +@allow_non_gpu(*allow_non_gpu_for_json_scan) +def test_json_read_generated_dates(spark_tmp_table_factory, spark_tmp_path, date_gen_pattern, schema, date_format, \ + ansi_enabled, allow_numeric_leading_zeros): + # create test data with json strings where a subset are valid dates + # example format: {"value":"3481-1-31"} + path = spark_tmp_path + '/JSON_DATA' + + data_gen = StringGen(optional_whitespace_regex + date_gen_pattern + optional_whitespace_regex, nullable=False) + + with_cpu_session(lambda spark: gen_df(spark, data_gen).write.json(path)) + + updated_conf = copy_and_update(_enable_all_types_conf, { + 'spark.sql.ansi.enabled': ansi_enabled, + 'spark.sql.legacy.timeParserPolicy': 'CORRECTED'}) + + options = { 'allowNumericLeadingZeros': allow_numeric_leading_zeros } + if date_format: + options['dateFormat'] = date_format + + f = read_json_df(path, schema, spark_tmp_table_factory, options) + assert_gpu_and_cpu_are_equal_collect(f, conf = updated_conf) @approximate_float @pytest.mark.parametrize('filename', [ @@ -376,16 +456,19 @@ def test_json_read_valid_dates(std_input_path, filename, schema, read_func, ansi @pytest.mark.parametrize('schema', [_date_schema]) @pytest.mark.parametrize('read_func', [read_json_df, read_json_sql]) @pytest.mark.parametrize('ansi_enabled', ["true", "false"]) +@pytest.mark.parametrize('date_format', [None, 'yyyy-MM-dd'] if is_before_spark_320 else json_supported_date_formats) @pytest.mark.parametrize('time_parser_policy', [ pytest.param('LEGACY', marks=pytest.mark.allow_non_gpu('FileSourceScanExec')), pytest.param('CORRECTED', marks=pytest.mark.allow_non_gpu(*not_utc_json_scan_allow)), pytest.param('EXCEPTION', marks=pytest.mark.allow_non_gpu(*not_utc_json_scan_allow)) ]) -def test_json_read_invalid_dates(std_input_path, filename, schema, read_func, ansi_enabled, time_parser_policy, spark_tmp_table_factory): +def test_json_read_invalid_dates(std_input_path, filename, schema, read_func, ansi_enabled, date_format, \ + time_parser_policy, spark_tmp_table_factory): updated_conf = copy_and_update(_enable_all_types_conf, {'spark.sql.ansi.enabled': ansi_enabled, 'spark.sql.legacy.timeParserPolicy': time_parser_policy }) - f = read_func(std_input_path + '/' + filename, schema, spark_tmp_table_factory, {}) + options = { 'dateFormat': date_format } if date_format else {} + f = read_func(std_input_path + '/' + filename, schema, spark_tmp_table_factory, options) if time_parser_policy == 'EXCEPTION': assert_gpu_and_cpu_error( df_fun=lambda spark: f(spark).collect(), @@ -551,15 +634,15 @@ def test_from_json_struct_decimal(): @pytest.mark.parametrize('date_gen', [ # "yyyy-MM-dd" - "\"[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?[1-8]{1}[0-9]{3}-[0-3]{1,2}-[0-3]{1,2}[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?\"", + "\"" + optional_whitespace_regex + "[1-8]{1}[0-9]{3}-[0-3]{1,2}-[0-3]{1,2}" + optional_whitespace_regex + "\"", # "yyyy-MM" - "\"[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?[1-8]{1}[0-9]{3}-[0-3]{1,2}[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?\"", + "\"" + optional_whitespace_regex + "[1-8]{1}[0-9]{3}-[0-3]{1,2}" + optional_whitespace_regex + "\"", # "yyyy" - "\"[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?[0-9]{4}[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?\"", + "\"" + optional_whitespace_regex + "[0-9]{4}" + optional_whitespace_regex + "\"", # "dd/MM/yyyy" - "\"[0-9]{2}/[0-9]{2}/[1-8]{1}[0-9]{3}\"", + "\"" + optional_whitespace_regex + "[0-9]{2}/[0-9]{2}/[1-8]{1}[0-9]{3}" + optional_whitespace_regex + "\"", # special constant values - "\"(now|today|tomorrow|epoch)\"", + "\"" + optional_whitespace_regex + "(now|today|tomorrow|epoch)" + optional_whitespace_regex + "\"", # "nnnnn" (number of days since epoch prior to Spark 3.4, throws exception from 3.4) pytest.param("\"[0-9]{5}\"", marks=pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/9664")), # integral @@ -569,38 +652,30 @@ def test_from_json_struct_decimal(): # boolean "(true|false)" ]) -@pytest.mark.parametrize('date_format', [ - pytest.param("", marks=pytest.mark.allow_non_gpu(*non_utc_project_allow)), - pytest.param("yyyy-MM-dd", marks=pytest.mark.allow_non_gpu(*non_utc_project_allow)), - # https://github.com/NVIDIA/spark-rapids/issues/9667 - pytest.param("dd/MM/yyyy", marks=pytest.mark.allow_non_gpu('ProjectExec')), -]) -@pytest.mark.parametrize('time_parser_policy', [ - pytest.param("LEGACY", marks=pytest.mark.allow_non_gpu('ProjectExec')), - "CORRECTED" -]) -def test_from_json_struct_date(date_gen, date_format, time_parser_policy): +@pytest.mark.parametrize('date_format', [None, 'yyyy-MM-dd'] if is_before_spark_320 else json_supported_date_formats) +@allow_non_gpu(*non_utc_project_allow) +def test_from_json_struct_date(date_gen, date_format): json_string_gen = StringGen(r'{ "a": ' + date_gen + ' }') \ .with_special_case('{ "a": null }') \ .with_special_case('null') - options = { 'dateFormat': date_format } if len(date_format) > 0 else { } + options = { 'dateFormat': date_format } if date_format else { } assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, json_string_gen) \ .select(f.col('a'), f.from_json('a', 'struct', options)), - conf={"spark.rapids.sql.expression.JsonToStructs": True, - 'spark.sql.legacy.timeParserPolicy': time_parser_policy}) + conf={'spark.rapids.sql.expression.JsonToStructs': True, + 'spark.sql.legacy.timeParserPolicy': 'CORRECTED'}) @allow_non_gpu('ProjectExec') @pytest.mark.parametrize('date_gen', ["\"[1-8]{1}[0-9]{3}-[0-3]{1,2}-[0-3]{1,2}\""]) @pytest.mark.parametrize('date_format', [ - "", + None, "yyyy-MM-dd", ]) def test_from_json_struct_date_fallback_legacy(date_gen, date_format): json_string_gen = StringGen(r'{ "a": ' + date_gen + ' }') \ .with_special_case('{ "a": null }') \ .with_special_case('null') - options = { 'dateFormat': date_format } if len(date_format) > 0 else { } + options = { 'dateFormat': date_format } if date_format else { } assert_gpu_fallback_collect( lambda spark : unary_op_df(spark, json_string_gen) \ .select(f.col('a'), f.from_json('a', 'struct', options)), @@ -608,6 +683,7 @@ def test_from_json_struct_date_fallback_legacy(date_gen, date_format): conf={"spark.rapids.sql.expression.JsonToStructs": True, 'spark.sql.legacy.timeParserPolicy': 'LEGACY'}) +@pytest.mark.skipif(is_spark_320_or_later(), reason="We only fallback for non-default formats prior to 320") @allow_non_gpu('ProjectExec') @pytest.mark.parametrize('date_gen', ["\"[1-8]{1}[0-9]{3}-[0-3]{1,2}-[0-3]{1,2}\""]) @pytest.mark.parametrize('date_format', [ @@ -618,7 +694,7 @@ def test_from_json_struct_date_fallback_non_default_format(date_gen, date_format json_string_gen = StringGen(r'{ "a": ' + date_gen + ' }') \ .with_special_case('{ "a": null }') \ .with_special_case('null') - options = { 'dateFormat': date_format } if len(date_format) > 0 else { } + options = { 'dateFormat': date_format } assert_gpu_fallback_collect( lambda spark : unary_op_df(spark, json_string_gen) \ .select(f.col('a'), f.from_json('a', 'struct', options)), @@ -631,22 +707,22 @@ def test_from_json_struct_date_fallback_non_default_format(date_gen, date_format @pytest.mark.parametrize('timestamp_gen', [ # "yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]" - "\"[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?[1-8]{1}[0-9]{3}-[0-3]{1,2}-[0-3]{1,2}T[0-9]{1,2}:[0-9]{1,2}:[0-9]{1,2}(\\.[0-9]{1,6})?Z?[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]}?\"", + "\"" + optional_whitespace_regex + "[1-8]{1}[0-9]{3}-[0-3]{1,2}-[0-3]{1,2}T[0-9]{1,2}:[0-9]{1,2}:[0-9]{1,2}(\\.[0-9]{1,6})?Z?" + optional_whitespace_regex + "\"", # "yyyy-MM-dd" - "\"[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?[1-8]{1}[0-9]{3}-[0-3]{1,2}-[0-3]{1,2}[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?\"", + "\"" + optional_whitespace_regex + "[1-8]{1}[0-9]{3}-[0-3]{1,2}-[0-3]{1,2}" + optional_whitespace_regex + "\"", # "yyyy-MM" - "\"[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?[1-8]{1}[0-9]{3}-[0-3]{1,2}[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?\"", + "\"" + optional_whitespace_regex + "[1-8]{1}[0-9]{3}-[0-3]{1,2}" + optional_whitespace_regex + "\"", # "yyyy" - "\"[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?" + yyyy_start_0001 + "[ \t\xA0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]?\"", + "\"" + optional_whitespace_regex + yyyy_start_0001 + optional_whitespace_regex + "\"", # "dd/MM/yyyy" - "\"[0-9]{2}/[0-9]{2}/[1-8]{1}[0-9]{3}\"", + "\"" + optional_whitespace_regex + "[0-9]{2}/[0-9]{2}/[1-8]{1}[0-9]{3}" + optional_whitespace_regex + "\"", # special constant values - pytest.param("\"(now|today|tomorrow|epoch)\"", marks=pytest.mark.xfail(condition=is_before_spark_320(), reason="https://github.com/NVIDIA/spark-rapids/issues/9724")), + pytest.param("\"" + optional_whitespace_regex + "(now|today|tomorrow|epoch)" + optional_whitespace_regex + "\"", marks=pytest.mark.xfail(condition=is_before_spark_320(), reason="https://github.com/NVIDIA/spark-rapids/issues/9724")), # "nnnnn" (number of days since epoch prior to Spark 3.4, throws exception from 3.4) - pytest.param("\"[0-9]{5}\"", marks=pytest.mark.skip(reason="https://github.com/NVIDIA/spark-rapids/issues/9664")), + pytest.param("\"" + optional_whitespace_regex + "[0-9]{5}" + optional_whitespace_regex + "\"", marks=pytest.mark.skip(reason="https://github.com/NVIDIA/spark-rapids/issues/9664")), # integral - pytest.param("[0-9]{1,5}", marks=pytest.mark.skip(reason="https://github.com/NVIDIA/spark-rapids/issues/9588")), - "[1-9]{1,8}", + pytest.param("[0-9]{1,5}", marks=pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/9588")), + pytest.param("[1-9]{1,8}", marks=pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/4940")), # floating-point "[0-9]{0,2}\.[0-9]{1,2}" # boolean @@ -654,9 +730,9 @@ def test_from_json_struct_date_fallback_non_default_format(date_gen, date_format ]) @pytest.mark.parametrize('timestamp_format', [ # Even valid timestamp format, CPU fallback happens still since non UTC is not supported for json. - pytest.param("", marks=pytest.mark.allow_non_gpu(*non_utc_project_allow)), - pytest.param("yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]", marks=pytest.mark.allow_non_gpu(*non_utc_project_allow)), + pytest.param(None, marks=pytest.mark.allow_non_gpu(*non_utc_project_allow)), # https://github.com/NVIDIA/spark-rapids/issues/9723 + pytest.param("yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]", marks=pytest.mark.allow_non_gpu('ProjectExec')), pytest.param("yyyy-MM-dd'T'HH:mm:ss.SSSXXX", marks=pytest.mark.allow_non_gpu('ProjectExec')), pytest.param("dd/MM/yyyy'T'HH:mm:ss[.SSS][XXX]", marks=pytest.mark.allow_non_gpu('ProjectExec')), ]) @@ -668,8 +744,9 @@ def test_from_json_struct_date_fallback_non_default_format(date_gen, date_format def test_from_json_struct_timestamp(timestamp_gen, timestamp_format, time_parser_policy, ansi_enabled): json_string_gen = StringGen(r'{ "a": ' + timestamp_gen + ' }') \ .with_special_case('{ "a": null }') \ + .with_special_case('{ "a": "6395-12-21T56:86:40.205705Z" }') \ .with_special_case('null') - options = { 'timestampFormat': timestamp_format } if len(timestamp_format) > 0 else { } + options = { 'timestampFormat': timestamp_format } if timestamp_format else { } assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, json_string_gen) \ .select(f.col('a'), f.from_json('a', 'struct', options)), @@ -680,14 +757,14 @@ def test_from_json_struct_timestamp(timestamp_gen, timestamp_format, time_parser @allow_non_gpu('ProjectExec') @pytest.mark.parametrize('timestamp_gen', ["\"[1-8]{1}[0-9]{3}-[0-3]{1,2}-[0-3]{1,2}T[0-9]{1,2}:[0-9]{1,2}:[0-9]{1,2}(\\.[0-9]{1,6})?Z?\""]) @pytest.mark.parametrize('timestamp_format', [ - "", + None, "yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]", ]) def test_from_json_struct_timestamp_fallback_legacy(timestamp_gen, timestamp_format): json_string_gen = StringGen(r'{ "a": ' + timestamp_gen + ' }') \ .with_special_case('{ "a": null }') \ .with_special_case('null') - options = { 'timestampFormat': timestamp_format } if len(timestamp_format) > 0 else { } + options = { 'timestampFormat': timestamp_format } if timestamp_format else { } assert_gpu_fallback_collect( lambda spark : unary_op_df(spark, json_string_gen) \ .select(f.col('a'), f.from_json('a', 'struct', options)), @@ -705,7 +782,7 @@ def test_from_json_struct_timestamp_fallback_non_default_format(timestamp_gen, t json_string_gen = StringGen(r'{ "a": ' + timestamp_gen + ' }') \ .with_special_case('{ "a": null }') \ .with_special_case('null') - options = { 'timestampFormat': timestamp_format } if len(timestamp_format) > 0 else { } + options = { 'timestampFormat': timestamp_format } if timestamp_format else { } assert_gpu_fallback_collect( lambda spark : unary_op_df(spark, json_string_gen) \ .select(f.col('a'), f.from_json('a', 'struct', options)), diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala index 611e9ce43a1..51d695904a3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCSVScan.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -359,7 +359,7 @@ abstract class CSVPartitionReaderBase[BUFF <: LineBufferer, FACT <: LineBufferer } } - override def dateFormat: String = GpuCsvUtils.dateFormatInRead(parsedOptions) + override def dateFormat: Option[String] = Some(GpuCsvUtils.dateFormatInRead(parsedOptions)) override def timestampFormat: String = GpuCsvUtils.timestampFormatInRead(parsedOptions) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 9bf9144db0e..1e70090d0a7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1301,7 +1301,8 @@ object GpuCast { def convertDateOrNull( input: ColumnVector, regex: String, - cudfFormat: String): ColumnVector = { + cudfFormat: String, + failOnInvalid: Boolean = false): ColumnVector = { val prog = new RegexProgram(regex, CaptureGroups.NON_CAPTURE) val isValidDate = withResource(input.matchesRe(prog)) { isMatch => @@ -1311,6 +1312,13 @@ object GpuCast { } withResource(isValidDate) { _ => + if (failOnInvalid) { + withResource(isValidDate.all()) { all => + if (all.isValid && !all.getBoolean) { + throw new DateTimeException("One or more values is not a valid date") + } + } + } withResource(Scalar.fromNull(DType.TIMESTAMP_DAYS)) { orElse => withResource(input.asTimestampDays(cudfFormat)) { asDays => isValidDate.ifElse(asDays, orElse) @@ -1393,7 +1401,7 @@ object GpuCast { } } - private def castStringToDateAnsi(input: ColumnVector, ansiMode: Boolean): ColumnVector = { + def castStringToDateAnsi(input: ColumnVector, ansiMode: Boolean): ColumnVector = { val result = castStringToDate(input) if (ansiMode) { // When ANSI mode is enabled, we need to throw an exception if any values could not be diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTextBasedPartitionReader.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTextBasedPartitionReader.scala index 09da238459d..2a7bff24c50 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTextBasedPartitionReader.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTextBasedPartitionReader.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids import java.time.DateTimeException +import java.util import java.util.Optional import scala.collection.mutable.ListBuffer @@ -26,7 +27,6 @@ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.DateUtils.{toStrf, TimestampFormatConversionException} import com.nvidia.spark.rapids.jni.CastStrings import com.nvidia.spark.rapids.shims.GpuTypeShims -import java.util import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.compress.CompressionCodecFactory @@ -372,18 +372,14 @@ abstract class GpuTextBasedPartitionReader[BUFF <: LineBufferer, FACT <: LineBuf } } - def dateFormat: String + def dateFormat: Option[String] def timestampFormat: String def castStringToDate(input: ColumnVector, dt: DType): ColumnVector = { - castStringToDate(input, dt, failOnInvalid = true) - } - - def castStringToDate(input: ColumnVector, dt: DType, failOnInvalid: Boolean): ColumnVector = { - val cudfFormat = DateUtils.toStrf(dateFormat, parseString = true) + val cudfFormat = DateUtils.toStrf(dateFormat.getOrElse("yyyy-MM-dd"), parseString = true) withResource(input.strip()) { stripped => withResource(stripped.isTimestamp(cudfFormat)) { isDate => - if (failOnInvalid && GpuOverrides.getTimeParserPolicy == ExceptionTimeParserPolicy) { + if (GpuOverrides.getTimeParserPolicy == ExceptionTimeParserPolicy) { withResource(isDate.all()) { all => if (all.isValid && !all.getBoolean) { throw new DateTimeException("One or more values is not a valid date") diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala index 04f28ef045d..ae2e3a877e3 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ import ai.rapids.cudf import ai.rapids.cudf.{CaptureGroups, ColumnVector, DType, NvtxColor, RegexProgram, Scalar, Schema, Table} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.shims.{ColumnDefaultValuesShims, LegacyBehaviorPolicyShim, ShimFilePartitionReaderFactory} +import com.nvidia.spark.rapids.shims.{ColumnDefaultValuesShims, GpuJsonToStructsShim, LegacyBehaviorPolicyShim, ShimFilePartitionReaderFactory} import org.apache.hadoop.conf.Configuration import org.apache.spark.broadcast.Broadcast @@ -113,16 +113,15 @@ object GpuJsonScan { val hasDates = TrampolineUtil.dataTypeExistsRecursively(dt, _.isInstanceOf[DateType]) if (hasDates) { - GpuJsonUtils.optionalDateFormatInRead(parsedOptions) match { - case None | Some("yyyy-MM-dd") => - // this is fine - case dateFormat => - meta.willNotWorkOnGpu(s"GpuJsonToStructs unsupported dateFormat $dateFormat") - } + GpuJsonToStructsShim.tagDateFormatSupport(meta, + GpuJsonUtils.optionalDateFormatInRead(parsedOptions)) } val hasTimestamps = TrampolineUtil.dataTypeExistsRecursively(dt, _.isInstanceOf[TimestampType]) if (hasTimestamps) { + GpuJsonToStructsShim.tagTimestampFormatSupport(meta, + GpuJsonUtils.optionalTimestampFormatInRead(parsedOptions)) + GpuJsonUtils.optionalTimestampFormatInRead(parsedOptions) match { case None | Some("yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]") => // this is fine @@ -163,10 +162,16 @@ object GpuJsonScan { tagSupportOptions(parsedOptions, meta) val types = readSchema.map(_.dataType) - if (types.contains(DateType)) { + + val hasDates = TrampolineUtil.dataTypeExistsRecursively(readSchema, _.isInstanceOf[DateType]) + if (hasDates) { + GpuTextBasedDateUtils.tagCudfFormat(meta, GpuJsonUtils.dateFormatInRead(parsedOptions), parseString = true) + GpuJsonToStructsShim.tagDateFormatSupportFromScan(meta, + GpuJsonUtils.optionalDateFormatInRead(parsedOptions)) + // For date type, timezone needs to be checked also. This is because JVM timezone is used // to get days offset before rebasing Julian to Gregorian in Spark while not in Rapids. // @@ -446,6 +451,10 @@ class JsonPartitionReader( } } + override def castStringToDate(input: ColumnVector, dt: DType): ColumnVector = { + GpuJsonToStructsShim.castJsonStringToDateFromScan(input, dt, dateFormat) + } + /** * JSON has strict rules about valid numeric formats. See https://www.json.org/ for specification. * @@ -490,6 +499,6 @@ class JsonPartitionReader( } } - override def dateFormat: String = GpuJsonUtils.dateFormatInRead(parsedOptions) + override def dateFormat: Option[String] = GpuJsonUtils.optionalDateFormatInRead(parsedOptions) override def timestampFormat: String = GpuJsonUtils.timestampFormatInRead(parsedOptions) } diff --git a/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/GpuJsonToStructsShim.scala b/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/GpuJsonToStructsShim.scala index 0dffe6a35fa..2d8b2fb9136 100644 --- a/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/GpuJsonToStructsShim.scala +++ b/sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/GpuJsonToStructsShim.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,36 +17,28 @@ {"spark": "311"} {"spark": "312"} {"spark": "313"} -{"spark": "320"} -{"spark": "321"} -{"spark": "321cdh"} -{"spark": "321db"} -{"spark": "322"} -{"spark": "323"} -{"spark": "324"} -{"spark": "330"} -{"spark": "330cdh"} -{"spark": "330db"} -{"spark": "331"} -{"spark": "332"} -{"spark": "332cdh"} -{"spark": "332db"} -{"spark": "333"} -{"spark": "334"} spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims -import ai.rapids.cudf.{ColumnVector, Scalar} +import ai.rapids.cudf.{ColumnVector, DType, Scalar} +import com.nvidia.spark.rapids.{GpuCast, GpuOverrides, RapidsMeta} import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.GpuCast import org.apache.spark.sql.catalyst.json.GpuJsonUtils +import org.apache.spark.sql.rapids.ExceptionTimeParserPolicy object GpuJsonToStructsShim { + def tagDateFormatSupport(meta: RapidsMeta[_, _, _], dateFormat: Option[String]): Unit = { + dateFormat match { + case None | Some("yyyy-MM-dd") => + case dateFormat => + meta.willNotWorkOnGpu(s"GpuJsonToStructs unsupported dateFormat $dateFormat") + } + } def castJsonStringToDate(input: ColumnVector, options: Map[String, String]): ColumnVector = { - GpuJsonUtils.dateFormatInRead(options) match { - case "yyyy-MM-dd" => + GpuJsonUtils.optionalDateFormatInRead(options) match { + case None | Some("yyyy-MM-dd") => withResource(Scalar.fromString(" ")) { space => withResource(input.strip(space)) { trimmed => GpuCast.castStringToDate(trimmed) @@ -58,6 +50,27 @@ object GpuJsonToStructsShim { } } + def tagDateFormatSupportFromScan(meta: RapidsMeta[_, _, _], dateFormat: Option[String]): Unit = { + tagDateFormatSupport(meta, dateFormat) + } + + def castJsonStringToDateFromScan(input: ColumnVector, dt: DType, + dateFormat: Option[String]): ColumnVector = { + dateFormat match { + case None | Some("yyyy-MM-dd") => + withResource(input.strip()) { trimmed => + GpuCast.castStringToDateAnsi(trimmed, ansiMode = + GpuOverrides.getTimeParserPolicy == ExceptionTimeParserPolicy) + } + case other => + // should be unreachable due to GpuOverrides checks + throw new IllegalStateException(s"Unsupported dateFormat $other") + } + } + + def tagTimestampFormatSupport(meta: RapidsMeta[_, _, _], + timestampFormat: Option[String]): Unit = {} + def castJsonStringToTimestamp(input: ColumnVector, options: Map[String, String]): ColumnVector = { withResource(Scalar.fromString(" ")) { space => diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GpuJsonToStructsShim.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GpuJsonToStructsShim.scala new file mode 100644 index 00000000000..0c94c5c1e1f --- /dev/null +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/GpuJsonToStructsShim.scala @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "321db"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "330db"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332cdh"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "334"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import ai.rapids.cudf.{ColumnVector, DType, Scalar} +import com.nvidia.spark.rapids.{DateUtils, GpuCast, GpuOverrides, RapidsMeta} +import com.nvidia.spark.rapids.Arm.withResource + +import org.apache.spark.sql.rapids.ExceptionTimeParserPolicy + +object GpuJsonToStructsShim { + + def tagDateFormatSupport(meta: RapidsMeta[_, _, _], dateFormat: Option[String]): Unit = { + // dateFormat is ignored by JsonToStructs in Spark 3.2.x and 3.3.x because it just + // performs a regular cast from string to date + } + + def castJsonStringToDate(input: ColumnVector, options: Map[String, String]): ColumnVector = { + // dateFormat is ignored in from_json in Spark 3.2 + withResource(Scalar.fromString(" ")) { space => + withResource(input.strip(space)) { trimmed => + GpuCast.castStringToDate(trimmed) + } + } + } + + def tagDateFormatSupportFromScan(meta: RapidsMeta[_, _, _], dateFormat: Option[String]): Unit = { + } + + def castJsonStringToDateFromScan(input: ColumnVector, dt: DType, + dateFormat: Option[String]): ColumnVector = { + dateFormat match { + case None => + // legacy behavior + withResource(input.strip()) { trimmed => + GpuCast.castStringToDateAnsi(trimmed, ansiMode = + GpuOverrides.getTimeParserPolicy == ExceptionTimeParserPolicy) + } + case Some(fmt) => + withResource(input.strip()) { trimmed => + val regexRoot = fmt + .replace("yyyy", raw"\d{4}") + .replace("MM", raw"\d{1,2}") + .replace("dd", raw"\d{1,2}") + val cudfFormat = DateUtils.toStrf(fmt, parseString = true) + GpuCast.convertDateOrNull(trimmed, "^" + regexRoot + "$", cudfFormat, + failOnInvalid = GpuOverrides.getTimeParserPolicy == ExceptionTimeParserPolicy) + } + } + } + + def tagTimestampFormatSupport(meta: RapidsMeta[_, _, _], + timestampFormat: Option[String]): Unit = { + // we only support the case where no format is specified + timestampFormat.foreach(f => meta.willNotWorkOnGpu(s"Unsupported timestampFormat: $f")) + } + + def castJsonStringToTimestamp(input: ColumnVector, + options: Map[String, String]): ColumnVector = { + // legacy behavior + withResource(Scalar.fromString(" ")) { space => + withResource(input.strip(space)) { trimmed => + // from_json doesn't respect ansi mode + GpuCast.castStringToTimestamp(trimmed, ansiMode = false) + } + } + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuJsonToStructsShim.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuJsonToStructsShim.scala index 8c53323d018..c05ebd2fa7c 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuJsonToStructsShim.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuJsonToStructsShim.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,14 +23,18 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims -import ai.rapids.cudf.{ColumnVector, Scalar} +import ai.rapids.cudf.{ColumnVector, DType, Scalar} +import com.nvidia.spark.rapids.{DateUtils, GpuCast, GpuOverrides, RapidsMeta} import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.GpuCast import org.apache.spark.sql.catalyst.json.GpuJsonUtils +import org.apache.spark.sql.rapids.ExceptionTimeParserPolicy object GpuJsonToStructsShim { + def tagDateFormatSupport(meta: RapidsMeta[_, _, _], dateFormat: Option[String]): Unit = { + } + def castJsonStringToDate(input: ColumnVector, options: Map[String, String]): ColumnVector = { GpuJsonUtils.optionalDateFormatInRead(options) match { case None => @@ -40,14 +44,46 @@ object GpuJsonToStructsShim { GpuCast.castStringToDate(trimmed) } } - case Some("yyyy-MM-dd") => - GpuCast.convertDateOrNull(input, "^[0-9]{4}-[0-9]{2}-[0-9]{2}$", "%Y-%m-%d") - case other => - // should be unreachable due to GpuOverrides checks - throw new IllegalStateException(s"Unsupported dateFormat $other") + case Some(f) => + // from_json does not respect EXCEPTION policy + jsonStringToDate(input, f, failOnInvalid = false) + } + } + + def tagDateFormatSupportFromScan(meta: RapidsMeta[_, _, _], dateFormat: Option[String]): Unit = { + } + + def castJsonStringToDateFromScan(input: ColumnVector, dt: DType, + dateFormat: Option[String]): ColumnVector = { + dateFormat match { + case None => + // legacy behavior + withResource(input.strip()) { trimmed => + GpuCast.castStringToDateAnsi(trimmed, ansiMode = + GpuOverrides.getTimeParserPolicy == ExceptionTimeParserPolicy) + } + case Some(f) => + jsonStringToDate(input, f, + GpuOverrides.getTimeParserPolicy == ExceptionTimeParserPolicy) } } + private def jsonStringToDate(input: ColumnVector, dateFormatPattern: String, + failOnInvalid: Boolean): ColumnVector = { + val regexRoot = dateFormatPattern + .replace("yyyy", raw"\d{4}") + .replace("MM", raw"\d{2}") + .replace("dd", raw"\d{2}") + val cudfFormat = DateUtils.toStrf(dateFormatPattern, parseString = true) + GpuCast.convertDateOrNull(input, "^" + regexRoot + "$", cudfFormat, failOnInvalid) + } + + def tagTimestampFormatSupport(meta: RapidsMeta[_, _, _], + timestampFormat: Option[String]): Unit = { + // we only support the case where no format is specified + timestampFormat.foreach(f => meta.willNotWorkOnGpu(s"Unsupported timestampFormat: $f")) + } + def castJsonStringToTimestamp(input: ColumnVector, options: Map[String, String]): ColumnVector = { options.get("timestampFormat") match { @@ -59,9 +95,6 @@ object GpuJsonToStructsShim { GpuCast.castStringToTimestamp(trimmed, ansiMode = false) } } - case Some("yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]") => - GpuCast.convertTimestampOrNull(input, - "^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(\\.[0-9]{1,6})?Z?$", "%Y-%m-%d") case other => // should be unreachable due to GpuOverrides checks throw new IllegalStateException(s"Unsupported timestampFormat $other")