Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16512] Adding a insertNullOnErrors option. #400

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ When reading files the API accepts several options:
* `comment`: skip lines beginning with this character. Default is `"#"`. Disable comments by setting this to `null`.
* `nullValue`: specifies a string that indicates a null value, any fields matching this string will be set as nulls in the DataFrame
* `dateFormat`: specifies a string that indicates the date format to use when reading dates or timestamps. Custom date formats follow the formats at [`java.text.SimpleDateFormat`](https://docs.oracle.com/javase/7/docs/api/java/text/SimpleDateFormat.html). This applies to both `DateType` and `TimestampType`. By default, it is `null` which means trying to parse times and date by `java.sql.Timestamp.valueOf()` and `java.sql.Date.valueOf()`.

* `insertNullOnErrors` : by default false. If true, we treat a parse exception encountered while reading the csv, such as a malformed number in a numeric column, as a null value, rather than failing or eliminating the row.
The package also supports saving simple (non-nested) DataFrame. When writing files the API accepts several options:
* `path`: location of files.
* `header`: when set to true, the header (from the schema in the DataFrame) will be written at the first line.
Expand Down
17 changes: 15 additions & 2 deletions src/main/scala/com/databricks/spark/csv/CsvParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class CsvParser extends Serializable {
private var nullValue: String = ""
private var dateFormat: String = null
private var maxCharsPerCol: Int = 100000
private var treatParseExceptionAsNull : Boolean = false

def withUseHeader(flag: Boolean): CsvParser = {
this.useHeader = flag
Expand Down Expand Up @@ -124,6 +125,16 @@ class CsvParser extends Serializable {
this
}

/**
* If this is set to true then dirty data, for example a string in a numeric column,
* or a mal-formed date will not cause a failure.
* Instead, that value will be null in the resulting data
*/
def withInsertNullOnError(flag : Boolean) : CsvParser = {
this.treatParseExceptionAsNull = flag
this
}

def withMaxCharsPerCol(maxCharsPerCol: Int): CsvParser = {
this.maxCharsPerCol = maxCharsPerCol
this
Expand All @@ -150,7 +161,8 @@ class CsvParser extends Serializable {
codec,
nullValue,
dateFormat,
maxCharsPerCol)(sqlContext)
maxCharsPerCol,
treatParseExceptionAsNull)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}

Expand All @@ -173,7 +185,8 @@ class CsvParser extends Serializable {
codec,
nullValue,
dateFormat,
maxCharsPerCol)(sqlContext)
maxCharsPerCol,
treatParseExceptionAsNull)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}
}
8 changes: 5 additions & 3 deletions src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ case class CsvRelation protected[spark] (
codec: String = null,
nullValue: String = "",
dateFormat: String = null,
maxCharsPerCol: Int = 100000)(@transient val sqlContext: SQLContext)
maxCharsPerCol: Int = 100000,
insertNullOnErrors: Boolean)(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with PrunedScan with InsertableRelation {

// Share date format object as it is expensive to parse date pattern.
Expand Down Expand Up @@ -119,7 +120,7 @@ case class CsvRelation protected[spark] (
while (index < schemaFields.length) {
val field = schemaFields(index)
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable,
treatEmptyValuesAsNulls, nullValue, simpleDateFormatter)
treatEmptyValuesAsNulls, nullValue, simpleDateFormatter, insertNullOnErrors)
index = index + 1
}
Some(Row.fromSeq(rowArray))
Expand Down Expand Up @@ -199,7 +200,8 @@ case class CsvRelation protected[spark] (
field.nullable,
treatEmptyValuesAsNulls,
nullValue,
simpleDateFormatter
simpleDateFormatter,
insertNullOnErrors
)
subIndex = subIndex + 1
}
Expand Down
11 changes: 10 additions & 1 deletion src/main/scala/com/databricks/spark/csv/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ class DefaultSource
val charset = parameters.getOrElse("charset", TextFile.DEFAULT_CHARSET.name())
// TODO validate charset?

