Skip to content

Commit

Permalink
Fix SparkSQL failures caused by presence of non-selected columns of U…
Browse files Browse the repository at this point in the history
…DT type in the table.

Refactor CassandraRelation class - less String/Regex magic, more type-safety.
  • Loading branch information
pkolaczk committed Nov 22, 2014
1 parent f684f83 commit cd28871
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 31 deletions.
1 change: 1 addition & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
* Added JavaTypeConverter to make is easy to implement custom TypeConverter in Java (#429)
* Fix SparkSQL failures caused by presence of non-selected columns of UDT type in the table.

1.1.0 rc 1
* Fixed problem with setting a batch size in bytes (#435)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class CassandraSQLSpec extends FlatSpec with Matchers with SharedEmbeddedCassand
session.execute("CREATE TABLE IF NOT EXISTS sql_test.test_collection (a INT, b SET<INT>, c MAP<INT, INT>, PRIMARY KEY (a))")
session.execute("INSERT INTO sql_test.test_collection (a, b, c) VALUES (1, {1,2,3}, {1:2, 2:3})")

session.execute("CREATE TYPE sql_test.address (street text, city text, zip int)")
session.execute("CREATE TABLE IF NOT EXISTS sql_test.udts(key INT PRIMARY KEY, name text, addr frozen<address>)")
session.execute("INSERT INTO sql_test.udts(key, name, addr) VALUES (1, 'name', {street: 'Some Street', city: 'Paris', zip: 11120})")
}

it should "allow to select all rows" in {
Expand Down Expand Up @@ -329,4 +332,14 @@ class CassandraSQLSpec extends FlatSpec with Matchers with SharedEmbeddedCassand
val result1 = cc.sql("SELECT * FROM test_data_type1").collect()
result1 should have length 1
}

it should "allow to select specified non-UDT columns from a table containing some UDT columns" in {
val cc = new CassandraSQLContext(sc)
cc.setKeyspace("sql_test")
val result = cc.sql("SELECT key, name FROM udts").collect()
result should have length 1
val row = result.head
row.getInt(0) should be(1)
row.getString(1) should be ("name")
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package org.apache.spark.sql.cassandra

import com.datastax.spark.connector
import com.datastax.spark.connector.cql.{ColumnDef, TableDef}
import org.apache.hadoop.hive.ql.stats.StatsSetupConst
import org.apache.spark.sql.SQLContext

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.sql.catalyst

private[cassandra] case class CassandraRelation
(tableDef: TableDef, alias: Option[String])(@transient cc: CassandraSQLContext)
Expand All @@ -20,12 +20,12 @@ private[cassandra] case class CassandraRelation
val columnNameByLowercase = allColumns.map(c => (c.columnName.toLowerCase, c.columnName)).toMap
var projectAttributes = tableDef.allColumns.map(columnToAttribute)

def columnToAttribute(column: ColumnDef): AttributeReference = new AttributeReference(
column.columnName,
ColumnDataType.scalaDataType(column.columnType.scalaTypeName, true),
// Since data can be dumped in randomly with no validation, everything is nullable.
nullable = true
)(qualifiers = tableDef.tableName +: alias.toSeq)
def columnToAttribute(column: ColumnDef): AttributeReference = {
// Since data can be dumped in randomly with no validation, everything is nullable.
val catalystType = ColumnDataType.catalystDataType(column.columnType, nullable = true)
val qualifiers = tableDef.tableName +: alias.toSeq
new AttributeReference(column.columnName, catalystType, nullable = true)(qualifiers = qualifiers)
}

override def output: Seq[Attribute] = projectAttributes

Expand All @@ -39,31 +39,40 @@ private[cassandra] case class CassandraRelation

object ColumnDataType {

implicit class Regex(sc: StringContext) {
def regex = new scala.util.matching.Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*)
}
private val primitiveTypeMap = Map[connector.types.ColumnType[_], catalyst.types.DataType](

connector.types.TextType -> catalyst.types.StringType,
connector.types.AsciiType -> catalyst.types.StringType,
connector.types.VarCharType -> catalyst.types.StringType,

connector.types.BooleanType -> catalyst.types.BooleanType,

connector.types.IntType -> catalyst.types.IntegerType,
connector.types.BigIntType -> catalyst.types.LongType,
connector.types.CounterType -> catalyst.types.LongType,
connector.types.FloatType -> catalyst.types.FloatType,
connector.types.DoubleType -> catalyst.types.DoubleType,

connector.types.VarIntType -> catalyst.types.DecimalType, // no native arbitrary-size integer type
connector.types.DecimalType -> catalyst.types.DecimalType,

private val primitiveTypeMap = Map[String, String](
"String" -> "StringType",
"Int" -> "IntegerType",
"Long" -> "LongType",
"Float" -> "FloatType",
"Double" -> "DoubleType",
"Boolean" -> "BooleanType",
"BigInt" -> "LongType",
"BigDecimal" -> "DecimalType",
"java.util.Date" -> "TimestampType",
"java.net.InetAddress" -> "StringType",
"java.util.UUID" -> "StringType",
"java.nio.ByteBuffer" -> "ByteType"
connector.types.TimestampType -> catalyst.types.TimestampType,
connector.types.InetType -> catalyst.types.StringType,
connector.types.UUIDType -> catalyst.types.StringType,
connector.types.TimeUUIDType -> catalyst.types.StringType,
connector.types.BlobType -> catalyst.types.ByteType,

// TODO: This mapping is useless, it is here only to avoid lookup failure if a table contains a UDT column.
// It is not possible to read UDT columns in SparkSQL now.
connector.types.UserDefinedTypeStub -> catalyst.types.StructType(Seq.empty)
)

def scalaDataType(scalaType: String, containNull: Boolean): DataType = {
scalaType match {
case regex"Set\[(\w+)$dt\]" => DataType("ArrayType(" + primitiveTypeMap(dt) + ", " + containNull + ")")
case regex"Vector\[(\w+)$dt\]" => DataType("ArrayType(" + primitiveTypeMap(dt) + ", " + containNull + ")")
case regex"Map\[(\w+)$key,(\w+)$value\]" => DataType("MapType(" + primitiveTypeMap(key) + "," + primitiveTypeMap(value) + ", " + containNull + ")")
case _ => DataType(primitiveTypeMap(scalaType))
def catalystDataType(cassandraType: connector.types.ColumnType[_], nullable: Boolean): catalyst.types.DataType = {
cassandraType match {
case connector.types.SetType(et) => catalyst.types.ArrayType(primitiveTypeMap(et), nullable)
case connector.types.ListType(et) => catalyst.types.ArrayType(primitiveTypeMap(et), nullable)
case connector.types.MapType(kt, vt) => catalyst.types.MapType(primitiveTypeMap(kt), primitiveTypeMap(vt), nullable)
case _ => primitiveTypeMap(cassandraType)
}
}
}

0 comments on commit cd28871

Please sign in to comment.