Skip to content

Commit

Permalink
[SIT-2214] Add support for com.snowflake.snowpark.Row.getAs(String)
Browse files Browse the repository at this point in the history
Adds an implementation for the `Row.getAs(String)` for retreving values
from a `Row` using field name.
  • Loading branch information
sfc-gh-lfallasavendano committed Sep 26, 2024
1 parent 6a9edb1 commit 152b98c
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 53 deletions.
32 changes: 32 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Row.java
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,38 @@ public <T> T getAs(int index, Class<T> clazz)
return (T) get(index);
}

/**
* Returns the field value for the specified field name and casts it to the desired type {@code
* T}.
*
* <p>Example:
*
* <pre>{@code
* StructType schema =
* StructType.create(
* new StructField("name", DataTypes.StringType),
* new StructField("val", DataTypes.IntegerType));
* Row[] data = { Row.create("Alice", 1) };
* DataFrame df = session.createDataFrame(data, schema);
* Row row = df.collect()[0];
*
* row.getAs("name", String.class); // Returns "Alice" as a String
* row.getAs("val", Integer.class); // Returns 1 as an Int
* }</pre>
*
* @param fieldName the name of the field within the row.
* @param clazz the {@code Class} object representing the type {@code T}.
* @param <T> the expected type of the value for the specified field name.
* @return the field value for the specified field name cast to type {@code T}.
* @throws ClassCastException if the value of the field cannot be cast to type {@code T}.
* @throws IllegalArgumentException if the name of the field is not part of the row schema.
* @throws UnsupportedOperationException if the schema information is not available.
* @since 1.15.0
*/
public <T> T getAs(String fieldName, Class<T> clazz) {
return this.getAs(this.scalaRow.fieldIndex(fieldName), clazz);
}

/**
* Generates a string value to represent the content of this row.
*
Expand Down
59 changes: 52 additions & 7 deletions src/main/scala/com/snowflake/snowpark/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.snowflake.snowpark

import java.sql.{Date, Time, Timestamp}
import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.types.{Geography, Geometry, Variant}
import com.snowflake.snowpark.types.{Geography, Geometry, StructType, Variant}

import scala.reflect.ClassTag
import scala.util.hashing.MurmurHash3
Expand All @@ -16,27 +16,30 @@ object Row {
* Returns a [[Row]] based on the given values.
* @since 0.1.0
*/
def apply(values: Any*): Row = new Row(values.toArray)
def apply(values: Any*): Row = new Row(values.toArray, None)

/**
* Return a [[Row]] based on the values in the given Seq.
* @since 0.1.0
*/
def fromSeq(values: Seq[Any]): Row = new Row(values.toArray)
def fromSeq(values: Seq[Any]): Row = new Row(values.toArray, None)

/**
* Return a [[Row]] based on the values in the given Array.
* @since 0.2.0
*/
def fromArray(values: Array[Any]): Row = new Row(values)
def fromArray(values: Array[Any]): Row = new Row(values, None)

private[snowpark] def fromSeqWithSchema(values: Seq[Any], schema: Option[StructType]): Row =
new Row(values.toArray, schema)

private[snowpark] def fromMap(map: Map[String, Any]): Row =
new SnowflakeObject(map)
}

