Skip to content

Commit

Permalink
Add nullValue being respected when parsing CSVs
Browse files Browse the repository at this point in the history
This change makes it so that we look for a user specified nullValue
through the CSV parsing. This allows for handling CSVs that might use
something else other than an empty string to represent nulls.

It reuses the same flag as CSV saving, `nullValue`. This change should
be non-breaking.

This also pushes this behavior into inferSchema so that inferred schemas
will properly reflect the user given null value.

Author: Addison Higham <[email protected]>

Closes #224 from addisonj/master.
  • Loading branch information
Addison Higham authored and falaki committed Jan 6, 2016
1 parent 44964a2 commit b556e59
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 20 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ When reading files the API accepts several options:
* `inferSchema`: automatically infers column types. It requires one extra pass over the data and is false by default
* `comment`: skip lines beginning with this character. Default is `"#"`. Disable comments by setting this to `null`.
* `codec`: compression codec to use when saving to file. Should be the fully qualified name of a class implementing `org.apache.hadoop.io.compress.CompressionCodec`. Defaults to no compression when a codec is not specified.
* `nullValue`: specificy a string that indicates a null value, any fields matching this string will be set as nulls in the DataFrame

The package also support saving simple (non-nested) DataFrame. When saving you can specify the delimiter and whether we should generate a header row for the table. See following examples for more details.

Expand Down Expand Up @@ -109,7 +110,7 @@ import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerT

