Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-802269 Added regexp_extract,signum,substring_index,collect_list #135

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -3880,6 +3880,116 @@ public static Column listagg(Column col) {
return new Column(com.snowflake.snowpark.functions.listagg(col.toScalaColumn()));
}

/**
* Signature - snowflake.snowpark.functions.regexp_extract (value: Union[Column, str], regexp:
* Union[Column, str], idx: int) Column Extract a specific group matched by a regex, from the
* specified string column. If the regex did not match, or the specified group did not match, an
* empty string is returned.
* Example:
* <pre>{@code
* from snowflake.snowpark.functions import regexp_extract
* df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"])
* df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
* ---------
* |"RES" |
* ---------
* |20 |
* |40 |
* ---------
* }</pr>
*
* @since 1.12.1
* @return Column object.
*/
public static Column regexp_extract(
Column col, String exp, Integer position, Integer Occurences, Integer grpIdx) {
return new Column(
com.snowflake.snowpark.functions.regexp_extract(
col.toScalaColumn(), exp, position, Occurences, grpIdx));
}

/**
* Returns the sign of its argument:
*
* <p>- -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0.
*
* <p>Args: col: The column to evaluate its sign
* Example::
* * <pre>{@code df =
* session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
* df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* }</pr>
*
* @since 1.12.1
* @param e Column to calculate the sign.
* @return Column object.
*/
public static Column signum(Column col) {
return new Column(com.snowflake.snowpark.functions.signum(col.toScalaColumn()));
}

/**
* Returns the sign of its argument:
*
* <p>- -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0.
*
* <p>Args: col: The column to evaluate its sign
* Example::
* <pre>{@code df =
* session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
* df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* }</pr>
*
* @since 1.12.1
* @param e Column to calculate the sign.
* @return Column object.
*/
public static Column sign(Column col) {
return new Column(com.snowflake.snowpark.functions.sign(col.toScalaColumn()));
}

/**
* Returns the substring from string str before count occurrences of the delimiter delim. If count
* is positive, everything the left of the final delimiter (counting from left) is returned. If
* count is negative, every to the right of the final delimiter (counting from the right) is
* returned. substring_index performs a case-sensitive match when searching for delim.
*
* @since 1.12.1
*/
public static Column substring_index(Column col, String delim, Integer count) {
return new Column(
com.snowflake.snowpark.functions.substring_index(col.toScalaColumn(), delim, count));
}

/**
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is
* returned.
* <p> Example::
* <pre>{@code
* df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* df.select(array_agg("a", True).alias("result")).show()
* "RESULT" [ 1, 2, 3 ]
* }</pre>
*
* @since 1.10.0
* @param c Column to be collect.
* @return The array.
*/
public static Column collect_list(Column col) {
return new Column(com.snowflake.snowpark.functions.collect_list(col.toScalaColumn()));
}
/**
* Calls a user-defined function (UDF) by name.
*
Expand Down
186 changes: 186 additions & 0 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3140,6 +3140,192 @@ object functions {
*/
def listagg(col: Column): Column = listagg(col, "", isDistinct = false)

/**
* Signature - snowflake.snowpark.functions.regexp_extract
* (value: Union[Column, str], regexp: Union[Column, str], idx: int)
* Column
* Extract a specific group matched by a regex, from the specified string
* column. If the regex did not match, or the specified group did not match,
* an empty string is returned.
* <pr>Example:
* from snowflake.snowpark.functions import regexp_extract
* df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]],
* ["id", "age"])
* df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
*</pr>
*<pr>
* ---------
* |"RES" |
* ---------
* |20 |
* |40 |
* ---------
*</pr>
* Note: non-greedy tokens such as are not supported
sfc-gh-sjayabalan marked this conversation as resolved.
Show resolved Hide resolved
* @since 1.12.1
* @return Column object.
*/
def regexp_extract(
colName: Column,
exp: String,
position: Int,
Occurences: Int,
grpIdx: Int): Column = {
when(colName.is_null, lit(null))
.otherwise(
coalesce(
builtin("REGEX_SUBSTR")(
colName,
lit(exp),
lit(position),
lit(Occurences),
lit("ce"),
lit(grpIdx)),
lit("")))
}

/**
* Returns the sign of its argument:
*
* - -1 if the argument is negative.
* - 1 if it is positive.
* - 0 if it is 0.
*
* Args:
* col: The column to evaluate its sign
*<pr>
* Example::
* >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"])
* >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* </pr>
* @since 1.12.1
* @param e Column to calculate the sign.
* @return Column object.
*/
def sign(colName: Column): Column = {
builtin("SIGN")(colName)
}