private[snowpark] class SnowflakeObject private[snowpark] (
private[snowpark] val map: Map[String, Any])
extends Row(map.values.toArray) {
extends Row(map.values.toArray, None) {
override def toString: String = convertValueToString(this)
}

Expand All @@ -47,7 +50,7 @@ private[snowpark] class SnowflakeObject private[snowpark] (
* @groupname utl Utility Functions
* @since 0.1.0
*/
class Row protected (values: Array[Any]) extends Serializable {
class Row protected (values: Array[Any], schema: Option[StructType]) extends Serializable {

/**
* Converts this [[Row]] to a Seq
Expand Down Expand Up @@ -89,7 +92,7 @@ class Row protected (values: Array[Any]) extends Serializable {
* @since 0.1.0
* @group utl
*/
def copy(): Row = new Row(values)
def copy(): Row = new Row(values, schema)

/**
* Returns a clone of this row object. Alias of [[copy]]
Expand Down Expand Up @@ -367,6 +370,48 @@ class Row protected (values: Array[Any]) extends Serializable {
getAs[Map[T, U]](index)
}

/**
* Returns the index of the field with the specified name.
*
* @param name the name of the field.
* @return the index of the specified field.
* @throws UnsupportedOperationException if schema information is not available.
* @since 1.15.0
*/
def fieldIndex(name: String): Int = {
var schema = this.schema.getOrElse(
throw new UnsupportedOperationException("Cannot get field index for row without schema"))
schema.fieldIndex(name)
}

/**
* Returns the value for the specified field name and casts it to the desired type `T`.
*
* Example:
*
* {{{
* val schema =
* StructType(Seq(StructField("name", StringType), StructField("value", IntegerType)))
* val data = Seq(Row("Alice", 1))
* val df = session.createDataFrame(data, schema)
* val row = df.collect()(0)
*
* row.getAs[String]("name") // Returns "Alice" as a String
* row.getAs[Int]("value") // Returns 1 as an Int
* }}}
*
* @param fieldName the name of the field within the row.
* @tparam T the expected type of the value for the specified field name.
* @return the value for the specified field name cast to type `T`.
* @throws ClassCastException if the value of the field cannot be cast to type `T`.
* @throws IllegalArgumentException if the name of the field is not part of the row schema.
* @throws UnsupportedOperationException if the schema information is not available.
* @group getter
* @since 1.15.0
*/
def getAs[T](fieldName: String)(implicit classTag: ClassTag[T]): T =
getAs[T](fieldIndex(fieldName))

/**
* Returns the value at the specified column index and casts it to the desired type `T`.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ private[snowpark] class ServerConnection(
val data = statement.getResultSet

val schema = ServerConnection.convertResultMetaToAttribute(data.getMetaData)
val schemaOption = Some(StructType.fromAttributes(schema))

lazy val geographyOutputFormat = getParameterValue(ParameterUtils.GeographyOutputFormat)
lazy val geometryOutputFormat = getParameterValue(ParameterUtils.GeometryOutputFormat)
Expand All @@ -343,53 +344,55 @@ private[snowpark] class ServerConnection(
private def readNext(): Unit = {
_hasNext = data.next()
_currentRow = if (_hasNext) {
Row.fromSeq(schema.zipWithIndex.map {
case (attribute, index) =>
val resultIndex: Int = index + 1
val resultSetExt = SnowflakeResultSetExt(data)
if (resultSetExt.isNull(resultIndex)) {
null
} else {
attribute.dataType match {
case VariantType => data.getString(resultIndex)
case _: StructuredArrayType | _: StructuredMapType | _: StructType =>
resultSetExt.getObject(resultIndex)
case ArrayType(StringType) => data.getString(resultIndex)
case MapType(StringType, StringType) => data.getString(resultIndex)
case StringType => data.getString(resultIndex)
case _: DecimalType => data.getBigDecimal(resultIndex)
case DoubleType => data.getDouble(resultIndex)
case FloatType => data.getFloat(resultIndex)
case BooleanType => data.getBoolean(resultIndex)
case BinaryType => data.getBytes(resultIndex)
case DateType => data.getDate(resultIndex)
case TimeType => data.getTime(resultIndex)
case ByteType => data.getByte(resultIndex)
case IntegerType => data.getInt(resultIndex)
case LongType => data.getLong(resultIndex)
case TimestampType => data.getTimestamp(resultIndex)
case ShortType => data.getShort(resultIndex)
case GeographyType =>
geographyOutputFormat match {
case "GeoJSON" => Geography.fromGeoJSON(data.getString(resultIndex))
case _ =>
throw ErrorMessage.MISC_UNSUPPORTED_GEOGRAPHY_FORMAT(
geographyOutputFormat)
}
case GeometryType =>
geometryOutputFormat match {
case "GeoJSON" => Geometry.fromGeoJSON(data.getString(resultIndex))
case _ =>
throw ErrorMessage.MISC_UNSUPPORTED_GEOMETRY_FORMAT(
geometryOutputFormat)
}
case _ =>
// ArrayType, StructType, MapType
throw new UnsupportedOperationException(
s"Unsupported type: ${attribute.dataType}")
Row.fromSeqWithSchema(
schema.zipWithIndex.map {
case (attribute, index) =>
val resultIndex: Int = index + 1
val resultSetExt = SnowflakeResultSetExt(data)
if (resultSetExt.isNull(resultIndex)) {
null
} else {
attribute.dataType match {
case VariantType => data.getString(resultIndex)
case _: StructuredArrayType | _: StructuredMapType | _: StructType =>
resultSetExt.getObject(resultIndex)
case ArrayType(StringType) => data.getString(resultIndex)
case MapType(StringType, StringType) => data.getString(resultIndex)
case StringType => data.getString(resultIndex)
case _: DecimalType => data.getBigDecimal(resultIndex)
case DoubleType => data.getDouble(resultIndex)
case FloatType => data.getFloat(resultIndex)
case BooleanType => data.getBoolean(resultIndex)
case BinaryType => data.getBytes(resultIndex)
case DateType => data.getDate(resultIndex)
case TimeType => data.getTime(resultIndex)
case ByteType => data.getByte(resultIndex)
case IntegerType => data.getInt(resultIndex)
case LongType => data.getLong(resultIndex)
case TimestampType => data.getTimestamp(resultIndex)
case ShortType => data.getShort(resultIndex)
case GeographyType =>
geographyOutputFormat match {
case "GeoJSON" => Geography.fromGeoJSON(data.getString(resultIndex))
case _ =>
throw ErrorMessage.MISC_UNSUPPORTED_GEOGRAPHY_FORMAT(
geographyOutputFormat)
}
case GeometryType =>
geometryOutputFormat match {
case "GeoJSON" => Geometry.fromGeoJSON(data.getString(resultIndex))
case _ =>
throw ErrorMessage.MISC_UNSUPPORTED_GEOMETRY_FORMAT(
geometryOutputFormat)
}
case _ =>
// ArrayType, StructType, MapType
throw new UnsupportedOperationException(
s"Unsupported type: ${attribute.dataType}")
}
}
}
})
},
schemaOption)
} else {
// After all rows are consumed, close the statement to release resource
close()
Expand Down
18 changes: 18 additions & 0 deletions src/main/scala/com/snowflake/snowpark/types/StructType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ case class StructType(fields: Array[StructField] = Array())
extends DataType
with Seq[StructField] {

private lazy val fieldPositions = scala.collection.immutable
.SortedMap(fields.zipWithIndex.map(tuple => (tuple._1.name -> tuple._2)): _*)(
scala.math.Ordering.comparatorToOrdering(String.CASE_INSENSITIVE_ORDER))

/**
* Returns the total number of [[StructField]]
* @since 0.1.0
Expand Down Expand Up @@ -101,6 +105,20 @@ case class StructType(fields: Array[StructField] = Array())
nameToField(name).getOrElse(
throw new IllegalArgumentException(s"$name does not exits. Names: ${names.mkString(", ")}"))

/**
* Return the index of the specified field.
*
* @param fieldName the name of the field.
* @returns the index of the field with the specified name.
* @throws IllegalArgumentException if the given field name does not exist in the schema.
* @since 1.15.0
*/
def fieldIndex(fieldName: String): Int = {
fieldPositions.getOrElse(
fieldName,
throw new IllegalArgumentException("Field " + fieldName + " does not exist"))
}

protected[snowpark] def toAttributes: Seq[Attribute] = {
/*
* When user provided schema is used in a SnowflakePlan, we have to
Expand Down
27 changes: 27 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -602,4 +602,31 @@ public void getAsWithStructuredArray() {
},
getSession());
}

@Test
public void getAsWithFieldName() {
StructType schema =
StructType.create(
new StructField("EmpName", DataTypes.StringType),
new StructField("NumVal", DataTypes.IntegerType));

Row[] data = {Row.create("abcd", 10), Row.create("efgh", 20)};

DataFrame df = getSession().createDataFrame(data, schema);
Row row = df.collect()[0];

assert (row.getAs("EmpName", String.class) == row.getAs(0, String.class));
assert (row.getAs("EmpName", String.class).charAt(3) == 'd');
assert (row.getAs("NumVal", Integer.class) == row.getAs(1, Integer.class));

assert (row.getAs("EMPNAME", String.class) == row.getAs(0, String.class));

assertThrows(
IllegalArgumentException.class, () -> row.getAs("NonExistingColumn", Integer.class));

Row rowWithoutSchema = Row.create(40, "Alice");
assertThrows(
UnsupportedOperationException.class,
() -> rowWithoutSchema.getAs("NonExistingColumn", Integer.class));
}
}
19 changes: 19 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,25 @@ class RowSuite extends SNTestBase {
}
}

test("getAs with field name") {
val schema =
StructType(Seq(StructField("EmpName", StringType), StructField("NumVal", IntegerType)))
val df = session.createDataFrame(Seq(Row("abcd", 10), Row("efgh", 20)), schema)
val row = df.collect()(0)

assert(row.getAs[String]("EmpName") == row.getAs[String](0))
assert(row.getAs[String]("EmpName").charAt(3) == 'd')
assert(row.getAs[Int]("NumVal") == row.getAs[Int](1))

assert(row.getAs[String]("EMPNAME") == row.getAs[String](0))

assertThrows[IllegalArgumentException](row.getAs[String]("NonExistingColumn"))

val rowWithoutSchema = Row(40, "Alice")
assertThrows[UnsupportedOperationException](
rowWithoutSchema.getAs[Integer]("NonExistingColumn"));
}

test("hashCode") {
val row1 = Row(1, 2, 3)
val row2 = Row("str", null, 3)
Expand Down

0 comments on commit 152b98c

Please sign in to comment.