diff --git a/README.md b/README.md index 3dffa2f..b6c9e32 100755 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ When reading files the API accepts several options: * `comment`: skip lines beginning with this character. Default is `"#"`. Disable comments by setting this to `null`. * `nullValue`: specifies a string that indicates a null value, any fields matching this string will be set as nulls in the DataFrame * `dateFormat`: specifies a string that indicates the date format to use when reading dates or timestamps. Custom date formats follow the formats at [`java.text.SimpleDateFormat`](https://docs.oracle.com/javase/7/docs/api/java/text/SimpleDateFormat.html). This applies to both `DateType` and `TimestampType`. By default, it is `null` which means trying to parse times and date by `java.sql.Timestamp.valueOf()` and `java.sql.Date.valueOf()`. - +* `insertNullOnErrors` : by default false. If true, we treat a parse exception encountered while reading the csv, such as a malformed number in a numeric column, as a null value, rather than failing or eliminating the row. The package also supports saving simple (non-nested) DataFrame. When writing files the API accepts several options: * `path`: location of files. * `header`: when set to true, the header (from the schema in the DataFrame) will be written at the first line. diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index 8ed496b..551d033 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -43,6 +43,7 @@ class CsvParser extends Serializable { private var nullValue: String = "" private var dateFormat: String = null private var maxCharsPerCol: Int = 100000 + private var treatParseExceptionAsNull : Boolean = false def withUseHeader(flag: Boolean): CsvParser = { this.useHeader = flag @@ -124,6 +125,16 @@ class CsvParser extends Serializable { this } + /** + * If this is set to true then dirty data, for example a string in a numeric column, + * or a mal-formed date will not cause a failure. + * Instead, that value will be null in the resulting data + */ + def withInsertNullOnError(flag : Boolean) : CsvParser = { + this.treatParseExceptionAsNull = flag + this + } + def withMaxCharsPerCol(maxCharsPerCol: Int): CsvParser = { this.maxCharsPerCol = maxCharsPerCol this @@ -150,7 +161,8 @@ class CsvParser extends Serializable { codec, nullValue, dateFormat, - maxCharsPerCol)(sqlContext) + maxCharsPerCol, + treatParseExceptionAsNull)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } @@ -173,7 +185,8 @@ class CsvParser extends Serializable { codec, nullValue, dateFormat, - maxCharsPerCol)(sqlContext) + maxCharsPerCol, + treatParseExceptionAsNull)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } } diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 7efc6bd..3179211 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -50,7 +50,8 @@ case class CsvRelation protected[spark] ( codec: String = null, nullValue: String = "", dateFormat: String = null, - maxCharsPerCol: Int = 100000)(@transient val sqlContext: SQLContext) + maxCharsPerCol: Int = 100000, + insertNullOnErrors: Boolean)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with PrunedScan with InsertableRelation { // Share date format object as it is expensive to parse date pattern. @@ -119,7 +120,7 @@ case class CsvRelation protected[spark] ( while (index < schemaFields.length) { val field = schemaFields(index) rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable, - treatEmptyValuesAsNulls, nullValue, simpleDateFormatter) + treatEmptyValuesAsNulls, nullValue, simpleDateFormatter, insertNullOnErrors) index = index + 1 } Some(Row.fromSeq(rowArray)) @@ -199,7 +200,8 @@ case class CsvRelation protected[spark] ( field.nullable, treatEmptyValuesAsNulls, nullValue, - simpleDateFormatter + simpleDateFormatter, + insertNullOnErrors ) subIndex = subIndex + 1 } diff --git a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala index 04c8ef4..b8032aa 100755 --- a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala @@ -128,6 +128,14 @@ class DefaultSource val charset = parameters.getOrElse("charset", TextFile.DEFAULT_CHARSET.name()) // TODO validate charset? + val treatParseExceptionAsNull = parameters.getOrElse("insertNullOnErrors", "false") + val insertNullOnErrorFlag = if (treatParseExceptionAsNull == "false"){ + false + } else if (treatParseExceptionAsNull == "true") { + true + } else { + throw new Exception("Insert null on errors flag can be true or false") + } val inferSchema = parameters.getOrElse("inferSchema", "false") val inferSchemaFlag = if (inferSchema == "false") { false @@ -168,7 +176,8 @@ class DefaultSource codec, nullValue, dateFormat, - maxCharsPerCol)(sqlContext) + maxCharsPerCol, + insertNullOnErrorFlag)(sqlContext) } override def createRelation( diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index 2cce4af..edeb05e 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -60,7 +60,8 @@ package object csv { ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace, treatEmptyValuesAsNulls = false, - inferCsvSchema = inferSchema)(sqlContext) + inferCsvSchema = inferSchema, + insertNullOnErrors = true)(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } @@ -85,7 +86,7 @@ package object csv { ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace, treatEmptyValuesAsNulls = false, - inferCsvSchema = inferSchema)(sqlContext) + inferCsvSchema = inferSchema, insertNullOnErrors = false)(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } } diff --git a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala index 8c3474b..2f6e4c4 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala @@ -21,7 +21,6 @@ import java.text.{SimpleDateFormat, NumberFormat} import java.util.Locale import org.apache.spark.sql.types._ - import scala.util.Try /** @@ -45,31 +44,38 @@ object TypeCast { nullable: Boolean = true, treatEmptyValuesAsNulls: Boolean = false, nullValue: String = "", - dateFormatter: SimpleDateFormat = null): Any = { + dateFormatter: SimpleDateFormat = null, + insertNullOnError : Boolean = false): Any = { if (datum == nullValue && nullable || (treatEmptyValuesAsNulls && datum == "")){ null } else { - castType match { - case _: ByteType => datum.toByte - case _: ShortType => datum.toShort - case _: IntegerType => datum.toInt - case _: LongType => datum.toLong - case _: FloatType => Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) - case _: DoubleType => Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) - case _: BooleanType => datum.toBoolean - case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) - case _: TimestampType if dateFormatter != null => - new Timestamp(dateFormatter.parse(datum).getTime) - case _: TimestampType => Timestamp.valueOf(datum) - case _: DateType if dateFormatter != null => - new Date(dateFormatter.parse(datum).getTime) - case _: DateType => Date.valueOf(datum) - case _: StringType => datum - case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + try { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) + case _: DoubleType => Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) + case _: BooleanType => datum.toBoolean + case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) + case _: TimestampType if dateFormatter != null => + new Timestamp(dateFormatter.parse(datum).getTime) + case _: TimestampType => Timestamp.valueOf(datum) + case _: DateType if dateFormatter != null => + new Date(dateFormatter.parse(datum).getTime) + case _: DateType => Date.valueOf(datum) + case _: StringType => datum + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + } + } + catch { + case e : UnsupportedTypeException => throw e + case e : Throwable => if (insertNullOnError && nullable) null else throw e } } } @@ -102,3 +108,6 @@ object TypeCast { } } } + +class UnsupportedTypeException(message: String = null, cause: Throwable = null) + extends RuntimeException(message, cause) \ No newline at end of file diff --git a/src/test/resources/cars_dirty.csv b/src/test/resources/cars_dirty.csv new file mode 100644 index 0000000..97f5518 --- /dev/null +++ b/src/test/resources/cars_dirty.csv @@ -0,0 +1,5 @@ +year,make,model,price,comment,blank +2012,Tesla,S"80,000.65" +2013.5,Ford,E350,35,000,"Go get one now they are going fast" +2015,,Volt,5,000 +new,"",Volt,5000.00 \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index 6058fe9..adcdb00 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -32,6 +32,7 @@ import org.scalatest.Matchers._ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { val carsFile = "src/test/resources/cars.csv" val carsMalformedFile = "src/test/resources/cars-malformed.csv" + val carsDirtyTsvFile = "src/test/resources/cars_dirty.csv" val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv" val carsTsvFile = "src/test/resources/cars.tsv" val carsAltFile = "src/test/resources/cars-alternative.csv" @@ -69,6 +70,13 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { } } + test("Dirty Data CSV"){ + val results = sqlContext.csvFile( + carsDirtyTsvFile, parserLib = parserLib + ).collect() + assert(results.length == 4) + } + test("DSL test") { val results = sqlContext .csvFile(carsFile, parserLib = parserLib) @@ -216,6 +224,32 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) } + test("Insert Null On Error with schema"){ + val carsSchema = new StructType( + Array( + StructField("year", IntegerType, nullable = true), + StructField("make", StringType, nullable = true), + StructField("model", StringType, nullable = true), + StructField("price", DoubleType, nullable = true), + StructField("comment", StringType, nullable = true), + StructField("blank", IntegerType, nullable = true) + ) + ) + + val results = new CsvParser() + .withSchema(carsSchema) + .withUseHeader(true) + .withDelimiter(',') + .withQuoteChar('\"').withInsertNullOnError(true) + .csvFile(sqlContext, carsDirtyTsvFile).select("year", "make") + .collect() + + assert(results(0).toSeq == Seq(2012, "Tesla")) + assert(results(1).toSeq == Seq(null, "Ford")) + assert(results(2).toSeq == Seq(2015, null)) + assert(results(3).toSeq == Seq(null, null)) + } + test("DSL test roundtrip nulls") { // Create temp directory TestUtils.deleteRecursively(new File(tempEmptyDir)) diff --git a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala index 9f3df23..ef55c23 100644 --- a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala @@ -36,7 +36,53 @@ class TypeCastSuite extends FunSuite { } } - test("Can parse escaped characters") { + test("Parse exception is caught correctly") { + + def testParseException(castType: DataType, badValues: Seq[String]): Unit = { + badValues.foreach { testValue => + assert(TypeCast.castTo(testValue, castType, true, false, "", null, true) == null) + // if not nullable it isn't null + try { + TypeCast.castTo(testValue, castType, false, false, "", null, true) + } catch { + case e: Throwable => assert(e.isInstanceOf[Exception]) + } + } + } + + assert(TypeCast.castTo("10", ByteType, true, false, "", null, true) == 10) + testParseException(ByteType, Seq("10.5", "s", "true")) + + assert(TypeCast.castTo("10", ShortType, true, false, "", null, true) == 10) + testParseException(ShortType, Seq("s", "true")) + + assert(TypeCast.castTo("10", IntegerType, true, false, "", null, true) == 10) + testParseException(IntegerType, Seq("10.5", "s", "true")) + + assert(TypeCast.castTo("10", LongType, true, false, "", null, true) == 10) + testParseException(LongType, Seq("10.5", "s", "true")) + + assert(TypeCast.castTo("1.00", FloatType, true, false, "", null, true) == 1.0) + testParseException(FloatType, Seq("s", "true")) + + assert(TypeCast.castTo("1.00", DoubleType, true, false, "", null, true) == 1.0) + testParseException(DoubleType, Seq("s", "true")) + + assert(TypeCast.castTo("true", BooleanType, true, false, "", null, true) == true) + testParseException(BooleanType, Seq("s", "5")) + + val timestamp = "2015-01-01 00:00:00" + assert(TypeCast.castTo(timestamp, TimestampType, true, false, "", null, true) + == Timestamp.valueOf(timestamp)) + testParseException(TimestampType, Seq("5", "string")) + + assert(TypeCast.castTo("2015-01-01", DateType, true, false, "", null, true) + == Date.valueOf("2015-01-01")) + testParseException(DateType, Seq("5", "string", timestamp)) + + } + + test("Can parse escaped characters") { assert(TypeCast.toChar("""\t""") === '\t') assert(TypeCast.toChar("""\r""") === '\r') assert(TypeCast.toChar("""\b""") === '\b')