val treatParseExceptionAsNull = parameters.getOrElse("insertNullOnErrors", "false")
val insertNullOnErrorFlag = if (treatParseExceptionAsNull == "false"){
false
} else if (treatParseExceptionAsNull == "true") {
true
} else {
throw new Exception("Insert null on errors flag can be true or false")
}
val inferSchema = parameters.getOrElse("inferSchema", "false")
val inferSchemaFlag = if (inferSchema == "false") {
false
Expand Down Expand Up @@ -168,7 +176,8 @@ class DefaultSource
codec,
nullValue,
dateFormat,
maxCharsPerCol)(sqlContext)
maxCharsPerCol,
insertNullOnErrorFlag)(sqlContext)
}

override def createRelation(
Expand Down
5 changes: 3 additions & 2 deletions src/main/scala/com/databricks/spark/csv/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ package object csv {
ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls = false,
inferCsvSchema = inferSchema)(sqlContext)
inferCsvSchema = inferSchema,
insertNullOnErrors = true)(sqlContext)
sqlContext.baseRelationToDataFrame(csvRelation)
}

Expand All @@ -85,7 +86,7 @@ package object csv {
ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls = false,
inferCsvSchema = inferSchema)(sqlContext)
inferCsvSchema = inferSchema, insertNullOnErrors = false)(sqlContext)
sqlContext.baseRelationToDataFrame(csvRelation)
}
}
Expand Down
51 changes: 30 additions & 21 deletions src/main/scala/com/databricks/spark/csv/util/TypeCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.text.{SimpleDateFormat, NumberFormat}
import java.util.Locale

import org.apache.spark.sql.types._

import scala.util.Try

