Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-802269 Add ordering and size function for scala and java modules #133

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static com.snowflake.snowpark.internal.OpenTelemetry.javaUDF;

import com.snowflake.snowpark.functions;
import com.snowflake.snowpark.internal.JavaUtils;
import com.snowflake.snowpark_java.types.DataType;
import com.snowflake.snowpark_java.udf.*;
Expand Down Expand Up @@ -3880,6 +3881,82 @@ public static Column listagg(Column col) {
return new Column(com.snowflake.snowpark.functions.listagg(col.toScalaColumn()));
}

/**
* Returns a Column expression with values sorted in descending order.
*
* <p>Example: order column values in descending
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(a)");
* df.sort(Functions.desc("a")).show();
* -------
* |"A" |
* -------
* |3 |
* |2 |
* |1 |
* -------
* }</pre>
*
* @since 1.14.0
* @param name The input column name
* @return Column object ordered in descending manner.
*/
public static Column desc(String name) {
return new Column(functions.desc(name));
}

/**
* Returns a Column expression with values sorted in ascending order.
*
* <p>Example: order column values in ascending
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
* df.sort(Functions.asc("a")).show();
* -------
* |"A" |
* -------
* |1 |
* |2 |
* |3 |
* -------
* }</pre>
*
* @since 1.14.0
* @param name The input column name
* @return Column object ordered in ascending manner.
*/
public static Column asc(String name) {
return new Column(functions.asc(name));
}

/**
* Returns the size of the input ARRAY.
sfc-gh-gmahadevan marked this conversation as resolved.
Show resolved Hide resolved
*
* <p>If the specified column contains a VARIANT value that contains an ARRAY, the size of the
* ARRAY is returned; otherwise, NULL is returned if the value is not an ARRAY.
*
* <p>Example: calculate size of the array in a column
*
* <pre>{@code
* DataFrame df = getSession().sql("select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)");
* df.select(Functions.size(Functions.col("arr"))).show();
* -------------------------
* |"ARRAY_SIZE(""ARR"")" |
* -------------------------
* |3 |
* -------------------------
* }</pre>
*
* @since 1.14.0
* @param col The input column name
* @return size of the input ARRAY.
*/
public static Column size(Column col) {
return array_size(col);
}

/**
* Calls a user-defined function (UDF) by name.
*
Expand Down
67 changes: 67 additions & 0 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3140,6 +3140,73 @@ object functions {
*/
def listagg(col: Column): Column = listagg(col, "", isDistinct = false)

/**
* Returns a Column expression with values sorted in descending order.
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id")
* df.sort(desc("id")).show()
*
* --------
* |"ID" |
* --------
* |3 |
* |2 |
* |1 |
* --------
* }}}
*
* @since 1.14.0
* @param colName Column name.
* @return Column object ordered in a descending manner.
*/
def desc(colName: String): Column = col(colName).desc

/**
* Returns a Column expression with values sorted in ascending order.
* Example:
* {{{
* val df = session.createDataFrame(Seq(3, 2, 1)).toDF("id")
* df.sort(asc("id")).show()
*
* --------
* |"ID" |
* --------
* |1 |
* |2 |
* |3 |
* --------
* }}}
* @since 1.14.0
* @param colName Column name.
* @return Column object ordered in an ascending manner.
*/
def asc(colName: String): Column = col(colName).asc

/**
* Returns the size of the input ARRAY.
*
* If the specified column contains a VARIANT value that contains an ARRAY, the size of the ARRAY
* is returned; otherwise, NULL is returned if the value is not an ARRAY.
*
* Example:
* {{{
* val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id")
* df.select(size(col("id"))).show()
*
* ------------------------
* |"ARRAY_SIZE(""ID"")" |
* ------------------------
* |3 |
* ------------------------
* }}}
*
* @since 1.14.0
* @param c Column to get the size.
* @return Size of array column.
*/
def size(c: Column): Column = array_size(c)

/**
* Invokes a built-in snowflake function with the specified name and arguments.
* Arguments can be of two types
Expand Down
26 changes: 26 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -2764,4 +2764,30 @@ public void any_value() {
assert result.length == 1;
assert result[0].getInt(0) == 1 || result[0].getInt(0) == 2 || result[0].getInt(0) == 3;
}

@Test
public void test_asc() {
DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
Row[] expected = {Row.create(1), Row.create(2), Row.create(3)};

checkAnswer(df.sort(Functions.asc("a")), expected, false);
}

@Test
public void test_desc() {
DataFrame df = getSession().sql("select * from values(2),(1),(3) as t(a)");
Row[] expected = {Row.create(3), Row.create(2), Row.create(1)};

checkAnswer(df.sort(Functions.desc("a")), expected, false);
}

@Test
public void test_size() {
DataFrame df = getSession()
.sql(
"select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)");
Row[] expected = {Row.create(3)};

checkAnswer(df.select(Functions.size(Functions.col("arr"))), expected, false);
}
}
29 changes: 29 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2178,6 +2178,35 @@ trait FunctionSuite extends TestData {
sort = false)
}

test("desc column order") {
val input = Seq(1, 2, 3).toDF("data")
val expected = Seq(3, 2, 1).toDF("data")

val inputStr = Seq("a", "b", "c").toDF("dataStr")
val expectedStr = Seq("c", "b", "a").toDF("dataStr")

checkAnswer(input.sort(desc("data")), expected, sort = false)
checkAnswer(inputStr.sort(desc("dataStr")), expectedStr, sort = false)
}

test("asc column order") {
val input = Seq(3, 2, 1).toDF("data")
val expected = Seq(1, 2, 3).toDF("data")

val inputStr = Seq("c", "b", "a").toDF("dataStr")
val expectedStr = Seq("a", "b", "c").toDF("dataStr")

checkAnswer(input.sort(asc("data")), expected, sort = false)
checkAnswer(inputStr.sort(asc("dataStr")), expectedStr, sort = false)
}

test("column array size") {

val input = Seq(Array(1, 2, 3)).toDF("size")
val expected = Seq((3)).toDF("size")
checkAnswer(input.select(size(col("size"))), expected, sort = false)
}

}

class EagerFunctionSuite extends FunctionSuite with EagerSession
Expand Down
Loading