Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-802269 Added regexp_extract,signum,substring_index,collect_list #135

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3140,6 +3140,112 @@ object functions {
*/
def listagg(col: Column): Column = listagg(col, "", isDistinct = false)

/**
* This function receives a column and extracts the groupIdx from the string
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved
* after applying the exp regex. Returns empty string when the string doesn't
* match and null if the input is null.
*
* This function applies the `case sensitive` and `extract` flags.
* It doesn't apply multiline nor .* matches newlines.
* If these flags need to be applied, use `builtin("REGEXP_SUBSTR")`
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved
* instead and apply the desired flags.
*
* Note: non-greedy tokens such as `.*?` are not supported
* @since 1.12.1
* @param colName Column to apply regex.
* @param exp Regex expression to apply.
* @param grpIdx Group to extract.
* @return Column object.
*/
def regexp_extract(
colName: Column,
exp: String,
position: Int,
Occurences: Int,
grpIdx: Int): Column = {
when(colName.is_null, lit(null))
.otherwise(
coalesce(
builtin("REGEXP_SUBSTR")(
colName,
lit(exp),
lit(position),
lit(Occurences),
lit("ce"),
lit(grpIdx)),
lit("")))
}

/**
* Returns the sign of the given column. Returns either 1 for positive,
* 0 for 0 or
* NaN, -1 for negative and null for null.
* NOTE: if string values are provided snowflake will attempts to cast.
* If it casts correctly, returns the calculation,
* if not an error will be thrown
* @since 1.12.1
* @param e Column to calculate the sign.
* @return Column object.
*/
def signum(colName: Column): Column = {
builtin("SIGN")(colName)
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Returns the sign of the given column. Returns either 1 for positive,
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved
* 0 for 0 or
* NaN, -1 for negative and null for null.
* NOTE: if string values are provided snowflake will attempts to cast.
* If it casts correctly, returns the calculation,
* if not an error will be thrown
* @since 1.12.1
* @param columnName Name of the column to calculate the sign.
* @return Column object.
*/
def signum(columnName: String): Column = {
signum(col(columnName))
}

/**
* Returns the substring from string str before count occurrences
* of the delimiter delim. If count is positive,
* everything the left of the final delimiter (counting from left)
* is returned. If count is negative, every to the right of the
* final delimiter (counting from the right) is returned.
* substring_index performs a case-sensitive match when searching for delim.
* @since 1.12.1
*/
def substring_index(str: Column, delim: String, count: Int): Column = {
when(
lit(count) < lit(0),
callBuiltin(
"substring",
lit(str),
callBuiltin("regexp_instr", sqlExpr(s"reverse(${str}, ${delim}, 1, abs(${count}), 0"))))
.otherwise(
callBuiltin(
"substring",
lit(str),
1,
callBuiltin("regexp_instr", col("str"), lit(delim), 1, lit(count), 1)))
}

/**
* Wrapper for Snowflake built-in collect_list function. Get the values of array column.
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved
* @since 1.10.0
* @param c Column to be collect.
* @return The array.
*/
def collect_list(c: Column): Column = array_agg(c)
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved

/**
* Wrapper for Snowflake built-in collect_list function. Get the values of array column.
* @since 1.10.0
* @param s Column name to be collected.
* @return The array.
*/
def collect_list(s: String): Column = array_agg(col(s))

/**
* Invokes a built-in snowflake function with the specified name and arguments.
* Arguments can be of two types
Expand Down
38 changes: 38 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2177,7 +2177,45 @@ trait FunctionSuite extends TestData {
expected,
sort = false)
}
test("regexp_extract") {
val data = Seq("A MAN A PLAN A CANAL").toDF("a")
var expected = Seq(Row("MAN"))
checkAnswer(
data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 1, 1)),
expected,
sort = false)
expected = Seq(Row("PLAN"))
checkAnswer(
data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 2, 1)),
expected,
sort = false)
expected = Seq(Row("CANAL"))
checkAnswer(
data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 3, 1)),
expected,
sort = false)

expected = Seq(Row(null))
checkAnswer(
data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 4, 1)),
expected,
sort = false)
}
test("signum") {
val df = Seq(1, -2, 0).toDF("a")
checkAnswer(df.select(signum(col("a"))), Seq(Row(1), Row(-1), Row(0)), sort = false)
}

test("collect_list") {
assert(monthlySales.select(collect_list(col("amount"))).collect()(0).get(0).toString ==
"[\n 10000,\n 400,\n 4500,\n 35000,\n 5000,\n 3000,\n 200,\n 90500,\n 6000,\n " +
"5000,\n 2500,\n 9500,\n 8000,\n 10000,\n 800,\n 4500\n]")

}
test("substring_index") {
val df = Seq("It was the best of times, it was the worst of times").toDF("a")
checkAnswer(df.select(substring_index(col("a"), "was", 1)), Seq(Row(7)), sort = false)
}
}

class EagerFunctionSuite extends FunctionSuite with EagerSession
Expand Down
Loading