Skip to content

Commit

Permalink
Create an overload for com.snowflake.snowpark.functions.sum (#84)
Browse files Browse the repository at this point in the history
* Create an overload for com.snowflake.snowpark.functions.sum

* Update comments and rename param

* Fix comments
  • Loading branch information
sfc-gh-aherreraaguilar authored Mar 8, 2024
1 parent e968927 commit 55aaf09
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 4 deletions.
13 changes: 11 additions & 2 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,7 @@ public static Column stddev_pop(Column col) {
}

/**
* Returns the sum of non-NULL records in a group. You can use the DISTINCT keyword to compute the
* sum of unique non-null values. If all records inside a group are NULL, the function returns
* Returns the sum of non-NULL records in a group. If all records inside a group are NULL, the function returns
* NULL.
*
* @since 0.9.0
Expand All @@ -372,6 +371,16 @@ public static Column sum(Column col) {
return new Column(com.snowflake.snowpark.functions.sum(col.toScalaColumn()));
}

/**
* Returns the sum of non-NULL records in a group. If all records inside a group are NULL, the function returns
* NULL.
*
* @since 1.12.0
* @param colName The input column name
* @return The result column
*/
public static Column sum(String colName) { return sum(col(colName)); }

/**
* Returns the sum of non-NULL distinct records in a group. You can use the DISTINCT keyword to
* compute the sum of unique non-null values. If all records inside a group are NULL, the function
Expand Down
14 changes: 12 additions & 2 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -348,15 +348,25 @@ object functions {
def stddev_pop(e: Column): Column = builtin("stddev_pop")(e)

/**
* Returns the sum of non-NULL records in a group. You can use the DISTINCT keyword to compute
* the sum of unique non-null values. If all records inside a group are NULL,
* Returns the sum of non-NULL records in a group. If all records inside a group are NULL,
* the function returns NULL.
*
* @group agg_func
* @since 0.1.0
*/
def sum(e: Column): Column = builtin("sum")(e)

/**
* Returns the sum of non-NULL records in a group. If all records inside a group are NULL,
* the function returns NULL.
*
* @group agg_func
* @since 1.12.0
* @param colName The input column name
* @return The result column
*/
def sum(colName: String): Column = sum(col(colName))

/**
* Returns the sum of non-NULL distinct records in a group. You can use the DISTINCT keyword to
* compute the sum of unique non-null values. If all records inside a group are NULL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ public void sum() {

checkAnswer(df.groupBy(df.col("a")).agg(Functions.sum(df.col("a"))), expected, false);

checkAnswer(df.groupBy(df.col("a")).agg(Functions.sum("a")), expected, false);

Row[] expected1 = {Row.create(3, 3), Row.create(2, 2), Row.create(1, 1)};
checkAnswer(df.groupBy(df.col("a")).agg(Functions.sum_distinct(df.col("a"))), expected1, false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ trait FunctionSuite extends TestData {
Seq(Row(3, 6), Row(2, 4), Row(1, 1)),
sort = false)

checkAnswer(
duplicatedNumbers.groupBy("A").agg(sum("A")),
Seq(Row(3, 6), Row(2, 4), Row(1, 1)),
sort = false)

checkAnswer(
duplicatedNumbers.groupBy("A").agg(sum_distinct(col("A"))),
Seq(Row(3, 3), Row(2, 2), Row(1, 1)),
Expand Down

0 comments on commit 55aaf09

Please sign in to comment.