Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nullValue being respected when parsing CSVs #224

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ When reading files the API accepts several options:
* `inferSchema`: automatically infers column types. It requires one extra pass over the data and is false by default
* `comment`: skip lines beginning with this character. Default is `"#"`. Disable comments by setting this to `null`.
* `codec`: compression codec to use when saving to file. Should be the fully qualified name of a class implementing `org.apache.hadoop.io.compress.CompressionCodec`. Defaults to no compression when a codec is not specified.
* `nullValue`: specificy a string that indicates a null value, any fields matching this string will be set as nulls in the DataFrame

The package also support saving simple (non-nested) DataFrame. When saving you can specify the delimiter and whether we should generate a header row for the table. See following examples for more details.

Expand Down Expand Up @@ -109,7 +110,7 @@ import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerT

val sqlContext = new SQLContext(sc)
val customSchema = StructType(
StructField("year", IntegerType, true),
StructField("year", IntegerType, true),
StructField("make", StringType, true),
StructField("model", StringType, true),
StructField("comment", StringType, true),
Expand Down Expand Up @@ -155,7 +156,7 @@ import org.apache.spark.sql.SQLContext

val sqlContext = new SQLContext(sc)
val df = sqlContext.load(
"com.databricks.spark.csv",
"com.databricks.spark.csv",
Map("path" -> "cars.csv", "header" -> "true", "inferSchema" -> "true"))
val selectedData = df.select("year", "model")
selectedData.save("newcars.csv", "com.databricks.spark.csv")
Expand All @@ -168,14 +169,14 @@ import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerT

val sqlContext = new SQLContext(sc)
val customSchema = StructType(
StructField("year", IntegerType, true),
StructField("year", IntegerType, true),
StructField("make", StringType, true),
StructField("model", StringType, true),
StructField("comment", StringType, true),
StructField("blank", StringType, true))

val df = sqlContext.load(
"com.databricks.spark.csv",
"com.databricks.spark.csv",
schema = customSchema,
Map("path" -> "cars.csv", "header" -> "true"))

Expand Down Expand Up @@ -210,7 +211,7 @@ import org.apache.spark.sql.types.*;

SQLContext sqlContext = new SQLContext(sc);
StructType customSchema = new StructType(new StructField[] {
new StructField("year", DataTypes.IntegerType, true, Metadata.empty()),
new StructField("year", DataTypes.IntegerType, true, Metadata.empty()),
new StructField("make", DataTypes.StringType, true, Metadata.empty()),
new StructField("model", DataTypes.StringType, true, Metadata.empty()),
new StructField("comment", DataTypes.StringType, true, Metadata.empty()),
Expand Down
12 changes: 10 additions & 2 deletions src/main/scala/com/databricks/spark/csv/CsvParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class CsvParser extends Serializable {
private var charset: String = TextFile.DEFAULT_CHARSET.name()
private var inferSchema: Boolean = false
private var codec: String = null
private var nullValue: String = ""

def withUseHeader(flag: Boolean): CsvParser = {
this.useHeader = flag
Expand Down Expand Up @@ -111,6 +112,11 @@ class CsvParser extends Serializable {
this
}

def withNullValue(nullValue: String): CsvParser = {
this.nullValue = nullValue
this
}

/** Returns a Schema RDD for the given CSV path. */
@throws[RuntimeException]
def csvFile(sqlContext: SQLContext, path: String): DataFrame = {
Expand All @@ -129,7 +135,8 @@ class CsvParser extends Serializable {
treatEmptyValuesAsNulls,
schema,
inferSchema,
codec)(sqlContext)
codec,
nullValue)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}

Expand All @@ -149,7 +156,8 @@ class CsvParser extends Serializable {
treatEmptyValuesAsNulls,
schema,
inferSchema,
codec)(sqlContext)
codec,
nullValue)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}
}
11 changes: 7 additions & 4 deletions src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ case class CsvRelation protected[spark] (
treatEmptyValuesAsNulls: Boolean,
userSchema: StructType = null,
inferCsvSchema: Boolean,
codec: String = null)(@transient val sqlContext: SQLContext)
codec: String = null,
nullValue: String = "")(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with PrunedScan with InsertableRelation {

/**
Expand Down Expand Up @@ -116,7 +117,7 @@ case class CsvRelation protected[spark] (
while (index < schemaFields.length) {
val field = schemaFields(index)
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable,
treatEmptyValuesAsNulls)
treatEmptyValuesAsNulls, nullValue)
index = index + 1
}
Some(Row.fromSeq(rowArray))
Expand Down Expand Up @@ -189,7 +190,9 @@ case class CsvRelation protected[spark] (
indexSafeTokens(index),
field.dataType,
field.nullable,
treatEmptyValuesAsNulls)
treatEmptyValuesAsNulls,
nullValue
)
subIndex = subIndex + 1
}
Some(Row.fromSeq(rowArray.take(requiredSize)))
Expand Down Expand Up @@ -235,7 +238,7 @@ case class CsvRelation protected[spark] (
firstRow.zipWithIndex.map { case (value, index) => s"C$index"}
}
if (this.inferCsvSchema) {
InferSchema(tokenRdd(header), header)
InferSchema(tokenRdd(header), header, nullValue)
} else {
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/com/databricks/spark/csv/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class DefaultSource
} else {
throw new Exception("Infer schema flag can be true or false")
}
val nullValue = parameters.getOrElse("nullValue", "")

val codec = parameters.getOrElse("codec", null)

Expand All @@ -154,7 +155,8 @@ class DefaultSource
treatEmptyValuesAsNullsFlag,
schema,
inferSchemaFlag,
codec)(sqlContext)
codec,
nullValue)(sqlContext)
}

override def createRelation(
Expand Down
20 changes: 14 additions & 6 deletions src/main/scala/com/databricks/spark/csv/util/InferSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@ private[csv] object InferSchema {
* 2. Merge row types to find common type
* 3. Replace any null types with string type
*/
def apply(tokenRdd: RDD[Array[String]], header: Array[String]): StructType = {
def apply(
tokenRdd: RDD[Array[String]],
header: Array[String],
nullValue: String = ""): StructType = {

val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] = tokenRdd.aggregate(startType)(inferRowType, mergeRowTypes)
val rootTypes: Array[DataType] = tokenRdd.aggregate(startType)(
inferRowType(nullValue),
mergeRowTypes)

val stuctFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
StructField(thisHeader, rootType, nullable = true)
Expand All @@ -42,10 +47,11 @@ private[csv] object InferSchema {
StructType(stuctFields)
}

private def inferRowType(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
private def inferRowType(nullValue: String)
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
var i = 0
while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing.
rowSoFar(i) = inferField(rowSoFar(i), next(i))
rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue)
i+=1
}
rowSoFar
Expand All @@ -67,8 +73,10 @@ private[csv] object InferSchema {
* Infer type of string field. Given known type Double, and a string "1", there is no
* point checking if it is an Int, as the final type must be Double or higher.
*/
private[csv] def inferField(typeSoFar: DataType, field: String): DataType = {
if (field == null || field.isEmpty) {
private[csv] def inferField(typeSoFar: DataType,
field: String,
nullValue: String = ""): DataType = {
if (field == null || field.isEmpty || field == nullValue) {
typeSoFar
} else {
typeSoFar match {
Expand Down
11 changes: 9 additions & 2 deletions src/main/scala/com/databricks/spark/csv/util/TypeCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,15 @@ object TypeCast {
datum: String,
castType: DataType,
nullable: Boolean = true,
treatEmptyValuesAsNulls: Boolean = false): Any = {
if (datum == "" && nullable && (!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls)){
treatEmptyValuesAsNulls: Boolean = false,
nullValue: String = ""): Any = {
// if nullValue is not an empty string, don't require treatEmptyValuesAsNulls
// to be set to true
val nullValueIsNotEmpty = nullValue != ""
if (datum == nullValue &&
nullable &&
(!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls || nullValueIsNotEmpty)
){
null
} else {
castType match {
Expand Down
4 changes: 4 additions & 0 deletions src/test/resources/null_null_numbers.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name,age
alice,35
bob,null
null,24
4 changes: 4 additions & 0 deletions src/test/resources/null_slashn_numbers.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name,age
alice,35
bob,\N
\N,24
32 changes: 32 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
val carsAltFile = "src/test/resources/cars-alternative.csv"
val carsUnbalancedQuotesFile = "src/test/resources/cars-unbalanced-quotes.csv"
val nullNumbersFile = "src/test/resources/null-numbers.csv"
val nullNullNumbersFile = "src/test/resources/null_null_numbers.csv"
val nullSlashNNumbersFile = "src/test/resources/null_slashn_numbers.csv"
val emptyFile = "src/test/resources/empty.csv"
val ageFile = "src/test/resources/ages.csv"
val escapeFile = "src/test/resources/escape.csv"
Expand Down Expand Up @@ -572,6 +574,36 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
assert(results(2).toSeq === Seq("", 24))
}

test("DSL test nullable fields with user defined null value of \"null\"") {
val results = new CsvParser()
.withSchema(StructType(List(StructField("name", StringType, false),
StructField("age", IntegerType, true))))
.withUseHeader(true)
.withParserLib(parserLib)
.withNullValue("null")
.csvFile(sqlContext, nullNullNumbersFile)
.collect()

assert(results.head.toSeq === Seq("alice", 35))
assert(results(1).toSeq === Seq("bob", null))
assert(results(2).toSeq === Seq("null", 24))
}

test("DSL test nullable fields with user defined null value of \"\\N\"") {
val results = new CsvParser()
.withSchema(StructType(List(StructField("name", StringType, false),
StructField("age", IntegerType, true))))
.withUseHeader(true)
.withParserLib(parserLib)
.withNullValue("\\N")
.csvFile(sqlContext, nullSlashNNumbersFile)
.collect()

assert(results.head.toSeq === Seq("alice", 35))
assert(results(1).toSeq === Seq("bob", null))
assert(results(2).toSeq === Seq("\\N", 24))
}

test("Commented lines in CSV data") {
val results: Array[Row] = new CsvParser()
.withDelimiter(',')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ class InferSchemaSuite extends FunSuite {
assert(InferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType)
}

test("Null fields are handled properly when a nullValue is specified") {
assert(InferSchema.inferField(NullType, "null", "null") == NullType)
assert(InferSchema.inferField(StringType, "null", "null") == StringType)
assert(InferSchema.inferField(LongType, "null", "null") == LongType)
assert(InferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType)
assert(InferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
assert(InferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
}

test("String fields types are inferred correctly from other types") {
assert(InferSchema.inferField(LongType, "1.0") == DoubleType)
assert(InferSchema.inferField(LongType, "test") == StringType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,13 @@ class TypeCastSuite extends FunSuite {
assert(TypeCast.castTo("1,00", FloatType) == 1.0)
assert(TypeCast.castTo("1,00", DoubleType) == 1.0)
}

test("Can handle mapping user specified nullValues") {
assert(TypeCast.castTo("null", StringType, true, false, "null") == null)
assert(TypeCast.castTo("\\N", ByteType, true, false, "\\N") == null)
assert(TypeCast.castTo("", ShortType, true, false) == null)
assert(TypeCast.castTo("null", StringType, true, true, "null") == null)
assert(TypeCast.castTo("", StringType, true, false, "") == "")
assert(TypeCast.castTo("", StringType, true, true, "") == null)
}
}