/**
Expand All @@ -45,31 +44,38 @@ object TypeCast {
nullable: Boolean = true,
treatEmptyValuesAsNulls: Boolean = false,
nullValue: String = "",
dateFormatter: SimpleDateFormat = null): Any = {
dateFormatter: SimpleDateFormat = null,
insertNullOnError : Boolean = false): Any = {
if (datum == nullValue &&
nullable ||
(treatEmptyValuesAsNulls && datum == "")){
null
} else {
castType match {
case _: ByteType => datum.toByte
case _: ShortType => datum.toShort
case _: IntegerType => datum.toInt
case _: LongType => datum.toLong
case _: FloatType => Try(datum.toFloat)
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue())
case _: DoubleType => Try(datum.toDouble)
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
case _: BooleanType => datum.toBoolean
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
case _: TimestampType if dateFormatter != null =>
new Timestamp(dateFormatter.parse(datum).getTime)
case _: TimestampType => Timestamp.valueOf(datum)
case _: DateType if dateFormatter != null =>
new Date(dateFormatter.parse(datum).getTime)
case _: DateType => Date.valueOf(datum)
case _: StringType => datum
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
try {
castType match {
case _: ByteType => datum.toByte
case _: ShortType => datum.toShort
case _: IntegerType => datum.toInt
case _: LongType => datum.toLong
case _: FloatType => Try(datum.toFloat)
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue())
case _: DoubleType => Try(datum.toDouble)
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
case _: BooleanType => datum.toBoolean
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
case _: TimestampType if dateFormatter != null =>
new Timestamp(dateFormatter.parse(datum).getTime)
case _: TimestampType => Timestamp.valueOf(datum)
case _: DateType if dateFormatter != null =>
new Date(dateFormatter.parse(datum).getTime)
case _: DateType => Date.valueOf(datum)
case _: StringType => datum
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
}
}
catch {
case e : UnsupportedTypeException => throw e
case e : Throwable => if (insertNullOnError && nullable) null else throw e
}
}
}
Expand Down Expand Up @@ -102,3 +108,6 @@ object TypeCast {
}
}
}

class UnsupportedTypeException(message: String = null, cause: Throwable = null)
extends RuntimeException(message, cause)
5 changes: 5 additions & 0 deletions src/test/resources/cars_dirty.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
year,make,model,price,comment,blank
2012,Tesla,S"80,000.65"
2013.5,Ford,E350,35,000,"Go get one now they are going fast"
2015,,Volt,5,000
new,"",Volt,5000.00
34 changes: 34 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.scalatest.Matchers._
abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
val carsFile = "src/test/resources/cars.csv"
val carsMalformedFile = "src/test/resources/cars-malformed.csv"
val carsDirtyTsvFile = "src/test/resources/cars_dirty.csv"
val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv"
val carsTsvFile = "src/test/resources/cars.tsv"
val carsAltFile = "src/test/resources/cars-alternative.csv"
Expand Down Expand Up @@ -69,6 +70,13 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
}
}

test("Dirty Data CSV"){
val results = sqlContext.csvFile(
carsDirtyTsvFile, parserLib = parserLib
).collect()
assert(results.length == 4)
}

test("DSL test") {
val results = sqlContext
.csvFile(carsFile, parserLib = parserLib)
Expand Down Expand Up @@ -216,6 +224,32 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt"))
}

test("Insert Null On Error with schema"){
val carsSchema = new StructType(
Array(
StructField("year", IntegerType, nullable = true),
StructField("make", StringType, nullable = true),
StructField("model", StringType, nullable = true),
StructField("price", DoubleType, nullable = true),
StructField("comment", StringType, nullable = true),
StructField("blank", IntegerType, nullable = true)
)
)

val results = new CsvParser()
.withSchema(carsSchema)
.withUseHeader(true)
.withDelimiter(',')
.withQuoteChar('\"').withInsertNullOnError(true)
.csvFile(sqlContext, carsDirtyTsvFile).select("year", "make")
.collect()

assert(results(0).toSeq == Seq(2012, "Tesla"))
assert(results(1).toSeq == Seq(null, "Ford"))
assert(results(2).toSeq == Seq(2015, null))
assert(results(3).toSeq == Seq(null, null))
}

test("DSL test roundtrip nulls") {
// Create temp directory
TestUtils.deleteRecursively(new File(tempEmptyDir))
Expand Down
48 changes: 47 additions & 1 deletion src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,53 @@ class TypeCastSuite extends FunSuite {
}
}

test("Can parse escaped characters") {
test("Parse exception is caught correctly") {

def testParseException(castType: DataType, badValues: Seq[String]): Unit = {
badValues.foreach { testValue =>
assert(TypeCast.castTo(testValue, castType, true, false, "", null, true) == null)
// if not nullable it isn't null
try {
TypeCast.castTo(testValue, castType, false, false, "", null, true)
} catch {
case e: Throwable => assert(e.isInstanceOf[Exception])
}
}
}

assert(TypeCast.castTo("10", ByteType, true, false, "", null, true) == 10)
testParseException(ByteType, Seq("10.5", "s", "true"))

assert(TypeCast.castTo("10", ShortType, true, false, "", null, true) == 10)
testParseException(ShortType, Seq("s", "true"))

assert(TypeCast.castTo("10", IntegerType, true, false, "", null, true) == 10)
testParseException(IntegerType, Seq("10.5", "s", "true"))

assert(TypeCast.castTo("10", LongType, true, false, "", null, true) == 10)
testParseException(LongType, Seq("10.5", "s", "true"))

assert(TypeCast.castTo("1.00", FloatType, true, false, "", null, true) == 1.0)
testParseException(FloatType, Seq("s", "true"))

assert(TypeCast.castTo("1.00", DoubleType, true, false, "", null, true) == 1.0)
testParseException(DoubleType, Seq("s", "true"))

assert(TypeCast.castTo("true", BooleanType, true, false, "", null, true) == true)
testParseException(BooleanType, Seq("s", "5"))

val timestamp = "2015-01-01 00:00:00"
assert(TypeCast.castTo(timestamp, TimestampType, true, false, "", null, true)
== Timestamp.valueOf(timestamp))
testParseException(TimestampType, Seq("5", "string"))

assert(TypeCast.castTo("2015-01-01", DateType, true, false, "", null, true)
== Date.valueOf("2015-01-01"))
testParseException(DateType, Seq("5", "string", timestamp))

}

test("Can parse escaped characters") {
assert(TypeCast.toChar("""\t""") === '\t')
assert(TypeCast.toChar("""\r""") === '\r')
assert(TypeCast.toChar("""\b""") === '\b')
Expand Down