diff --git a/src/main/java/com/snowflake/snowpark_java/Row.java b/src/main/java/com/snowflake/snowpark_java/Row.java index 74475e2d..87cfa088 100644 --- a/src/main/java/com/snowflake/snowpark_java/Row.java +++ b/src/main/java/com/snowflake/snowpark_java/Row.java @@ -483,6 +483,38 @@ public T getAs(int index, Class clazz) return (T) get(index); } + /** + * Returns the field value for the specified field name and casts it to the desired type {@code + * T}. + * + *

Example: + * + *

{@code
+   * StructType schema =
+   *     StructType.create(
+   *        new StructField("name", DataTypes.StringType),
+   *        new StructField("val", DataTypes.IntegerType));
+   * Row[] data = { Row.create("Alice", 1) };
+   * DataFrame df = session.createDataFrame(data, schema);
+   * Row row = df.collect()[0];
+   *
+   * row.getAs("name", String.class); // Returns "Alice" as a String
+   * row.getAs("val", Integer.class); // Returns 1 as an Int
+   * }
+ * + * @param fieldName the name of the field within the row. + * @param clazz the {@code Class} object representing the type {@code T}. + * @param the expected type of the value for the specified field name. + * @return the field value for the specified field name cast to type {@code T}. + * @throws ClassCastException if the value of the field cannot be cast to type {@code T}. + * @throws IllegalArgumentException if the name of the field is not part of the row schema. + * @throws UnsupportedOperationException if the schema information is not available. + * @since 1.15.0 + */ + public T getAs(String fieldName, Class clazz) { + return this.getAs(this.scalaRow.fieldIndex(fieldName), clazz); + } + /** * Generates a string value to represent the content of this row. * diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala index 34eb9bbf..c3810f19 100644 --- a/src/main/scala/com/snowflake/snowpark/Row.scala +++ b/src/main/scala/com/snowflake/snowpark/Row.scala @@ -2,7 +2,7 @@ package com.snowflake.snowpark import java.sql.{Date, Time, Timestamp} import com.snowflake.snowpark.internal.ErrorMessage -import com.snowflake.snowpark.types.{Geography, Geometry, Variant} +import com.snowflake.snowpark.types.{Geography, Geometry, StructType, Variant} import scala.reflect.ClassTag import scala.util.hashing.MurmurHash3 @@ -16,19 +16,22 @@ object Row { * Returns a [[Row]] based on the given values. * @since 0.1.0 */ - def apply(values: Any*): Row = new Row(values.toArray) + def apply(values: Any*): Row = new Row(values.toArray, None) /** * Return a [[Row]] based on the values in the given Seq. * @since 0.1.0 */ - def fromSeq(values: Seq[Any]): Row = new Row(values.toArray) + def fromSeq(values: Seq[Any]): Row = new Row(values.toArray, None) /** * Return a [[Row]] based on the values in the given Array. * @since 0.2.0 */ - def fromArray(values: Array[Any]): Row = new Row(values) + def fromArray(values: Array[Any]): Row = new Row(values, None) + + private[snowpark] def fromSeqWithSchema(values: Seq[Any], schema: Option[StructType]): Row = + new Row(values.toArray, schema) private[snowpark] def fromMap(map: Map[String, Any]): Row = new SnowflakeObject(map) @@ -36,7 +39,7 @@ object Row { private[snowpark] class SnowflakeObject private[snowpark] ( private[snowpark] val map: Map[String, Any]) - extends Row(map.values.toArray) { + extends Row(map.values.toArray, None) { override def toString: String = convertValueToString(this) } @@ -47,7 +50,7 @@ private[snowpark] class SnowflakeObject private[snowpark] ( * @groupname utl Utility Functions * @since 0.1.0 */ -class Row protected (values: Array[Any]) extends Serializable { +class Row protected (values: Array[Any], schema: Option[StructType]) extends Serializable { /** * Converts this [[Row]] to a Seq @@ -89,7 +92,7 @@ class Row protected (values: Array[Any]) extends Serializable { * @since 0.1.0 * @group utl */ - def copy(): Row = new Row(values) + def copy(): Row = new Row(values, schema) /** * Returns a clone of this row object. Alias of [[copy]] @@ -367,6 +370,48 @@ class Row protected (values: Array[Any]) extends Serializable { getAs[Map[T, U]](index) } + /** + * Returns the index of the field with the specified name. + * + * @param name the name of the field. + * @return the index of the specified field. + * @throws UnsupportedOperationException if schema information is not available. + * @since 1.15.0 + */ + def fieldIndex(name: String): Int = { + var schema = this.schema.getOrElse( + throw new UnsupportedOperationException("Cannot get field index for row without schema")) + schema.fieldIndex(name) + } + + /** + * Returns the value for the specified field name and casts it to the desired type `T`. + * + * Example: + * + * {{{ + * val schema = + * StructType(Seq(StructField("name", StringType), StructField("value", IntegerType))) + * val data = Seq(Row("Alice", 1)) + * val df = session.createDataFrame(data, schema) + * val row = df.collect()(0) + * + * row.getAs[String]("name") // Returns "Alice" as a String + * row.getAs[Int]("value") // Returns 1 as an Int + * }}} + * + * @param fieldName the name of the field within the row. + * @tparam T the expected type of the value for the specified field name. + * @return the value for the specified field name cast to type `T`. + * @throws ClassCastException if the value of the field cannot be cast to type `T`. + * @throws IllegalArgumentException if the name of the field is not part of the row schema. + * @throws UnsupportedOperationException if the schema information is not available. + * @group getter + * @since 1.15.0 + */ + def getAs[T](fieldName: String)(implicit classTag: ClassTag[T]): T = + getAs[T](fieldIndex(fieldName)) + /** * Returns the value at the specified column index and casts it to the desired type `T`. * diff --git a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala index 92728eaf..a2281925 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala @@ -330,6 +330,7 @@ private[snowpark] class ServerConnection( val data = statement.getResultSet val schema = ServerConnection.convertResultMetaToAttribute(data.getMetaData) + val schemaOption = Some(StructType.fromAttributes(schema)) lazy val geographyOutputFormat = getParameterValue(ParameterUtils.GeographyOutputFormat) lazy val geometryOutputFormat = getParameterValue(ParameterUtils.GeometryOutputFormat) @@ -343,53 +344,55 @@ private[snowpark] class ServerConnection( private def readNext(): Unit = { _hasNext = data.next() _currentRow = if (_hasNext) { - Row.fromSeq(schema.zipWithIndex.map { - case (attribute, index) => - val resultIndex: Int = index + 1 - val resultSetExt = SnowflakeResultSetExt(data) - if (resultSetExt.isNull(resultIndex)) { - null - } else { - attribute.dataType match { - case VariantType => data.getString(resultIndex) - case _: StructuredArrayType | _: StructuredMapType | _: StructType => - resultSetExt.getObject(resultIndex) - case ArrayType(StringType) => data.getString(resultIndex) - case MapType(StringType, StringType) => data.getString(resultIndex) - case StringType => data.getString(resultIndex) - case _: DecimalType => data.getBigDecimal(resultIndex) - case DoubleType => data.getDouble(resultIndex) - case FloatType => data.getFloat(resultIndex) - case BooleanType => data.getBoolean(resultIndex) - case BinaryType => data.getBytes(resultIndex) - case DateType => data.getDate(resultIndex) - case TimeType => data.getTime(resultIndex) - case ByteType => data.getByte(resultIndex) - case IntegerType => data.getInt(resultIndex) - case LongType => data.getLong(resultIndex) - case TimestampType => data.getTimestamp(resultIndex) - case ShortType => data.getShort(resultIndex) - case GeographyType => - geographyOutputFormat match { - case "GeoJSON" => Geography.fromGeoJSON(data.getString(resultIndex)) - case _ => - throw ErrorMessage.MISC_UNSUPPORTED_GEOGRAPHY_FORMAT( - geographyOutputFormat) - } - case GeometryType => - geometryOutputFormat match { - case "GeoJSON" => Geometry.fromGeoJSON(data.getString(resultIndex)) - case _ => - throw ErrorMessage.MISC_UNSUPPORTED_GEOMETRY_FORMAT( - geometryOutputFormat) - } - case _ => - // ArrayType, StructType, MapType - throw new UnsupportedOperationException( - s"Unsupported type: ${attribute.dataType}") + Row.fromSeqWithSchema( + schema.zipWithIndex.map { + case (attribute, index) => + val resultIndex: Int = index + 1 + val resultSetExt = SnowflakeResultSetExt(data) + if (resultSetExt.isNull(resultIndex)) { + null + } else { + attribute.dataType match { + case VariantType => data.getString(resultIndex) + case _: StructuredArrayType | _: StructuredMapType | _: StructType => + resultSetExt.getObject(resultIndex) + case ArrayType(StringType) => data.getString(resultIndex) + case MapType(StringType, StringType) => data.getString(resultIndex) + case StringType => data.getString(resultIndex) + case _: DecimalType => data.getBigDecimal(resultIndex) + case DoubleType => data.getDouble(resultIndex) + case FloatType => data.getFloat(resultIndex) + case BooleanType => data.getBoolean(resultIndex) + case BinaryType => data.getBytes(resultIndex) + case DateType => data.getDate(resultIndex) + case TimeType => data.getTime(resultIndex) + case ByteType => data.getByte(resultIndex) + case IntegerType => data.getInt(resultIndex) + case LongType => data.getLong(resultIndex) + case TimestampType => data.getTimestamp(resultIndex) + case ShortType => data.getShort(resultIndex) + case GeographyType => + geographyOutputFormat match { + case "GeoJSON" => Geography.fromGeoJSON(data.getString(resultIndex)) + case _ => + throw ErrorMessage.MISC_UNSUPPORTED_GEOGRAPHY_FORMAT( + geographyOutputFormat) + } + case GeometryType => + geometryOutputFormat match { + case "GeoJSON" => Geometry.fromGeoJSON(data.getString(resultIndex)) + case _ => + throw ErrorMessage.MISC_UNSUPPORTED_GEOMETRY_FORMAT( + geometryOutputFormat) + } + case _ => + // ArrayType, StructType, MapType + throw new UnsupportedOperationException( + s"Unsupported type: ${attribute.dataType}") + } } - } - }) + }, + schemaOption) } else { // After all rows are consumed, close the statement to release resource close() diff --git a/src/main/scala/com/snowflake/snowpark/types/StructType.scala b/src/main/scala/com/snowflake/snowpark/types/StructType.scala index ff8869df..a8d77023 100644 --- a/src/main/scala/com/snowflake/snowpark/types/StructType.scala +++ b/src/main/scala/com/snowflake/snowpark/types/StructType.scala @@ -40,6 +40,10 @@ case class StructType(fields: Array[StructField] = Array()) extends DataType with Seq[StructField] { + private lazy val fieldPositions = scala.collection.immutable + .SortedMap(fields.zipWithIndex.map(tuple => (tuple._1.name -> tuple._2)): _*)( + scala.math.Ordering.comparatorToOrdering(String.CASE_INSENSITIVE_ORDER)) + /** * Returns the total number of [[StructField]] * @since 0.1.0 @@ -101,6 +105,20 @@ case class StructType(fields: Array[StructField] = Array()) nameToField(name).getOrElse( throw new IllegalArgumentException(s"$name does not exits. Names: ${names.mkString(", ")}")) + /** + * Return the index of the specified field. + * + * @param fieldName the name of the field. + * @returns the index of the field with the specified name. + * @throws IllegalArgumentException if the given field name does not exist in the schema. + * @since 1.15.0 + */ + def fieldIndex(fieldName: String): Int = { + fieldPositions.getOrElse( + fieldName, + throw new IllegalArgumentException("Field " + fieldName + " does not exist")) + } + protected[snowpark] def toAttributes: Seq[Attribute] = { /* * When user provided schema is used in a SnowflakePlan, we have to diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index f8918292..bb072675 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -602,4 +602,31 @@ public void getAsWithStructuredArray() { }, getSession()); } + + @Test + public void getAsWithFieldName() { + StructType schema = + StructType.create( + new StructField("EmpName", DataTypes.StringType), + new StructField("NumVal", DataTypes.IntegerType)); + + Row[] data = {Row.create("abcd", 10), Row.create("efgh", 20)}; + + DataFrame df = getSession().createDataFrame(data, schema); + Row row = df.collect()[0]; + + assert (row.getAs("EmpName", String.class) == row.getAs(0, String.class)); + assert (row.getAs("EmpName", String.class).charAt(3) == 'd'); + assert (row.getAs("NumVal", Integer.class) == row.getAs(1, Integer.class)); + + assert (row.getAs("EMPNAME", String.class) == row.getAs(0, String.class)); + + assertThrows( + IllegalArgumentException.class, () -> row.getAs("NonExistingColumn", Integer.class)); + + Row rowWithoutSchema = Row.create(40, "Alice"); + assertThrows( + UnsupportedOperationException.class, + () -> rowWithoutSchema.getAs("NonExistingColumn", Integer.class)); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala index df87666f..2ebafd0c 100644 --- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala @@ -404,6 +404,25 @@ class RowSuite extends SNTestBase { } } + test("getAs with field name") { + val schema = + StructType(Seq(StructField("EmpName", StringType), StructField("NumVal", IntegerType))) + val df = session.createDataFrame(Seq(Row("abcd", 10), Row("efgh", 20)), schema) + val row = df.collect()(0) + + assert(row.getAs[String]("EmpName") == row.getAs[String](0)) + assert(row.getAs[String]("EmpName").charAt(3) == 'd') + assert(row.getAs[Int]("NumVal") == row.getAs[Int](1)) + + assert(row.getAs[String]("EMPNAME") == row.getAs[String](0)) + + assertThrows[IllegalArgumentException](row.getAs[String]("NonExistingColumn")) + + val rowWithoutSchema = Row(40, "Alice") + assertThrows[UnsupportedOperationException]( + rowWithoutSchema.getAs[Integer]("NonExistingColumn")); + } + test("hashCode") { val row1 = Row(1, 2, 3) val row2 = Row("str", null, 3)