From b556e596714c323bf3286e612fa4f2be0c3ab280 Mon Sep 17 00:00:00 2001 From: Addison Higham Date: Wed, 6 Jan 2016 00:40:11 -0800 Subject: [PATCH] Add nullValue being respected when parsing CSVs This change makes it so that we look for a user specified nullValue through the CSV parsing. This allows for handling CSVs that might use something else other than an empty string to represent nulls. It reuses the same flag as CSV saving, `nullValue`. This change should be non-breaking. This also pushes this behavior into inferSchema so that inferred schemas will properly reflect the user given null value. Author: Addison Higham Closes #224 from addisonj/master. --- README.md | 11 ++++--- .../com/databricks/spark/csv/CsvParser.scala | 12 +++++-- .../databricks/spark/csv/CsvRelation.scala | 11 ++++--- .../databricks/spark/csv/DefaultSource.scala | 4 ++- .../spark/csv/util/InferSchema.scala | 20 ++++++++---- .../databricks/spark/csv/util/TypeCast.scala | 11 +++++-- src/test/resources/null_null_numbers.csv | 4 +++ src/test/resources/null_slashn_numbers.csv | 4 +++ .../com/databricks/spark/csv/CsvSuite.scala | 32 +++++++++++++++++++ .../spark/csv/util/InferSchemaSuite.scala | 9 ++++++ .../spark/csv/util/TypeCastSuite.scala | 9 ++++++ 11 files changed, 107 insertions(+), 20 deletions(-) create mode 100644 src/test/resources/null_null_numbers.csv create mode 100644 src/test/resources/null_slashn_numbers.csv diff --git a/README.md b/README.md index e2aba30..e21752a 100755 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ When reading files the API accepts several options: * `inferSchema`: automatically infers column types. It requires one extra pass over the data and is false by default * `comment`: skip lines beginning with this character. Default is `"#"`. Disable comments by setting this to `null`. * `codec`: compression codec to use when saving to file. Should be the fully qualified name of a class implementing `org.apache.hadoop.io.compress.CompressionCodec`. Defaults to no compression when a codec is not specified. +* `nullValue`: specificy a string that indicates a null value, any fields matching this string will be set as nulls in the DataFrame The package also support saving simple (non-nested) DataFrame. When saving you can specify the delimiter and whether we should generate a header row for the table. See following examples for more details. @@ -109,7 +110,7 @@ import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerT val sqlContext = new SQLContext(sc) val customSchema = StructType( - StructField("year", IntegerType, true), + StructField("year", IntegerType, true), StructField("make", StringType, true), StructField("model", StringType, true), StructField("comment", StringType, true), @@ -155,7 +156,7 @@ import org.apache.spark.sql.SQLContext val sqlContext = new SQLContext(sc) val df = sqlContext.load( - "com.databricks.spark.csv", + "com.databricks.spark.csv", Map("path" -> "cars.csv", "header" -> "true", "inferSchema" -> "true")) val selectedData = df.select("year", "model") selectedData.save("newcars.csv", "com.databricks.spark.csv") @@ -168,14 +169,14 @@ import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerT val sqlContext = new SQLContext(sc) val customSchema = StructType( - StructField("year", IntegerType, true), + StructField("year", IntegerType, true), StructField("make", StringType, true), StructField("model", StringType, true), StructField("comment", StringType, true), StructField("blank", StringType, true)) val df = sqlContext.load( - "com.databricks.spark.csv", + "com.databricks.spark.csv", schema = customSchema, Map("path" -> "cars.csv", "header" -> "true")) @@ -210,7 +211,7 @@ import org.apache.spark.sql.types.*; SQLContext sqlContext = new SQLContext(sc); StructType customSchema = new StructType(new StructField[] { - new StructField("year", DataTypes.IntegerType, true, Metadata.empty()), + new StructField("year", DataTypes.IntegerType, true, Metadata.empty()), new StructField("make", DataTypes.StringType, true, Metadata.empty()), new StructField("model", DataTypes.StringType, true, Metadata.empty()), new StructField("comment", DataTypes.StringType, true, Metadata.empty()), diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index 0a2f914..370ab3c 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -40,6 +40,7 @@ class CsvParser extends Serializable { private var charset: String = TextFile.DEFAULT_CHARSET.name() private var inferSchema: Boolean = false private var codec: String = null + private var nullValue: String = "" def withUseHeader(flag: Boolean): CsvParser = { this.useHeader = flag @@ -111,6 +112,11 @@ class CsvParser extends Serializable { this } + def withNullValue(nullValue: String): CsvParser = { + this.nullValue = nullValue + this + } + /** Returns a Schema RDD for the given CSV path. */ @throws[RuntimeException] def csvFile(sqlContext: SQLContext, path: String): DataFrame = { @@ -129,7 +135,8 @@ class CsvParser extends Serializable { treatEmptyValuesAsNulls, schema, inferSchema, - codec)(sqlContext) + codec, + nullValue)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } @@ -149,7 +156,8 @@ class CsvParser extends Serializable { treatEmptyValuesAsNulls, schema, inferSchema, - codec)(sqlContext) + codec, + nullValue)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } } diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index dcab9c8..5a09176 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -46,7 +46,8 @@ case class CsvRelation protected[spark] ( treatEmptyValuesAsNulls: Boolean, userSchema: StructType = null, inferCsvSchema: Boolean, - codec: String = null)(@transient val sqlContext: SQLContext) + codec: String = null, + nullValue: String = "")(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with PrunedScan with InsertableRelation { /** @@ -116,7 +117,7 @@ case class CsvRelation protected[spark] ( while (index < schemaFields.length) { val field = schemaFields(index) rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable, - treatEmptyValuesAsNulls) + treatEmptyValuesAsNulls, nullValue) index = index + 1 } Some(Row.fromSeq(rowArray)) @@ -189,7 +190,9 @@ case class CsvRelation protected[spark] ( indexSafeTokens(index), field.dataType, field.nullable, - treatEmptyValuesAsNulls) + treatEmptyValuesAsNulls, + nullValue + ) subIndex = subIndex + 1 } Some(Row.fromSeq(rowArray.take(requiredSize))) @@ -235,7 +238,7 @@ case class CsvRelation protected[spark] ( firstRow.zipWithIndex.map { case (value, index) => s"C$index"} } if (this.inferCsvSchema) { - InferSchema(tokenRdd(header), header) + InferSchema(tokenRdd(header), header, nullValue) } else { // By default fields are assumed to be StringType val schemaFields = header.map { fieldName => diff --git a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala index 13abf04..c2e1481 100755 --- a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala @@ -136,6 +136,7 @@ class DefaultSource } else { throw new Exception("Infer schema flag can be true or false") } + val nullValue = parameters.getOrElse("nullValue", "") val codec = parameters.getOrElse("codec", null) @@ -154,7 +155,8 @@ class DefaultSource treatEmptyValuesAsNullsFlag, schema, inferSchemaFlag, - codec)(sqlContext) + codec, + nullValue)(sqlContext) } override def createRelation( diff --git a/src/main/scala/com/databricks/spark/csv/util/InferSchema.scala b/src/main/scala/com/databricks/spark/csv/util/InferSchema.scala index dc9db6c..d9991a2 100644 --- a/src/main/scala/com/databricks/spark/csv/util/InferSchema.scala +++ b/src/main/scala/com/databricks/spark/csv/util/InferSchema.scala @@ -31,10 +31,15 @@ private[csv] object InferSchema { * 2. Merge row types to find common type * 3. Replace any null types with string type */ - def apply(tokenRdd: RDD[Array[String]], header: Array[String]): StructType = { + def apply( + tokenRdd: RDD[Array[String]], + header: Array[String], + nullValue: String = ""): StructType = { val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) - val rootTypes: Array[DataType] = tokenRdd.aggregate(startType)(inferRowType, mergeRowTypes) + val rootTypes: Array[DataType] = tokenRdd.aggregate(startType)( + inferRowType(nullValue), + mergeRowTypes) val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => StructField(thisHeader, rootType, nullable = true) @@ -43,10 +48,11 @@ private[csv] object InferSchema { StructType(structFields) } - private def inferRowType(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { + private def inferRowType(nullValue: String) + (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. - rowSoFar(i) = inferField(rowSoFar(i), next(i)) + rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue) i+=1 } rowSoFar @@ -68,8 +74,10 @@ private[csv] object InferSchema { * Infer type of string field. Given known type Double, and a string "1", there is no * point checking if it is an Int, as the final type must be Double or higher. */ - private[csv] def inferField(typeSoFar: DataType, field: String): DataType = { - if (field == null || field.isEmpty) { + private[csv] def inferField(typeSoFar: DataType, + field: String, + nullValue: String = ""): DataType = { + if (field == null || field.isEmpty || field == nullValue) { typeSoFar } else { typeSoFar match { diff --git a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala index 226eafd..edecf97 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala @@ -43,8 +43,15 @@ object TypeCast { datum: String, castType: DataType, nullable: Boolean = true, - treatEmptyValuesAsNulls: Boolean = false): Any = { - if (datum == "" && nullable && (!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls)){ + treatEmptyValuesAsNulls: Boolean = false, + nullValue: String = ""): Any = { + // if nullValue is not an empty string, don't require treatEmptyValuesAsNulls + // to be set to true + val nullValueIsNotEmpty = nullValue != "" + if (datum == nullValue && + nullable && + (!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls || nullValueIsNotEmpty) + ){ null } else { castType match { diff --git a/src/test/resources/null_null_numbers.csv b/src/test/resources/null_null_numbers.csv new file mode 100644 index 0000000..d020d9f --- /dev/null +++ b/src/test/resources/null_null_numbers.csv @@ -0,0 +1,4 @@ +name,age +alice,35 +bob,null +null,24 diff --git a/src/test/resources/null_slashn_numbers.csv b/src/test/resources/null_slashn_numbers.csv new file mode 100644 index 0000000..4068ca8 --- /dev/null +++ b/src/test/resources/null_slashn_numbers.csv @@ -0,0 +1,4 @@ +name,age +alice,35 +bob,\N +\N,24 diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index 9acc7f3..6fddab5 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -33,6 +33,8 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { val carsAltFile = "src/test/resources/cars-alternative.csv" val carsUnbalancedQuotesFile = "src/test/resources/cars-unbalanced-quotes.csv" val nullNumbersFile = "src/test/resources/null-numbers.csv" + val nullNullNumbersFile = "src/test/resources/null_null_numbers.csv" + val nullSlashNNumbersFile = "src/test/resources/null_slashn_numbers.csv" val emptyFile = "src/test/resources/empty.csv" val ageFile = "src/test/resources/ages.csv" val escapeFile = "src/test/resources/escape.csv" @@ -572,6 +574,36 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(results(2).toSeq === Seq("", 24)) } + test("DSL test nullable fields with user defined null value of \"null\"") { + val results = new CsvParser() + .withSchema(StructType(List(StructField("name", StringType, false), + StructField("age", IntegerType, true)))) + .withUseHeader(true) + .withParserLib(parserLib) + .withNullValue("null") + .csvFile(sqlContext, nullNullNumbersFile) + .collect() + + assert(results.head.toSeq === Seq("alice", 35)) + assert(results(1).toSeq === Seq("bob", null)) + assert(results(2).toSeq === Seq("null", 24)) + } + + test("DSL test nullable fields with user defined null value of \"\\N\"") { + val results = new CsvParser() + .withSchema(StructType(List(StructField("name", StringType, false), + StructField("age", IntegerType, true)))) + .withUseHeader(true) + .withParserLib(parserLib) + .withNullValue("\\N") + .csvFile(sqlContext, nullSlashNNumbersFile) + .collect() + + assert(results.head.toSeq === Seq("alice", 35)) + assert(results(1).toSeq === Seq("bob", null)) + assert(results(2).toSeq === Seq("\\N", 24)) + } + test("Commented lines in CSV data") { val results: Array[Row] = new CsvParser() .withDelimiter(',') diff --git a/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala b/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala index 1d43c06..d713649 100644 --- a/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala @@ -15,6 +15,15 @@ class InferSchemaSuite extends FunSuite { assert(InferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType) } + test("Null fields are handled properly when a nullValue is specified") { + assert(InferSchema.inferField(NullType, "null", "null") == NullType) + assert(InferSchema.inferField(StringType, "null", "null") == StringType) + assert(InferSchema.inferField(LongType, "null", "null") == LongType) + assert(InferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType) + assert(InferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) + assert(InferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) + } + test("String fields types are inferred correctly from other types") { assert(InferSchema.inferField(LongType, "1.0") == DoubleType) assert(InferSchema.inferField(LongType, "test") == StringType) diff --git a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala index f2e93fc..b8e6e71 100644 --- a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala @@ -94,4 +94,13 @@ class TypeCastSuite extends FunSuite { assert(TypeCast.castTo("1,00", FloatType) == 1.0) assert(TypeCast.castTo("1,00", DoubleType) == 1.0) } + + test("Can handle mapping user specified nullValues") { + assert(TypeCast.castTo("null", StringType, true, false, "null") == null) + assert(TypeCast.castTo("\\N", ByteType, true, false, "\\N") == null) + assert(TypeCast.castTo("", ShortType, true, false) == null) + assert(TypeCast.castTo("null", StringType, true, true, "null") == null) + assert(TypeCast.castTo("", StringType, true, false, "") == "") + assert(TypeCast.castTo("", StringType, true, true, "") == null) + } }