Skip to content

Commit

Permalink
Shortcut common type inference cases to fail fast, speed up inference (
Browse files Browse the repository at this point in the history
…#660)

* Shortcut to fail date/time parsing if not a date/time

* Don't use exceptions for date/time control flow in the parsing method to make inference faster

* Also shortcut int/float/double parsing where obviously not parseable
  • Loading branch information
srowen authored Sep 7, 2023
1 parent 3d76b79 commit 994e357
Showing 1 changed file with 54 additions and 31 deletions.
85 changes: 54 additions & 31 deletions src/main/scala/com/databricks/spark/xml/util/TypeCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.databricks.spark.xml.util

import java.math.BigDecimal
import java.sql.{Date, Timestamp}
import java.text.{NumberFormat, ParsePosition}
import java.text.NumberFormat
import java.time.{Instant, LocalDate, ZoneId}
import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder}
import java.util.Locale
Expand All @@ -26,8 +26,6 @@ import scala.util.control.Exception._
import org.apache.spark.sql.types._
import com.databricks.spark.xml.XmlOptions

import java.time.temporal.TemporalQueries

/**
* Utility functions for type casting
*/
Expand Down Expand Up @@ -63,8 +61,14 @@ private[xml] object TypeCast {
case _: BooleanType => parseXmlBoolean(datum)
case dt: DecimalType =>
Decimal(new BigDecimal(datum.replaceAll(",", "")), dt.precision, dt.scale)
case _: TimestampType => parseXmlTimestamp(datum, options)
case _: DateType => parseXmlDate(datum, options)
case _: TimestampType =>
parseXmlTimestamp(datum, options).getOrElse {
throw new IllegalArgumentException(s"cannot convert value $datum to Timestamp")
}
case _: DateType =>
parseXmlDate(datum, options).getOrElse {
throw new IllegalArgumentException(s"cannot convert value $datum to Date")
}
case _: StringType => datum
case _ => throw new IllegalArgumentException(s"Unsupported type: ${castType.typeName}")
}
Expand All @@ -85,17 +89,26 @@ private[xml] object TypeCast {
DateTimeFormatter.ISO_DATE
)

private def parseXmlDate(value: String, options: XmlOptions): Date = {
val formatters = options.dateFormat.map(DateTimeFormatter.ofPattern).
map(supportedXmlDateFormatters :+ _).getOrElse(supportedXmlDateFormatters)
formatters.foreach { format =>
private def parseXmlDate(value: String, options: XmlOptions): Option[Date] = {
// A little shortcut to avoid trying many formatters in the common case that
// the input isn't a date. All built-in formats will start with a digit.
if (value.nonEmpty && Character.isDigit(value.head)) {
supportedXmlDateFormatters.foreach { format =>
try {
return Some(Date.valueOf(LocalDate.parse(value, format)))
} catch {
case _: Exception => // continue
}
}
}
options.dateFormat.map(DateTimeFormatter.ofPattern).foreach { format =>
try {
return Date.valueOf(LocalDate.parse(value, format))
return Some(Date.valueOf(LocalDate.parse(value, format)))
} catch {
case _: Exception => // continue
}
}
throw new IllegalArgumentException(s"cannot convert value $value to Date")
None
}

private val supportedXmlTimestampFormatters = Seq(
Expand All @@ -115,12 +128,16 @@ private[xml] object TypeCast {
DateTimeFormatter.ISO_INSTANT
)

private def parseXmlTimestamp(value: String, options: XmlOptions): Timestamp = {
supportedXmlTimestampFormatters.foreach { format =>
try {
return Timestamp.from(Instant.from(format.parse(value)))
} catch {
case _: Exception => // continue
private def parseXmlTimestamp(value: String, options: XmlOptions): Option[Timestamp] = {
// A little shortcut to avoid trying many formatters in the common case that
// the input isn't a timestamp. All built-in formats will start with a digit.
if (value.nonEmpty && Character.isDigit(value.head)) {
supportedXmlTimestampFormatters.foreach { format =>
try {
return Some(Timestamp.from(Instant.from(format.parse(value))))
} catch {
case _: Exception => // continue
}
}
}
options.timestampFormat.foreach { formatString =>
Expand All @@ -138,12 +155,12 @@ private[xml] object TypeCast {
DateTimeFormatter.ofPattern(formatString).withZone(options.timezone.map(ZoneId.of).orNull)
}
try {
return Timestamp.from(Instant.from(format.parse(value)))
return Some(Timestamp.from(Instant.from(format.parse(value))))
} catch {
case _: Exception => // continue
}
}
throw new IllegalArgumentException(s"cannot convert value $value to Timestamp")
None
}


Expand Down Expand Up @@ -196,6 +213,12 @@ private[xml] object TypeCast {
} else {
value
}
// A little shortcut to avoid trying many formatters in the common case that
// the input isn't a double. All built-in formats will start with a digit or period.
if (signSafeValue.isEmpty ||
!(Character.isDigit(signSafeValue.head) || signSafeValue.head == '.')) {
return false
}
// Rule out strings ending in D or F, as they will parse as double but should be disallowed
if (value.nonEmpty && (value.last match {
case 'd' | 'D' | 'f' | 'F' => true
Expand All @@ -212,6 +235,11 @@ private[xml] object TypeCast {
} else {
value
}
// A little shortcut to avoid trying many formatters in the common case that
// the input isn't a number. All built-in formats will start with a digit.
if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) {
return false
}
(allCatch opt signSafeValue.toInt).isDefined
}

Expand All @@ -221,25 +249,20 @@ private[xml] object TypeCast {
} else {
value
}
// A little shortcut to avoid trying many formatters in the common case that
// the input isn't a number. All built-in formats will start with a digit.
if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) {
return false
}
(allCatch opt signSafeValue.toLong).isDefined
}

private[xml] def isTimestamp(value: String, options: XmlOptions): Boolean = {
try {
parseXmlTimestamp(value, options)
true
} catch {
case _: IllegalArgumentException => false
}
parseXmlTimestamp(value, options).nonEmpty
}

private[xml] def isDate(value: String, options: XmlOptions): Boolean = {
try {
parseXmlDate(value, options)
true
} catch {
case _: IllegalArgumentException => false
}
parseXmlDate(value, options).nonEmpty
}

private[xml] def signSafeToLong(value: String, options: XmlOptions): Long = {
Expand Down

0 comments on commit 994e357

Please sign in to comment.