From c0f9785688937c789472a3b47380fd25821412c6 Mon Sep 17 00:00:00 2001 From: Bing Li <63471091+sfc-gh-bli@users.noreply.github.com> Date: Mon, 13 May 2024 14:20:17 -0700 Subject: [PATCH] SNOW-1333980 Support Read Structured Type Values (#104) * fix type name * array v2 * map type * object * update JDBC * support structed array * map type * support map type * struct type * structure type * tmp * tmp * support date * time * complete show * decimal * reorg * add test * add test * show string * result test * handle object * show string * tmp * fix test * fix getObject * fix test * row get array * support map * add scala doc * java array * java map * java api * fix checker * remove useless code * fix java test * fix time zone * fix time zone --- .../java/com/snowflake/snowpark_java/Row.java | 75 +++++++-- .../com/snowflake/snowpark/DataFrame.scala | 26 +++- .../scala/com/snowflake/snowpark/Row.scala | 73 ++++++++- .../snowpark/internal/JavaDataTypeUtils.scala | 2 + .../snowpark/internal/ServerConnection.scala | 145 +++++++++++++++++- .../snowflake/snowpark_test/JavaRowSuite.java | 102 ++++++++++-- .../code_verification/JavaScalaAPISuite.scala | 7 +- .../snowflake/snowpark/APIInternalSuite.scala | 114 ++++++++++++++ .../snowpark_test/DataTypeSuite.scala | 137 ++++++++++++++++- 9 files changed, 633 insertions(+), 48 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Row.java b/src/main/java/com/snowflake/snowpark_java/Row.java index cff3489e..0921a0d6 100644 --- a/src/main/java/com/snowflake/snowpark_java/Row.java +++ b/src/main/java/com/snowflake/snowpark_java/Row.java @@ -132,26 +132,42 @@ public int hashCode() { * @return The value of the column at the given index */ public Object get(int index) { - Object result = scalaRow.get(index); - if (result instanceof com.snowflake.snowpark.types.Variant) { - return InternalUtils.createVariant((com.snowflake.snowpark.types.Variant) result); - } else if (result instanceof com.snowflake.snowpark.types.Geography) { - return Geography.fromGeoJSON(((com.snowflake.snowpark.types.Geography) result).asGeoJSON()); - } else if (result instanceof com.snowflake.snowpark.types.Geometry) { - return Geometry.fromGeoJSON(result.toString()); - } else if (result instanceof com.snowflake.snowpark.types.Variant[]) { + return toJavaValue(scalaRow.get(index)); + } + + private static Object toJavaValue(Object value) { + if (value instanceof com.snowflake.snowpark.types.Variant) { + return InternalUtils.createVariant((com.snowflake.snowpark.types.Variant) value); + } else if (value instanceof com.snowflake.snowpark.types.Geography) { + return Geography.fromGeoJSON(((com.snowflake.snowpark.types.Geography) value).asGeoJSON()); + } else if (value instanceof com.snowflake.snowpark.types.Geometry) { + return Geometry.fromGeoJSON(value.toString()); + } else if (value instanceof com.snowflake.snowpark.types.Variant[]) { com.snowflake.snowpark.types.Variant[] scalaVariantArray = - (com.snowflake.snowpark.types.Variant[]) result; + (com.snowflake.snowpark.types.Variant[]) value; Variant[] resultArray = new Variant[scalaVariantArray.length]; for (int idx = 0; idx < scalaVariantArray.length; idx++) { resultArray[idx] = InternalUtils.createVariant(scalaVariantArray[idx]); } return resultArray; - } else if (result instanceof scala.collection.immutable.Map) { - return JavaUtils.scalaMapToJavaWithVariantConversion( - (scala.collection.immutable.Map) result); + } else if (value instanceof scala.collection.immutable.Map) { + scala.collection.immutable.Map input = (scala.collection.immutable.Map) value; + Map result = new HashMap<>(); + // key is either Long or String, no need to convert values + input.foreach(x -> result.put(x._1, toJavaValue(x._2))); + return result; + } else if (value instanceof Object[]) { + Object[] arr = (Object[]) value; + List result = new ArrayList<>(arr.length); + for (Object x : arr) { + result.add(toJavaValue(x)); + } + return result; + } else if (value instanceof com.snowflake.snowpark.Row) { + return new Row((com.snowflake.snowpark.Row) value); + } else { + return value; } - return result; } /** @@ -376,6 +392,39 @@ public Map getMapOfVariant(int index) { return result; } + /** + * Retrieves the value of the column at the given index as a list of Object. + * + * @param index The index of target column + * @return A list of Object + * @since 1.13.0 + */ + public List getList(int index) { + return (List) get(index); + } + + /** + * Retrieves the value of the column at the given index as a Java Map + * + * @param index The index of target column + * @return A Java Map + * @since 1.13.0 + */ + public Map getMap(int index) { + return (Map) get(index); + } + + /** + * Retrieves the value of the column at the given index as a Row + * + * @param index The index of target column + * @return A Row + * @since 1.13.0 + */ + public Row getObject(int index) { + return (Row) get(index); + } + /** * Generates a string value to represent the content of this row. * diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 7601f286..91f0021b 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -2369,17 +2369,33 @@ class DataFrame private[snowpark] ( lines } + def convertValueToString(value: Any): String = + value match { + case map: Map[_, _] => + map + .map { + case (key, value) => s"${convertValueToString(key)}:${convertValueToString(value)}" + } + .mkString("{", ",", "}") + case ba: Array[Byte] => s"'${DatatypeConverter.printHexBinary(ba)}'" + case bytes: Array[java.lang.Byte] => + s"'${DatatypeConverter.printHexBinary(bytes.map(_.toByte))}'" + case arr: Array[String] => + arr.mkString("[", ",", "]") + case arr: Array[_] => + arr.map(convertValueToString).mkString("[", ",", "]") + case arr: java.sql.Array => + arr.getArray().asInstanceOf[Array[_]].map(convertValueToString).mkString("[", ",", "]") + case _ => value.toString + } + val body: Seq[Seq[String]] = result.flatMap(row => { // Value may contain multiple lines val lines: Seq[Seq[String]] = row.toSeq.zipWithIndex.map { case (value, index) => val texts: Seq[String] = if (value != null) { - val str = value match { - case ba: Array[Byte] => s"'${DatatypeConverter.printHexBinary(ba)}'" - case _ => value.toString - } // if the result contains multiple lines, split result string - splitLines(str) + splitLines(convertValueToString(value)) } else { Seq("NULL") } diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala index 84419555..a1dc5aef 100644 --- a/src/main/scala/com/snowflake/snowpark/Row.scala +++ b/src/main/scala/com/snowflake/snowpark/Row.scala @@ -4,6 +4,7 @@ import java.sql.{Date, Time, Timestamp} import com.snowflake.snowpark.internal.ErrorMessage import com.snowflake.snowpark.types.{Geography, Geometry, Variant} +import scala.reflect.ClassTag import scala.util.hashing.MurmurHash3 /** @@ -28,6 +29,15 @@ object Row { * @since 0.2.0 */ def fromArray(values: Array[Any]): Row = new Row(values) + + private[snowpark] def fromMap(map: Map[String, Any]): Row = + new SnowflakeObject(map) +} + +private[snowpark] class SnowflakeObject private[snowpark] ( + private[snowpark] val map: Map[String, Any]) + extends Row(map.values.toArray) { + override def toString: String = convertValueToString(this) } /** @@ -37,7 +47,7 @@ object Row { * @groupname utl Utility Functions * @since 0.1.0 */ -class Row private (values: Array[Any]) extends Serializable { +class Row protected (values: Array[Any]) extends Serializable { /** * Converts this [[Row]] to a Seq @@ -325,6 +335,61 @@ class Row private (values: Array[Any]) extends Serializable { def getMapOfVariant(index: Int): Map[String, Variant] = new Variant(getString(index)).asMap() + /** + * Returns the Snowflake Object value at the given index as a Row value. + * + * @since 1.13.0 + * @group getter + */ + def getObject(index: Int): Row = + getAs[Row](index) + + /** + * Returns the value of the column at the given index as a Seq value. + * + * @since 1.13.0 + * @group getter + */ + def getSeq[T](index: Int): Seq[T] = { + val result = getAs[Array[_]](index) + result.map { + case x: T => x + } + } + + /** + * Returns the value of the column at the given index as a Map value. + * + * @since 1.13.0 + * @group getter + */ + def getMap[T, U](index: Int): Map[T, U] = { + getAs[Map[T, U]](index) + } + + protected def convertValueToString(value: Any): String = + value match { + case null => "null" + case map: Map[_, _] => + map + .map { + case (key, value) => s"${convertValueToString(key)}:${convertValueToString(value)}" + } + .mkString("Map(", ",", ")") + case binary: Array[Byte] => s"Binary(${binary.mkString(",")})" + case strValue: String => s""""$strValue"""" + case arr: Array[_] => + arr.map(convertValueToString).mkString("Array(", ",", ")") + case obj: SnowflakeObject => + obj.map + .map { + case (key, value) => + s"$key:${convertValueToString(value)}" + } + .mkString("Object(", ",", ")") + case other => other.toString + } + /** * Returns a string value to represent the content of this row * @since 0.1.0 @@ -332,11 +397,7 @@ class Row private (values: Array[Any]) extends Serializable { */ override def toString: String = values - .map { - case null => "null" - case binary: Array[Byte] => s"Binary(${binary.mkString(",")})" - case other => other.toString - } + .map(convertValueToString) .mkString("Row[", ",", "]") private def getAs[T](index: Int): T = get(index).asInstanceOf[T] diff --git a/src/main/scala/com/snowflake/snowpark/internal/JavaDataTypeUtils.scala b/src/main/scala/com/snowflake/snowpark/internal/JavaDataTypeUtils.scala index 79f24e2d..1ac271c9 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/JavaDataTypeUtils.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/JavaDataTypeUtils.scala @@ -47,6 +47,8 @@ object JavaDataTypeUtils { case TimestampType => JDataTypes.TimestampType case TimeType => JDataTypes.TimeType case VariantType => JDataTypes.VariantType + case st: StructType => + com.snowflake.snowpark_java.types.InternalUtils.createStructType(st) } def javaTypeToScalaType(jDataType: JDataType): DataType = diff --git a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala index 0c6769c9..92728eaf 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala @@ -1,7 +1,7 @@ package com.snowflake.snowpark.internal import java.io.{Closeable, InputStream} -import java.sql.{PreparedStatement, ResultSetMetaData, SQLException, Statement} +import java.sql.{PreparedStatement, ResultSet, ResultSetMetaData, SQLException, Statement} import java.time.LocalDateTime import com.snowflake.snowpark.{ MergeBuilder, @@ -27,16 +27,30 @@ import com.snowflake.snowpark.internal.Utils.PackageNameDelimiter import com.snowflake.snowpark.internal.analyzer.{Attribute, Query, SnowflakePlan} import net.snowflake.client.jdbc.{ FieldMetadata, + SnowflakeBaseResultSet, SnowflakeConnectString, SnowflakeConnectionV1, SnowflakeReauthenticationRequest, SnowflakeResultSet, SnowflakeResultSetMetaData, - SnowflakeStatement + SnowflakeResultSetV1, + SnowflakeStatement, + SnowflakeUtil } import com.snowflake.snowpark.types._ -import net.snowflake.client.core.QueryStatus +import net.snowflake.client.core.{ + ArrowSqlInput, + ColumnTypeHelper, + QueryStatus, + SFArrowResultSet, + SFBaseResultSet +} +import net.snowflake.client.jdbc.internal.apache.arrow.vector.util.{ + JsonStringArrayList, + JsonStringHashMap +} +import java.util import scala.collection.mutable import scala.reflect.runtime.universe.TypeTag import scala.collection.JavaConverters._ @@ -314,6 +328,7 @@ private[snowpark] class ServerConnection( statement: Statement): (CloseableIterator[Row], StructType) = withValidConnection { val data = statement.getResultSet + val schema = ServerConnection.convertResultMetaToAttribute(data.getMetaData) lazy val geographyOutputFormat = getParameterValue(ParameterUtils.GeographyOutputFormat) @@ -331,12 +346,14 @@ private[snowpark] class ServerConnection( Row.fromSeq(schema.zipWithIndex.map { case (attribute, index) => val resultIndex: Int = index + 1 - data.getObject(resultIndex) // check null value, JDBC standard - if (data.wasNull()) { + val resultSetExt = SnowflakeResultSetExt(data) + if (resultSetExt.isNull(resultIndex)) { null } else { attribute.dataType match { case VariantType => data.getString(resultIndex) + case _: StructuredArrayType | _: StructuredMapType | _: StructType => + resultSetExt.getObject(resultIndex) case ArrayType(StringType) => data.getString(resultIndex) case MapType(StringType, StringType) => data.getString(resultIndex) case StringType => data.getString(resultIndex) @@ -999,5 +1016,123 @@ private[snowpark] class ServerConnection( throw e } } +} +private[snowflake] object SnowflakeResultSetExt { + def apply(data: ResultSet): SnowflakeResultSetExt = + data match { + case sfResultSet: SnowflakeResultSetV1 => new SnowflakeResultSetExt(sfResultSet) + case other => + throw new IllegalArgumentException( + s"Unsupported JDBC ResultSet Object: ${other.getClass.getSimpleName}") + } +} +// Extends the Snowflake ResultSet to access private fields +private[snowflake] class SnowflakeResultSetExt(data: SnowflakeResultSetV1) { + // used by structured types + // the baseResultSet is not always SFArrowResultSet + lazy val baseResultSet: SFBaseResultSet = { + val sfResultSet = data.asInstanceOf[SnowflakeBaseResultSet] + val baseResultSetField = + classOf[SnowflakeBaseResultSet].getDeclaredField("sfBaseResultSet") + baseResultSetField.setAccessible(true) + baseResultSetField.get(sfResultSet).asInstanceOf[SFBaseResultSet] + } + + lazy val arrowResultSet: SFArrowResultSet = + baseResultSet.asInstanceOf[SFArrowResultSet] + + private def getObjectInternal(index: Int): Any = { + SnowflakeUtil.mapSFExceptionToSQLException(() => baseResultSet.getObject(index)) + } + + def isNull(index: Int): Boolean = + getObjectInternal(index) == null + + def getObject(index: Int): Any = { + val meta = data.getMetaData + // convert meta to field meta + val field = new FieldMetadata( + meta.getColumnName(index), + meta.getColumnTypeName(index), + meta.getColumnType(index), + true, + 0, + 0, + 0, + false, + null, + meta + .asInstanceOf[SnowflakeResultSetMetaData] + .getColumnFields(index)) + convertToSnowparkValue(getObjectInternal(index), field) + } + + private def convertToSnowparkValue(value: Any, meta: FieldMetadata): Any = { + meta.getTypeName match { + // semi structured + case "ARRAY" if meta.getFields.isEmpty => value.toString + // structured array + case "ARRAY" if meta.getFields.size() == 1 => + value + .asInstanceOf[util.ArrayList[_]] + .toArray + .map(v => convertToSnowparkValue(v, meta.getFields.get(0))) + // semi-structured + case "OBJECT" if meta.getFields.isEmpty => value.toString + // structured map, Map type has two fields, and both field names are empty + case "OBJECT" if meta.getFields.size() == 2 && meta.getFields.get(0).getName.isEmpty => + value match { + // nested structured maps are JsonStringArrayValues + case subMap: JsonStringArrayList[_] => + subMap.asScala.map { + case mapValue: JsonStringHashMap[_, _] => + convertToSnowparkValue(mapValue.get("key"), meta.getFields.get(0)) -> + convertToSnowparkValue(mapValue.get("value"), meta.getFields.get(1)) + }.toMap + case map: util.HashMap[_, _] => + map.asScala.map { + case (key, value) => + convertToSnowparkValue(key, meta.getFields.get(0)) -> + convertToSnowparkValue(value, meta.getFields.get(1)) + }.toMap + } + // object, object's field name can't be empty + case "OBJECT" => + value match { + case arrowSqlInput: ArrowSqlInput => + convertToSnowparkValue(arrowSqlInput.getInput, meta) + case map: java.util.Map[String, _] => + Row.fromMap( + map.asScala.toList + .zip(meta.getFields.asScala) + .map { + case ((key, value), metadata) => + key -> convertToSnowparkValue(value, metadata) + } + .toMap) + } + + case "NUMBER" if meta.getType == java.sql.Types.BIGINT => + value match { + case str: String => str.toLong // number key in structured map + case bd: java.math.BigDecimal => bd.toBigInteger.longValue() + } + case "DOUBLE" | "BOOLEAN" | "BINARY" | "NUMBER" => value + case "VARCHAR" | "VARIANT" => value.toString // Text to String + case "DATE" => + arrowResultSet.convertToDate(value, null) + case "TIME" => + arrowResultSet.convertToTime(value, meta.getScale) + case _ + if meta.getType == java.sql.Types.TIMESTAMP || + meta.getType == java.sql.Types.TIMESTAMP_WITH_TIMEZONE => + val columnSubType = meta.getType + val columnType = ColumnTypeHelper + .getColumnType(columnSubType, arrowResultSet.getSession) + arrowResultSet.convertToTimestamp(value, columnType, columnSubType, null, meta.getScale) + case _ => + throw new UnsupportedOperationException(s"Unsupported type: ${meta.getTypeName}") + } + } } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index bdde050e..98349a9d 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -1,9 +1,8 @@ package com.snowflake.snowpark_test; +import com.snowflake.snowpark_java.DataFrame; import com.snowflake.snowpark_java.Row; -import com.snowflake.snowpark_java.types.Geography; -import com.snowflake.snowpark_java.types.Geometry; -import com.snowflake.snowpark_java.types.Variant; +import com.snowflake.snowpark_java.types.*; import java.math.BigDecimal; import java.sql.Date; import java.sql.Time; @@ -14,7 +13,7 @@ import java.util.Map; import org.junit.Test; -public class JavaRowSuite { +public class JavaRowSuite extends TestBase { @Test public void createList() { @@ -77,7 +76,7 @@ public void getters1() { assert row.getDouble(7) == 6.6; assert row.getString(8).equals("a"); - assert row.toString().equals("Row[null,true,1,2,3,4,5.5,6.6,a]"); + assert row.toString().equals("Row[null,true,1,2,3,4,5.5,6.6,\"a\"]"); } @Test @@ -122,7 +121,7 @@ public void getter3() { assert map.get("a").asInt() == 1; assert map.get("b").asInt() == 2; - assert row.toString().equals("Row[[1,2,3],{\"a\":1,\"b\":2}]"); + assert row.toString().equals("Row[\"[1,2,3]\",\"{\"a\":1,\"b\":2}\"]"); } @Test @@ -161,10 +160,6 @@ public void testArray() { assert values[0].asString().equals("a") && values[1].asString().equals("b") && values[2].asString().equals("null"); - // get() - String[] getValues = (String[]) row.get(0); - assert getValues.length == 3; - assert getValues[0].equals("a") && getValues[1].equals("b") && getValues[2] == null; // Variant Array Variant[] variantArray = {new Variant("a"), new Variant("b"), null}; @@ -191,7 +186,7 @@ public void testEmptyArray() { // Empty String Array String[] emptyStringArray = new String[0]; row = Row.create((Object) emptyStringArray); - assert ((String[]) row.get(0)).length == 0; + assert row.getList(0).isEmpty(); assert row.getVariant(0).asArray().length == 0; // Empty Variant Array @@ -217,11 +212,6 @@ public void testSpecialArray() { assert values[0].asString().equals("null") && values[1].asString().equals("null") && values[2].asString().equals("null"); - // get() - String[] getValues = (String[]) row.get(0); - assert getValues.length == 3; - assert getValues[0] == null && getValues[1] == null && getValues[2] == null; - // Variant Array with all values to be null Variant[] variantArrayAllNull = new Variant[3]; variantArrayAllNull[0] = null; @@ -346,4 +336,84 @@ public void testSpecialMap() { && getValues2.get("b") == null && getValues2.get("c") == null; } + + @Test + public void testGetList() { + DataFrame df = getSession().sql("select [[1, 2], [3]]::ARRAY(ARRAY(NUMBER)) AS arr1"); + StructType schema = df.schema(); + assert schema.get(0).dataType() instanceof ArrayType; + assert ((ArrayType) schema.get(0).dataType()).getElementType() instanceof ArrayType; + + List list = df.collect()[0].getList(0); + assert list.size() == 2; + + List list1 = (List) list.get(0); + List list2 = (List) list.get(1); + + assert list1.size() == 2; + assert list2.size() == 1; + + assert (Long) list1.get(0) == 1; + assert (Long) list1.get(1) == 2; + assert (Long) list2.get(0) == 3; + } + + @Test + public void testGetMap() { + DataFrame df = + getSession() + .sql( + "select {'1':{'a':1,'b':2},'2':{'c':3}} :: MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map"); + StructType schema = df.schema(); + assert schema.get(0).dataType() instanceof MapType; + assert ((MapType) schema.get(0).dataType()).getKeyType() instanceof LongType; + assert ((MapType) schema.get(0).dataType()).getValueType() instanceof MapType; + + Map map = df.collect()[0].getMap(0); + Map map1 = (Map) map.get(1L); + assert map1.size() == 2; + assert (Long) map1.get("a") == 1; + assert (Long) map1.get("b") == 2; + + Map map2 = (Map) map.get(2L); + assert map2.size() == 1; + assert (Long) map2.get("c") == 3; + } + + @Test + public void testGetRow() { + DataFrame df = + getSession() + .sql( + "select {'a': {'b': {'d':10,'c': 'txt'}}} :: OBJECT(a OBJECT(b OBJECT(c VARCHAR, d NUMBER))) as obj1"); + StructType schema = df.schema(); + schema.printTreeString(); + assert schema.get(0).dataType() instanceof StructType; + assert schema.get(0).name().equals("OBJ1"); + StructType sub1 = (StructType) schema.get(0).dataType(); + assert sub1.size() == 1; + assert sub1.get(0).dataType() instanceof StructType; + assert sub1.get(0).name().equals("A"); + StructType sub2 = (StructType) sub1.get(0).dataType(); + assert sub2.size() == 1; + assert sub2.get(0).dataType() instanceof StructType; + assert sub2.get(0).name().equals("B"); + StructType sub3 = (StructType) sub2.get(0).dataType(); + assert sub3.size() == 2; + assert sub3.get(0).dataType() instanceof StringType; + assert sub3.get(0).name().equals("C"); + assert sub3.get(1).dataType() instanceof LongType; + assert sub3.get(1).name().equals("D"); + + Row[] rows1 = df.collect(); + assert rows1.length == 1; + Row row1 = rows1[0].getObject(0); + assert row1.size() == 1; + Row row2 = row1.getObject(0); + assert row2.size() == 1; + Row row3 = row2.getObject(0); + assert row3.size() == 2; + assert row3.getString(0).equals("txt"); + assert row3.getLong(1) == 10; + } } diff --git a/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala b/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala index d4394054..c0f1ed3e 100644 --- a/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala +++ b/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala @@ -317,8 +317,11 @@ class JavaScalaAPISuite extends FunSuite { class1Only = Set(), class2Only = Set("fromArray", "fromSeq", "length" // Java API has "size" ) ++ scalaCaseClassFunctions, - class1To2NameMap = - Map("toList" -> "toSeq", "create" -> "apply", "getListOfVariant" -> "getSeqOfVariant"))) + class1To2NameMap = Map( + "toList" -> "toSeq", + "create" -> "apply", + "getListOfVariant" -> "getSeqOfVariant", + "getList" -> "getSeq"))) } // Java SaveMode is an Enum, diff --git a/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala b/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala index 29111bf9..9dadb8d8 100644 --- a/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala @@ -393,6 +393,120 @@ class APIInternalSuite extends TestData { |""".stripMargin) } + test("show structured types mix") { + val query = + // scalastyle:off + """SELECT + | NULL :: OBJECT(a VARCHAR, b NUMBER) as object1, + | 1 as NUM1, + | 'abc' as STR1, + | {'a':1,'b':2} :: MAP(VARCHAR, NUMBER) as map1, + | {'1':'a','2':'b'} :: MAP(NUMBER, VARCHAR) as map2, + | {'a': 1, 'b': [1,2,3,4]} :: OBJECT(a NUMBER, b ARRAY(NUMBER)) as object2, + | [1, 2, 3]::ARRAY(NUMBER) AS arr1, + | [1.1, 2.2, 3.3]::ARRAY(FLOAT) AS arr2, + | {'a1':{'b':2}, 'a2':{'b':3}} :: MAP(VARCHAR, OBJECT(b NUMBER)) as map1 + |""".stripMargin + // scalastyle:on + val df = session.sql(query) + // scalastyle:off + assert( + df.showString(10) == + """-------------------------------------------------------------------------------------------------------------------------------------------------- + ||"OBJECT1" |"NUM1" |"STR1" |"MAP1" |"MAP2" |"OBJECT2" |"ARR1" |"ARR2" |"MAP1" | + |-------------------------------------------------------------------------------------------------------------------------------------------------- + ||NULL |1 |abc |{b:2,a:1} |{2:b,1:a} |Object(a:1,b:Array(1,2,3,4)) |[1,2,3] |[1.1,2.2,3.3] |{a1:Object(b:2),a2:Object(b:3)} | + |-------------------------------------------------------------------------------------------------------------------------------------------------- + |""".stripMargin) + // scalastyle:on + } + + test("show object") { + val query = + // scalastyle:off + """SELECT + | {'b': 1, 'a': '22'} :: OBJECT(a VARCHAR, b NUMBER) as object1, + | {'a': 1, 'b': [1,2,3,4]} :: OBJECT(a NUMBER, b ARRAY(NUMBER)) as object2, + | {'a': 1, 'b': [1,2,3,4], 'c': {'1':'a'}} :: OBJECT(a VARCHAR, b ARRAY(NUMBER), c MAP(NUMBER, VARCHAR)) as object3, + | {'a': {'b': {'a':10,'c': 1}}} :: OBJECT(a OBJECT(b OBJECT(c NUMBER, a NUMBER))) as object4, + | [{'a':1,'b':2},{'b':3,'a':4}] :: ARRAY(OBJECT(a NUMBER, b NUMBER)) as arr1, + | {'a1':{'b':2}, 'a2':{'b':3}} :: MAP(VARCHAR, OBJECT(b NUMBER)) as map1 + |""".stripMargin + // scalastyle:on + + val df = session.sql(query) + // scalastyle:off + assert( + df.showString(10) == + """---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + ||"OBJECT1" |"OBJECT2" |"OBJECT3" |"OBJECT4" |"ARR1" |"MAP1" | + |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + ||Object(a:"22",b:1) |Object(a:1,b:Array(1,2,3,4)) |Object(a:"1",b:Array(1,2,3,4),c:Map(1:"a")) |Object(a:Object(b:Object(c:1,a:10))) |[Object(a:1,b:2),Object(a:4,b:3)] |{a1:Object(b:2),a2:Object(b:3)} | + |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + |""".stripMargin) + // scalastyle:on + + } + + test("show structured map") { + val query = + """SELECT + | {'a':1,'b':2} :: MAP(VARCHAR, NUMBER) as map1, + | {'1':'a','2':'b'} :: MAP(NUMBER, VARCHAR) as map2, + | {'1':[1,2,3],'2':[4,5,6]} :: MAP(NUMBER, ARRAY(NUMBER)) as map3, + | {'1':{'a':1,'b':2},'2':{'c':3}} :: MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map4, + | [{'a':1,'b':2},{'c':3}] :: ARRAY(MAP(VARCHAR, NUMBER)) as map5, + | {'a':1,'b':2} :: OBJECT as map0 + |""".stripMargin + + val df = session.sql(query) + // scalastyle:off + assert( + df.showString(10) == + """--------------------------------------------------------------------------------------------------------- + ||"MAP1" |"MAP2" |"MAP3" |"MAP4" |"MAP5" |"MAP0" | + |--------------------------------------------------------------------------------------------------------- + ||{b:2,a:1} |{2:b,1:a} |{2:[4,5,6],1:[1,2,3]} |{2:{c:3},1:{a:1,b:2}} |[{a:1,b:2},{c:3}] |{ | + || | | | | | "a": 1, | + || | | | | | "b": 2 | + || | | | | |} | + |--------------------------------------------------------------------------------------------------------- + |""".stripMargin) + // scalastyle:on + } + + test("show structured array") { + val query = + """SELECT + | [1, 2, 3]::ARRAY(NUMBER) AS arr1, + | [1.1, 2.2, 3.3]::ARRAY(FLOAT) AS arr2, + | [true, false]::ARRAY(BOOLEAN) AS arr3, + | ['a', 'b']::ARRAY(VARCHAR) AS arr4, + | [parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr5, + | [TO_BINARY('SNOW', 'utf-8')]::ARRAY(BINARY) AS arr6, + | [TO_DATE('2013-05-17')]::ARRAY(DATE) AS arr7, + | [[1,2]]::ARRAY(ARRAY) AS arr9, + | [OBJECT_CONSTRUCT('name', 1)]::ARRAY(OBJECT) AS arr10, + | [[1, 2], [3, 4]]::ARRAY(ARRAY(NUMBER)) AS arr11, + | [1.234::DECIMAL(13, 5)]::ARRAY(DECIMAL(13,5)) as arr12, + | [time '10:03:56']::ARRAY(TIME) as arr21 + |""".stripMargin + val df = session.sql(query) + // scalastyle:off + assert( + df.showString(10) == + """--------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + ||"ARR1" |"ARR2" |"ARR3" |"ARR4" |"ARR5" |"ARR6" |"ARR7" |"ARR9" |"ARR10" |"ARR11" |"ARR12" |"ARR21" | + |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + ||[1,2,3] |[1.1,2.2,3.3] |[true,false] |[a,b] |[1970-12-25 11:06:40.0] |['534E4F57'] |[2013-05-17] |[[ |[{ |[[1,2],[3,4]] |[1.23400] |[10:03:56] | + || | | | | | | | 1, | "name": 1 | | | | + || | | | | | | | 2 |}] | | | | + || | | | | | | |]] | | | | | + |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + |""".stripMargin) + // scalastyle:on + } + // dataframe test("withColumn function uses * instead of full column name list") { import session.implicits._ diff --git a/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala index b3803ac3..ad7cb57a 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala @@ -4,6 +4,9 @@ import com.snowflake.snowpark.{Row, SNTestBase, TestUtils} import com.snowflake.snowpark.types._ import com.snowflake.snowpark.functions._ +import java.sql.{Date, Time, Timestamp} +import java.util.TimeZone + // Test DataTypes out of com.snowflake.snowpark package. class DataTypeSuite extends SNTestBase { test("IntegralType") { @@ -181,6 +184,138 @@ class DataTypeSuite extends SNTestBase { |""".stripMargin) } + test("read Structured Array") { + val oldTimeZone = TimeZone.getDefault + try { + // Need to set default time zone because the expected result has timestamp data + TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")) + val query = + """SELECT + | [1, 2, 3]::ARRAY(NUMBER) AS arr1, + | [1.1, 2.2, 3.3]::ARRAY(FLOAT) AS arr2, + | [true, false]::ARRAY(BOOLEAN) AS arr3, + | ['a', 'b']::ARRAY(VARCHAR) AS arr4, + | [parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr5, + | [TO_BINARY('SNOW', 'utf-8')]::ARRAY(BINARY) AS arr6, + | [TO_DATE('2013-05-17')]::ARRAY(DATE) AS arr7, + | [[1,2]]::ARRAY(ARRAY) AS arr9, + | [OBJECT_CONSTRUCT('name', 1)]::ARRAY(OBJECT) AS arr10, + | [[1, 2], [3, 4]]::ARRAY(ARRAY(NUMBER)) AS arr11, + | [1.234::DECIMAL(13, 5)]::ARRAY(DECIMAL(13,5)) as arr12, + | [time '10:03:56']::ARRAY(TIME) as arr21 + |""".stripMargin + val df = session.sql(query) + assert(df.collect().head.getSeq[Double](1).isInstanceOf[Seq[Double]]) + checkAnswer( + df, + Row( + Array(1L, 2L, 3L), + Array(1.1, 2.2, 3.3), + Array(true, false), + Array("a", "b"), + Array(new Timestamp(31000000000L)), + Array(Array(83.toByte, 78.toByte, 79.toByte, 87.toByte)), + Array(Date.valueOf("2013-05-17")), + Array("[\n 1,\n 2\n]"), + Array("{\n \"name\": 1\n}"), + Array(Array(1L, 2L), Array(3L, 4L)), + Array(java.math.BigDecimal.valueOf(1.234)), + Array(Time.valueOf("10:03:56")))) + } finally { + TimeZone.setDefault(oldTimeZone) + } + } + + test("read Structured Map") { + val query = + """SELECT + | {'a':1,'b':2} :: MAP(VARCHAR, NUMBER) as map1, + | {'1':'a','2':'b'} :: MAP(NUMBER, VARCHAR) as map2, + | {'1':[1,2,3],'2':[4,5,6]} :: MAP(NUMBER, ARRAY(NUMBER)) as map3, + | {'1':{'a':1,'b':2},'2':{'c':3}} :: MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map4, + | [{'a':1,'b':2},{'c':3}] :: ARRAY(MAP(VARCHAR, NUMBER)) as map5, + | {'a':1,'b':2} :: OBJECT as map0 + |""".stripMargin + val df = session.sql(query) + checkAnswer( + df, + Row( + Map("b" -> 2, "a" -> 1), + Map(2 -> "b", 1 -> "a"), + Map(2 -> Array(4L, 5L, 6L), 1 -> Array(1L, 2L, 3L)), + Map(2 -> Map("c" -> 3), 1 -> Map("a" -> 1, "b" -> 2)), + Array(Map("a" -> 1, "b" -> 2), Map("c" -> 3)), + "{\n \"a\": 1,\n \"b\": 2\n}")) + } + + test("read object") { + val query = + // scalastyle:off + """SELECT + | {'b': 1, 'a': '22'} :: OBJECT(a VARCHAR, b NUMBER) as object1, + | {'a': 1, 'b': [1,2,3,4], 'c': true} :: OBJECT(a NUMBER, b ARRAY(NUMBER), c BOOLEAN) as object2, + | {'a': 1, 'b': [1,2,3,4], 'c': {'1':'a'}} :: OBJECT(a NUMBER, b ARRAY(NUMBER), c MAP(NUMBER, VARCHAR)) as object3, + | {'a': {'b': {'a':10,'c': 1}}} :: OBJECT(a OBJECT(b OBJECT(c NUMBER, a NUMBER))) as object4, + | [{'a':1,'b':2},{'b':3,'a':4}] :: ARRAY(OBJECT(a NUMBER, b NUMBER)) as arr1, + | {'a1':{'b':2}, 'a2':{'b':3}} :: MAP(VARCHAR, OBJECT(b NUMBER)) as map1 + |""".stripMargin + // scalastyle:on + + val df = session.sql(query) + val result = df.collect() + assert(result.length == 1) + val row = result.head + assert(row.getObject(0).length == 2) + assert(row.getObject(0).getString(0) == "22") + assert(row.getObject(0).getLong(1) == 1L) + + assert(row.getObject(1).length == 3) + assert(row.getObject(1).getLong(0) == 1L) + assert(row.getObject(1).getSeq(1).length == 4) + val arr1 = row.getObject(1).getSeq[Long](1) + assert(arr1.isInstanceOf[Seq[Long]]) + assert(arr1.sameElements(Array(1L, 2L, 3L, 4L))) + assert(row.getObject(1).getBoolean(2)) + + assert(row.getObject(2).length == 3) + assert(row.getObject(2).getLong(0) == 1L) + assert(row.getObject(2).getSeq(1).length == 4) + val arr2 = row.getObject(2).getSeq[Long](1) + assert(arr2.isInstanceOf[Seq[Long]]) + assert(arr2.sameElements(Array(1L, 2L, 3L, 4L))) + val map1 = row.getObject(2).getMap[Long, String](2) + assert(map1 == Map(1L -> "a")) + + assert(row.getObject(3).length == 1) + val row1 = row.getObject(3).getObject(0) + assert(row1.length == 1) + val row2 = row1.getObject(0) + assert(row2.length == 2) + assert(row2.getInt(0) == 1) + assert(row2.getInt(1) == 10) + + assert(row.getSeq[Row](4).length == 2) + val arr3 = row.getSeq[Row](4) + val row3 = arr3.head + val row4 = arr3(1) + assert(row3.length == 2) + assert(row3.getInt(0) == 1) + assert(row3.getInt(1) == 2) + assert(row4.length == 2) + assert(row4.getInt(0) == 4) + assert(row4.getInt(1) == 3) + + assert(row.getMap[String, Row](5).size == 2) + val map2 = row.getMap[String, Row](5) + val row5 = map2("a1") + val row6 = map2("a2") + assert(row5.length == 1) + assert(row5.getInt(0) == 2) + assert(row6.length == 1) + assert(row6.getInt(0) == 3) + + } + test("ArrayType v2") { val query = """SELECT | [1, 2, 3]::ARRAY(NUMBER) AS arr1, @@ -194,7 +329,7 @@ class DataTypeSuite extends SNTestBase { | [[1,2]]::ARRAY(ARRAY) AS arr9, | [OBJECT_CONSTRUCT('name', 1)]::ARRAY(OBJECT) AS arr10, | [[1, 2], [3, 4]]::ARRAY(ARRAY(NUMBER)) AS arr11, - | [1, 2, 3] AS arr0;""".stripMargin + | [1, 2, 3] AS arr0""".stripMargin val df = session.sql(query) assert( TestUtils.treeString(df.schema, 0) ==