val sqlContext = new SQLContext(sc)
val customSchema = StructType(
StructField("year", IntegerType, true),
StructField("year", IntegerType, true),
StructField("make", StringType, true),
StructField("model", StringType, true),
StructField("comment", StringType, true),
Expand Down Expand Up @@ -155,7 +156,7 @@ import org.apache.spark.sql.SQLContext

val sqlContext = new SQLContext(sc)
val df = sqlContext.load(
"com.databricks.spark.csv",
"com.databricks.spark.csv",
Map("path" -> "cars.csv", "header" -> "true", "inferSchema" -> "true"))
val selectedData = df.select("year", "model")
selectedData.save("newcars.csv", "com.databricks.spark.csv")
Expand All @@ -168,14 +169,14 @@ import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerT

val sqlContext = new SQLContext(sc)
val customSchema = StructType(
StructField("year", IntegerType, true),
StructField("year", IntegerType, true),
StructField("make", StringType, true),
StructField("model", StringType, true),
StructField("comment", StringType, true),
StructField("blank", StringType, true))

val df = sqlContext.load(
"com.databricks.spark.csv",
"com.databricks.spark.csv",
schema = customSchema,
Map("path" -> "cars.csv", "header" -> "true"))

Expand Down Expand Up @@ -210,7 +211,7 @@ import org.apache.spark.sql.types.*;

SQLContext sqlContext = new SQLContext(sc);
StructType customSchema = new StructType(new StructField[] {
new StructField("year", DataTypes.IntegerType, true, Metadata.empty()),
new StructField("year", DataTypes.IntegerType, true, Metadata.empty()),
new StructField("make", DataTypes.StringType, true, Metadata.empty()),
new StructField("model", DataTypes.StringType, true, Metadata.empty()),
new StructField("comment", DataTypes.StringType, true, Metadata.empty()),
Expand Down
12 changes: 10 additions & 2 deletions src/main/scala/com/databricks/spark/csv/CsvParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class CsvParser extends Serializable {
private var charset: String = TextFile.DEFAULT_CHARSET.name()
private var inferSchema: Boolean = false
private var codec: String = null
private var nullValue: String = ""

def withUseHeader(flag: Boolean): CsvParser = {
this.useHeader = flag
Expand Down Expand Up @@ -111,6 +112,11 @@ class CsvParser extends Serializable {
this
}

def withNullValue(nullValue: String): CsvParser = {
this.nullValue = nullValue
this
}

/** Returns a Schema RDD for the given CSV path. */
@throws[RuntimeException]
def csvFile(sqlContext: SQLContext, path: String): DataFrame = {
Expand All @@ -129,7 +135,8 @@ class CsvParser extends Serializable {
treatEmptyValuesAsNulls,
schema,
inferSchema,
codec)(sqlContext)
codec,
nullValue)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}

Expand All @@ -149,7 +156,8 @@ class CsvParser extends Serializable {
treatEmptyValuesAsNulls,
schema,
inferSchema,
codec)(sqlContext)
codec,
nullValue)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}
}
11 changes: 7 additions & 4 deletions src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ case class CsvRelation protected[spark] (
treatEmptyValuesAsNulls: Boolean,
userSchema: StructType = null,
inferCsvSchema: Boolean,
codec: String = null)(@transient val sqlContext: SQLContext)
codec: String = null,
nullValue: String = "")(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with PrunedScan with InsertableRelation {

/**
Expand Down Expand Up @@ -116,7 +117,7 @@ case class CsvRelation protected[spark] (
while (index < schemaFields.length) {
val field = schemaFields(index)
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable,
treatEmptyValuesAsNulls)
treatEmptyValuesAsNulls, nullValue)
index = index + 1
}
Some(Row.fromSeq(rowArray))
Expand Down Expand Up @@ -189,7 +190,9 @@ case class CsvRelation protected[spark] (
indexSafeTokens(index),
field.dataType,
field.nullable,
treatEmptyValuesAsNulls)
treatEmptyValuesAsNulls,
nullValue
)
subIndex = subIndex + 1
}
Some(Row.fromSeq(rowArray.take(requiredSize)))
Expand Down Expand Up @@ -235,7 +238,7 @@ case class CsvRelation protected[spark] (
firstRow.zipWithIndex.map { case (value, index) => s"C$index"}
}
if (this.inferCsvSchema) {
InferSchema(tokenRdd(header), header)
InferSchema(tokenRdd(header), header, nullValue)
} else {
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/com/databricks/spark/csv/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class DefaultSource
} else {
throw new Exception("Infer schema flag can be true or false")
}
val nullValue = parameters.getOrElse("nullValue", "")

val codec = parameters.getOrElse("codec", null)

Expand All @@ -154,7 +155,8 @@ class DefaultSource
treatEmptyValuesAsNullsFlag,
schema,
inferSchemaFlag,
codec)(sqlContext)
codec,
nullValue)(sqlContext)
}

override def createRelation(
Expand Down
20 changes: 14 additions & 6 deletions src/main/scala/com/databricks/spark/csv/util/InferSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,15 @@ private[csv] object InferSchema {
* 2. Merge row types to find common type
* 3. Replace any null types with string type
*/
def apply(tokenRdd: RDD[Array[String]], header: Array[String]): StructType = {
def apply(
tokenRdd: RDD[Array[String]],
header: Array[String],
nullValue: String = ""): StructType = {

val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] = tokenRdd.aggregate(startType)(inferRowType, mergeRowTypes)
val rootTypes: Array[DataType] = tokenRdd.aggregate(startType)(
inferRowType(nullValue),
mergeRowTypes)

val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
StructField(thisHeader, rootType, nullable = true)
Expand All @@ -43,10 +48,11 @@ private[csv] object InferSchema {
StructType(structFields)
}

private def inferRowType(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
private def inferRowType(nullValue: String)
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
var i = 0
while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing.
rowSoFar(i) = inferField(rowSoFar(i), next(i))
rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue)
i+=1
}
rowSoFar
Expand All @@ -68,8 +74,10 @@ private[csv] object InferSchema {
* Infer type of string field. Given known type Double, and a string "1", there is no
* point checking if it is an Int, as the final type must be Double or higher.
*/
private[csv] def inferField(typeSoFar: DataType, field: String): DataType = {
if (field == null || field.isEmpty) {
private[csv] def inferField(typeSoFar: DataType,
field: String,
nullValue: String = ""): DataType = {
if (field == null || field.isEmpty || field == nullValue) {
typeSoFar
} else {
typeSoFar match {
Expand Down
11 changes: 9 additions & 2 deletions src/main/scala/com/databricks/spark/csv/util/TypeCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,15 @@ object TypeCast {
datum: String,
castType: DataType,
nullable: Boolean = true,
treatEmptyValuesAsNulls: Boolean = false): Any = {
if (datum == "" && nullable && (!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls)){
treatEmptyValuesAsNulls: Boolean = false,
nullValue: String = ""): Any = {
// if nullValue is not an empty string, don't require treatEmptyValuesAsNulls
// to be set to true
val nullValueIsNotEmpty = nullValue != ""
if (datum == nullValue &&
nullable &&
(!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls || nullValueIsNotEmpty)
){
null
} else {
castType match {
Expand Down
4 changes: 4 additions & 0 deletions src/test/resources/null_null_numbers.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name,age
alice,35
bob,null
null,24
4 changes: 4 additions & 0 deletions src/test/resources/null_slashn_numbers.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name,age
alice,35
bob,\N
\N,24
32 changes: 32 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
val carsAltFile = "src/test/resources/cars-alternative.csv"
val carsUnbalancedQuotesFile = "src/test/resources/cars-unbalanced-quotes.csv"
val nullNumbersFile = "src/test/resources/null-numbers.csv"
val nullNullNumbersFile = "src/test/resources/null_null_numbers.csv"
val nullSlashNNumbersFile = "src/test/resources/null_slashn_numbers.csv"
val emptyFile = "src/test/resources/empty.csv"
val ageFile = "src/test/resources/ages.csv"
val escapeFile = "src/test/resources/escape.csv"
Expand Down Expand Up @@ -572,6 +574,36 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
assert(results(2).toSeq === Seq("", 24))
}

test("DSL test nullable fields with user defined null value of \"null\"") {
val results = new CsvParser()
.withSchema(StructType(List(StructField("name", StringType, false),
StructField("age", IntegerType, true))))
.withUseHeader(true)
.withParserLib(parserLib)
.withNullValue("null")
.csvFile(sqlContext, nullNullNumbersFile)
.collect()

assert(results.head.toSeq === Seq("alice", 35))
assert(results(1).toSeq === Seq("bob", null))
assert(results(2).toSeq === Seq("null", 24))
}

test("DSL test nullable fields with user defined null value of \"\\N\"") {
val results = new CsvParser()
.withSchema(StructType(List(StructField("name", StringType, false),
StructField("age", IntegerType, true))))
.withUseHeader(true)
.withParserLib(parserLib)
.withNullValue("\\N")
.csvFile(sqlContext, nullSlashNNumbersFile)
.collect()

assert(results.head.toSeq === Seq("alice", 35))
assert(results(1).toSeq === Seq("bob", null))
assert(results(2).toSeq === Seq("\\N", 24))
}

test("Commented lines in CSV data") {
val results: Array[Row] = new CsvParser()
.withDelimiter(',')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ class InferSchemaSuite extends FunSuite {
assert(InferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType)
}

test("Null fields are handled properly when a nullValue is specified") {
assert(InferSchema.inferField(NullType, "null", "null") == NullType)
assert(InferSchema.inferField(StringType, "null", "null") == StringType)
assert(InferSchema.inferField(LongType, "null", "null") == LongType)
assert(InferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType)
assert(InferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
assert(InferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
}

test("String fields types are inferred correctly from other types") {
assert(InferSchema.inferField(LongType, "1.0") == DoubleType)
assert(InferSchema.inferField(LongType, "test") == StringType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,13 @@ class TypeCastSuite extends FunSuite {
assert(TypeCast.castTo("1,00", FloatType) == 1.0)
assert(TypeCast.castTo("1,00", DoubleType) == 1.0)
}

test("Can handle mapping user specified nullValues") {
assert(TypeCast.castTo("null", StringType, true, false, "null") == null)
assert(TypeCast.castTo("\\N", ByteType, true, false, "\\N") == null)
assert(TypeCast.castTo("", ShortType, true, false) == null)
assert(TypeCast.castTo("null", StringType, true, true, "null") == null)
assert(TypeCast.castTo("", StringType, true, false, "") == "")
assert(TypeCast.castTo("", StringType, true, true, "") == null)
}
}

0 comments on commit b556e59

Please sign in to comment.