/**
* Returns the sign of its argument:
*
* - -1 if the argument is negative.
* - 1 if it is positive.
* - 0 if it is 0.
*
* Args:
* col: The column to evaluate its sign
*<pr>
* Example::
* >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"])
* >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* </pr>
* @since 1.12.1
* @param e Column to calculate the sign.
* @return Column object.
*/
def signum(colName: Column): Column = {
builtin("SIGN")(colName)
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Returns the sign of the given column. Returns either 1 for positive,
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved
* 0 for 0 or
* NaN, -1 for negative and null for null.
* NOTE: if string values are provided snowflake will attempts to cast.
* If it casts correctly, returns the calculation,
* if not an error will be thrown
* @since 1.12.1
* @param columnName Name of the column to calculate the sign.
* @return Column object.
*/
def signum(columnName: String): Column = {
signum(col(columnName))
}

/**
* Returns the substring from string str before count occurrences
* of the delimiter delim. If count is positive,
* everything the left of the final delimiter (counting from left)
* is returned. If count is negative, every to the right of the
* final delimiter (counting from the right) is returned.
* substring_index performs a case-sensitive match when searching for delim.
* @since 1.12.1
*/
def substring_index(str: Column, delim: String, count: Int): Column = {
when(
lit(count) < lit(0),
callBuiltin(
"substring",
lit(str),
callBuiltin("regexp_instr", sqlExpr(s"reverse(${str}, ${delim}, 1, abs(${count}), 0"))))
.otherwise(
callBuiltin(
"substring",
lit(str),
1,
callBuiltin("regexp_instr", col("str"), lit(delim), 1, lit(count), 1)))
}

/**
*
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty
* ARRAY is returned.
*<pr>
* Example::
* >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* >>> df.select(array_agg("a", True).alias("result")).show()
* ------------
* |"RESULT" |
* ------------
* |[ |
* | 1, |
* | 2, |
* | 3 |
* |] |
* ------------
* </pr>
* @since 1.10.0
* @param c Column to be collect.
* @return The array.
*/
def collect_list(c: Column): Column = array_agg(c)
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved

/**
*
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty
* ARRAY is returned.
*
* Example::
* >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* >>> df.select(array_agg("a", True).alias("result")).show()
* ------------
* |"RESULT" |
* ------------
* |[ |
* | 1, |
* | 2, |
* | 3 |
* |] |
* ------------
* @since 1.10.0
* @param s Column name to be collected.
* @return The array.
*/
def collect_list(s: String): Column = array_agg(col(s))

/**
* Invokes a built-in snowflake function with the specified name and arguments.
* Arguments can be of two types
Expand Down
50 changes: 50 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -2764,4 +2764,54 @@ public void any_value() {
assert result.length == 1;
assert result[0].getInt(0) == 1 || result[0].getInt(0) == 2 || result[0].getInt(0) == 3;
}

@Test
public void regexp_extract() {
DataFrame df = getSession().sql("select * from values('A MAN A PLAN A CANAL') as T(a)");
Row[] expected = {Row.create("MAN")};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 1, 1)), expected, false);
Row[] expected2 = {Row.create("PLAN")};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 2, 1)), expected2, false);
Row[] expected3 = {Row.create("CANAL")};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 2, 1)), expected3, false);
Row[] expected4 = {Row.create(null)};
checkAnswer(
df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 3, 1)), expected4, false);
}

@Test
public void signum() {
DataFrame df = getSession().sql("select * from values(1,-2,0) as T(a)");
checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1, -1, 0)}, false);
}

@Test
public void sign() {
DataFrame df = getSession().sql("select * from values(1,-2,0) as T(a)");
checkAnswer(df.select(Functions.sign(df.col("a"))), new Row[] {Row.create(1, -1, 0)}, false);
}

@Test
public void collect_list() {
DataFrame df = getSession().sql("select * from values(10000,400,450) as T(a)");
checkAnswer(
df.select(Functions.collect_list(df.col("a"))),
new Row[] {Row.create("[\n \"10000,400,450\"\n]")},
false);
}

@Test
public void substring_index() {
DataFrame df =
getSession()
.sql(
"select * from values ('It was the best of times,it was the worst of times') as T(a)");
checkAnswer(
df.select(Functions.substring_index(df.col("a"), "was", 1)),
new Row[] {Row.create(7)},
false);
}
}
Loading