Skip to content

Commit

Permalink
Create an overload for com.snowflake.snowpark.functions.sum
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aherreraaguilar committed Feb 22, 2024
1 parent 15d19c7 commit 71c5824
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,17 @@ public static Column sum(Column col) {
return new Column(com.snowflake.snowpark.functions.sum(col.toScalaColumn()));
}

/**
* Returns the sum of non-NULL records in a group. You can use the DISTINCT keyword to compute the
* sum of unique non-null values. If all records inside a group are NULL, the function returns
* NULL.
*
* @since 0.9.0
* @param str The input string
* @return The result column
*/
public static Column sum(String str) { return sum(col(str)); }

/**
* Returns the sum of non-NULL distinct records in a group. You can use the DISTINCT keyword to
* compute the sum of unique non-null values. If all records inside a group are NULL, the function
Expand Down
12 changes: 12 additions & 0 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,18 @@ object functions {
*/
def sum(e: Column): Column = builtin("sum")(e)

/**
* Returns the sum of non-NULL records in a group. You can use the DISTINCT keyword to compute
* the sum of unique non-null values. If all records inside a group are NULL,
* the function returns NULL.
*
* @group agg_func
* @since 0.1.0
* @param e The input string
* @return The result column
*/
def sum(e: String): Column = sum(col(e))

/**
* Returns the sum of non-NULL distinct records in a group. You can use the DISTINCT keyword to
* compute the sum of unique non-null values. If all records inside a group are NULL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ public void sum() {

checkAnswer(df.groupBy(df.col("a")).agg(Functions.sum(df.col("a"))), expected, false);

checkAnswer(df.groupBy(df.col("a")).agg(Functions.sum("a")), expected, false);

Row[] expected1 = {Row.create(3, 3), Row.create(2, 2), Row.create(1, 1)};
checkAnswer(df.groupBy(df.col("a")).agg(Functions.sum_distinct(df.col("a"))), expected1, false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ trait FunctionSuite extends TestData {
Seq(Row(3, 6), Row(2, 4), Row(1, 1)),
sort = false)

checkAnswer(
duplicatedNumbers.groupBy("A").agg(sum("A")),
Seq(Row(3, 6), Row(2, 4), Row(1, 1)),
sort = false)

checkAnswer(
duplicatedNumbers.groupBy("A").agg(sum_distinct(col("A"))),
Seq(Row(3, 3), Row(2, 2), Row(1, 1)),
Expand Down

0 comments on commit 71c5824

Please sign in to comment.