Skip to content

Commit

Permalink
upmerge
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 6, 2023
1 parent ae41666 commit b393caa
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 96 deletions.
1 change: 1 addition & 0 deletions integration_tests/src/main/python/csv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def test_csv_fallback(spark_tmp_path, read_func, disable_conf, spark_tmp_table_f
cpu_fallback_class_name=get_non_gpu_allowed()[0],
conf=updated_conf)

# todo add '' to represent no date format specified
csv_supported_date_formats = ['yyyy-MM-dd', 'yyyy/MM/dd', 'yyyy-MM', 'yyyy/MM',
'MM-yyyy', 'MM/yyyy', 'MM-dd-yyyy', 'MM/dd/yyyy', 'dd-MM-yyyy', 'dd/MM/yyyy']
@pytest.mark.parametrize('date_format', csv_supported_date_formats, ids=idfn)
Expand Down
229 changes: 157 additions & 72 deletions integration_tests/src/main/python/json_test.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions integration_tests/src/test/resources/dates.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{ "number": "2020-09-16" }
{ "number": " 2020-09-16" }
{ "number": "2020-09-16 " }
{ "number": " 2020-09-17" }
{ "number": "2020-09-18 " }
{ "number": " 2020-09-19 " }
{ "number": "1581-01-01" }
{ "number": "1583-01-01" }
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ abstract class CSVPartitionReaderBase[BUFF <: LineBufferer, FACT <: LineBufferer
}
}

override def dateFormat: String = GpuCsvUtils.dateFormatInRead(parsedOptions)
override def dateFormat: Option[String] = Some(GpuCsvUtils.dateFormatInRead(parsedOptions))
override def timestampFormat: String = GpuCsvUtils.timestampFormatInRead(parsedOptions)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.Optional

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, DecimalUtils, DType, RegexProgram, Scalar}
import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, DType, DecimalUtils, RegexProgram, Scalar}
import ai.rapids.cudf
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
Expand Down Expand Up @@ -1376,7 +1376,7 @@ object GpuCast {
}
}

private def castStringToDateAnsi(input: ColumnVector, ansiMode: Boolean): ColumnVector = {
def castStringToDateAnsi(input: ColumnVector, ansiMode: Boolean): ColumnVector = {
val result = castStringToDate(input)
if (ansiMode) {
// When ANSI mode is enabled, we need to throw an exception if any values could not be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package com.nvidia.spark.rapids

import java.time.DateTimeException
import java.util.Optional

import scala.collection.mutable.ListBuffer
Expand All @@ -36,7 +35,7 @@ import org.apache.spark.sql.connector.read.PartitionReader
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.datasources.{HadoopFileLinesReader, PartitionedFile}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.{ExceptionTimeParserPolicy, GpuToTimestamp, LegacyTimeParserPolicy}
import org.apache.spark.sql.rapids.{GpuToTimestamp, LegacyTimeParserPolicy}
import org.apache.spark.sql.types.{DataTypes, DecimalType, StructField, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -372,30 +371,40 @@ abstract class GpuTextBasedPartitionReader[BUFF <: LineBufferer, FACT <: LineBuf
}
}

def dateFormat: String
def dateFormat: Option[String]
def timestampFormat: String

def castStringToDate(input: ColumnVector, dt: DType): ColumnVector = {
castStringToDate(input, dt, failOnInvalid = true)
}

def castStringToDate(input: ColumnVector, dt: DType, failOnInvalid: Boolean): ColumnVector = {
val cudfFormat = DateUtils.toStrf(dateFormat, parseString = true)
withResource(input.strip()) { stripped =>
withResource(stripped.isTimestamp(cudfFormat)) { isDate =>
if (failOnInvalid && GpuOverrides.getTimeParserPolicy == ExceptionTimeParserPolicy) {
withResource(isDate.all()) { all =>
if (all.isValid && !all.getBoolean) {
throw new DateTimeException("One or more values is not a valid date")
}
}
}
withResource(stripped.asTimestamp(dt, cudfFormat)) { asDate =>
withResource(Scalar.fromNull(dt)) { nullScalar =>
isDate.ifElse(asDate, nullScalar)

// TODO make these same changes for timestamps and add tests

val dateFormatPattern = dateFormat.getOrElse("yyyy-MM-dd")

val cudfFormat = DateUtils.toStrf(dateFormatPattern, parseString = true)

dateFormat match {
case Some(_) =>
val twoDigits = raw"\d{2}"
val fourDigits = raw"\d{4}"

val regexRoot = dateFormatPattern
.replace("yyyy", fourDigits)
.replace("MM", twoDigits)
.replace("dd", twoDigits)
GpuCast.convertDateOrNull(input, "^" + regexRoot + "$", cudfFormat)
case _ =>
// legacy behavior
// TODO this is similar to, but different from GpuJsonToStructsShim
// withResource(Scalar.fromString(" ")) { space =>
withResource(input.strip()) { trimmed =>
// TODO add tests for EXCEPTION policy handling
GpuCast.castStringToDateAnsi(trimmed, ansiMode = false) // TODO
}
}
}
// }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,6 @@ class JsonPartitionReader(
}
}

override def dateFormat: String = GpuJsonUtils.dateFormatInRead(parsedOptions)
override def dateFormat: Option[String] = GpuJsonUtils.optionalDateFormatInRead(parsedOptions)
override def timestampFormat: String = GpuJsonUtils.timestampFormatInRead(parsedOptions)
}

0 comments on commit b393caa

Please sign in to comment.