From cbc92397f38aa3482bd54731c656fc60de095633 Mon Sep 17 00:00:00 2001 From: Luis Fallas Avendano Date: Thu, 26 Sep 2024 17:22:11 -0600 Subject: [PATCH] Add missing Java methods --- .../java/com/snowflake/snowpark_java/Row.java | 12 ++++++++++++ .../snowpark_java/types/StructType.java | 12 ++++++++++++ src/main/scala/com/snowflake/snowpark/Row.scala | 6 +++--- .../snowflake/snowpark/types/StructType.scala | 2 +- .../snowflake/snowpark_test/JavaRowSuite.java | 17 +++++++++++++++++ .../com/snowflake/snowpark_test/RowSuite.scala | 8 ++++++++ 6 files changed, 53 insertions(+), 4 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Row.java b/src/main/java/com/snowflake/snowpark_java/Row.java index 87cfa088..07403cb1 100644 --- a/src/main/java/com/snowflake/snowpark_java/Row.java +++ b/src/main/java/com/snowflake/snowpark_java/Row.java @@ -425,6 +425,18 @@ public Row getObject(int index) { return (Row) get(index); } + /** + * Returns the index of the field with the specified name. + * + * @param fieldName the name of the field. + * @return the index of the specified field. + * @throws UnsupportedOperationException if schema information is not available. + * @since 1.15.0 + */ + public int fieldIndex(String fieldName) { + return this.scalaRow.fieldIndex(fieldName); + } + /** * Returns the value at the specified column index and casts it to the desired type {@code T}. * diff --git a/src/main/java/com/snowflake/snowpark_java/types/StructType.java b/src/main/java/com/snowflake/snowpark_java/types/StructType.java index 63998e3d..2b738129 100644 --- a/src/main/java/com/snowflake/snowpark_java/types/StructType.java +++ b/src/main/java/com/snowflake/snowpark_java/types/StructType.java @@ -58,6 +58,18 @@ private static com.snowflake.snowpark.types.StructField[] toScalaFieldsArray( return result; } + /** + * Return the index of the specified field. + * + * @param fieldName the name of the field. + * @return the index of the field with the specified name. + * @throws IllegalArgumentException if the given field name does not exist in the schema. + * @since 1.15.0 + */ + public int fieldIndex(String fieldName) { + return this.scalaStructType.fieldIndex(fieldName); + } + /** * Retrieves the names of StructField. * diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala index c3810f19..8fb0df20 100644 --- a/src/main/scala/com/snowflake/snowpark/Row.scala +++ b/src/main/scala/com/snowflake/snowpark/Row.scala @@ -373,15 +373,15 @@ class Row protected (values: Array[Any], schema: Option[StructType]) extends Ser /** * Returns the index of the field with the specified name. * - * @param name the name of the field. + * @param fieldName the name of the field. * @return the index of the specified field. * @throws UnsupportedOperationException if schema information is not available. * @since 1.15.0 */ - def fieldIndex(name: String): Int = { + def fieldIndex(fieldName: String): Int = { var schema = this.schema.getOrElse( throw new UnsupportedOperationException("Cannot get field index for row without schema")) - schema.fieldIndex(name) + schema.fieldIndex(fieldName) } /** diff --git a/src/main/scala/com/snowflake/snowpark/types/StructType.scala b/src/main/scala/com/snowflake/snowpark/types/StructType.scala index a8d77023..d985a98d 100644 --- a/src/main/scala/com/snowflake/snowpark/types/StructType.scala +++ b/src/main/scala/com/snowflake/snowpark/types/StructType.scala @@ -109,7 +109,7 @@ case class StructType(fields: Array[StructField] = Array()) * Return the index of the specified field. * * @param fieldName the name of the field. - * @returns the index of the field with the specified name. + * @return the index of the field with the specified name. * @throws IllegalArgumentException if the given field name does not exist in the schema. * @since 1.15.0 */ diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index bb072675..88294511 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -629,4 +629,21 @@ public void getAsWithFieldName() { UnsupportedOperationException.class, () -> rowWithoutSchema.getAs("NonExistingColumn", Integer.class)); } + + @Test + public void fieldIndex() { + StructType schema = + StructType.create( + new StructField("EmpName", DataTypes.StringType), + new StructField("NumVal", DataTypes.IntegerType)); + + Row[] data = {Row.create("abcd", 10), Row.create("efgh", 20)}; + + DataFrame df = getSession().createDataFrame(data, schema); + Row row = df.collect()[0]; + + assert (row.fieldIndex("EmpName") == 0); + assert (row.fieldIndex("NumVal") == 1); + assertThrows(IllegalArgumentException.class, () -> row.fieldIndex("NonExistingColumn")); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala index 2ebafd0c..c92b4c7e 100644 --- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala @@ -423,6 +423,14 @@ class RowSuite extends SNTestBase { rowWithoutSchema.getAs[Integer]("NonExistingColumn")); } + test("fieldIndex") { + val schema = + StructType(Seq(StructField("EmpName", StringType), StructField("NumVal", IntegerType))) + assert(schema.fieldIndex("EmpName") == 0) + assert(schema.fieldIndex("NumVal") == 1) + assertThrows[IllegalArgumentException](schema.fieldIndex("NonExistingColumn")) + } + test("hashCode") { val row1 = Row(1, 2, 3) val row2 = Row("str", null, 3)