Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extending the Snowpark Scala APIs #56

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -2365,14 +2365,13 @@ public static Column regexp_count(Column strExpr, Column pattern) {
*/
public static Column regexp_replace(Column strExpr, Column pattern) {
return new Column(
com.snowflake.snowpark.functions.regexp_replace(
strExpr.toScalaColumn(), pattern.toScalaColumn()));
com.snowflake.snowpark.functions.regexp_replace(
strExpr.toScalaColumn(), pattern.toScalaColumn()));
}

/**
* Returns the subject with the specified pattern (or all occurrences of the pattern)
* replaced by a replacement string. If no matches are found, returns the original
* subject.
* Returns the subject with the specified pattern (or all occurrences of the pattern) replaced by
* a replacement string. If no matches are found, returns the original subject.
*
* @param strExpr The input string
* @param pattern The pattern
Expand All @@ -2382,8 +2381,8 @@ public static Column regexp_replace(Column strExpr, Column pattern) {
*/
public static Column regexp_replace(Column strExpr, Column pattern, Column replacement) {
return new Column(
com.snowflake.snowpark.functions.regexp_replace(
strExpr.toScalaColumn(), pattern.toScalaColumn(), replacement.toScalaColumn()));
com.snowflake.snowpark.functions.regexp_replace(
strExpr.toScalaColumn(), pattern.toScalaColumn(), replacement.toScalaColumn()));
}

/**
Expand Down
274 changes: 267 additions & 7 deletions src/main/scala/com/snowflake/snowpark/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,216 @@ case class Column private[snowpark] (private[snowpark] val expr: Expression) ext
}

protected def withExpr(newExpr: Expression): Column = Column(newExpr)

/**
* Function that validates if the value of the column is
* within the list of strings from parameter.
* @since 1.10.0
* @param strings List of strings to compare with the value.
* @return Column object.
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved
*/
def isin(strings: String*): Column = {
in(strings)
}

/**
* Function that validates if the values
* of the column are not null.
* @since 1.10.0
* @return Column object.
*/
def isNotNull(): Column =
is_not_null

/**
* Function that validates if the values
* of the column are null.
* @since 1.10.0
* @return Column object.
*/
def isNull(): Column =
is_null

/**
* Function that validates if the values of the column start
* with the parameter expression.
* @since 1.10.0
* @param expr Expression to validate with the column's values.
* @return Column object.
*/
def startsWith(expr: String): Column =
com.snowflake.snowpark.functions.startswith(self, lit(expr))

/**
* Function that validates if the values of the column contain
* the value from the paratemer expression.
* @since 1.10.0
* @param expr Expression to validate with the column's values.
* @return Column object.
*/
def contains(expr: String): Column =
builtin("contains")(self, expr)

/**
* Function that replaces column's values according to the regex
* pattern and replacement value parameters.
* @since 1.10.0
* @param pattern Regex pattern to replace.
* @param replacement Value to replace matches with.
* @return Column object.
*/
def regexp_replace(pattern: String, replacement: String): Column =
builtin("regexp_replace")(self, pattern, replacement)

/**
* Function that gives the column an alias using a symbol.
* @since 1.10.0
* @param symbol Symbol name.
* @return Column object.
*/
def as(symbol: Symbol): Column =
as(symbol.name)

/**
* Function that returns True if the current expression is NaN.
* @since 1.10.0
* @return Column object.
*/
def isNaN(): Column = equal_nan

/**
* Function that returns the portion of the string or binary value str,
* starting from the character/byte specified by pos, with limited length.
* @since 1.10.0
* @param pos Start position.
* @param len Length of the substring.
* @return Column object.
*/
def substr(pos: Column, len: Column): Column =
substring(self, pos, len)

/**
* Function that returns the portion of the string or binary value str,
* starting from the character/byte specified by pos, with limited length.
* @since 1.10.0
* @param pos Start position.
* @param len Length of the substring.
* @return Column object.
*/
def substr(pos: Int, len: Int): Column =
substring(self, lit(pos), lit(len))

/**
* Function that returns the result of the comparison of two columns.
* @since 1.10.0
* @param other Column to compare.
* @return Column object.
*/
def notEqual(other: Any): Column =
not_equal(lit(other))

/**
* Function that returns a boolean column based on a match.
* @since 1.10.0
* @param literal expresion to match.
* @return Column object.
*/
def like(literal: String): Column =
like(lit(literal))

/**
* Function that returns a boolean column based on a regex match.
* @since 1.10.0
* @param literal Regex expresion to match.
* @return Column object.
*/
def rlike(literal: String): Column =
regexp(lit(literal))

/**
* Function that computes bitwise AND of this expression with another expression.
* @since 1.10.0
* @param other Expression to match.
* @return Column object.
*/
def bitwiseAND(other: Any): Column =
bitand(lit(other))

/**
* Function that computes bitwise OR of this expression with another expression.
* @since 1.10.0
* @param other Expression to match.
* @return Column object.
*/
def bitwiseOR(other: Any): Column =
bitor(lit(other))

/**
* Function that computes bitwise XOR of this expression with another expression.]
* @since 1.10.0
* @param other Expression to match.
* @return Column object.
*/
def bitwiseXOR(other: Any): Column =
bitxor(lit(other))

/**
* Function that gets an item at position ordinal out of an array,
* or gets a value by key in a object.
* NOTE: The function returns a Variant type. You might need to add a cast
* @since 1.10.0
* @param key Key element to get.
* @return Column object.
*/
def getItem(key: Any): Column =
builtin("get")(self, key)

/**
* Function that gets a value by field name or key in a object.
* NOTE: The function returns a Variant type you might need to add a cast
* @since 1.10.0
* @param fieldName Field name.
* @return Column object.
*/
def getField(fieldName: String): Column =
builtin("get")(self, fieldName)

/**
* Function that casts the column to a different data type,
* using the canonical string representation of the type.
* The supported types are: string, boolean, byte, short, int,
* long, float, double, decimal, date, timestamp.
* NOTE: If cast is not possible returns null
* @since 1.10.0
* @param to String representation of the type.
* @return Column object.
*/
def cast(to: String): Column = {
to match {
case "string" => c.try_cast(StringType)
case "boolean" => c.try_cast(BooleanType)
case "byte" => c.try_cast(ByteType)
case "short" => c.try_cast(ShortType)
case "int" => c.try_cast(IntegerType)
case "long" => c.try_cast(LongType)
case "float" => c.try_cast(FloatType)
case "double" => c.try_cast(DoubleType)
case "decimal" => c.try_cast(DecimalType(38, 0))
case "date" => c.try_cast(DateType)
case "timestamp" => c.try_cast(TimestampType)
case _ => lit(null)
}
}

/**
* Function that performs quality test that is safe for null values.
* @since 1.10.0
* @param other Value to compare
* @return Column object.
*/
def eqNullSafe(other: Any): Column =
c.equal_null(lit(other))

}

private[snowpark] object Column {
Expand All @@ -735,8 +945,10 @@ private[snowpark] object Column {
* [[https://docs.snowflake.com/en/sql-reference/functions/case.html CASE]] expression.
*
* To construct this object for a CASE expression, call the
* [[com.snowflake.snowpark.functions.when functions.when]]. specifying a condition and the
* corresponding result for that condition. Then, call the [[when]] and [[otherwise]] methods to
* [[com.snowflake.snowpark.functions.when functions.when]].
* specifying a condition and the
* corresponding result for that condition. Then,
* call the [[when]] and [[otherwise]] methods to
* specify additional conditions and results.
*
* For example:
Expand All @@ -759,22 +971,70 @@ class CaseExpr private[snowpark] (branches: Seq[(Expression, Expression)])
*
* @since 0.2.0
*/
def when(condition: Column, value: Column): CaseExpr =
new CaseExpr(branches :+ ((condition.expr, value.expr)))
def when(condition: Column, value: Any): CaseExpr = {
value match {
case columnValue: Column => new CaseExpr(branches :+ ((condition.expr, columnValue.expr)))
case intValue: Int => new CaseExpr(branches :+ ((condition.expr, lit(intValue).expr)))
case stringValue: String =>
new CaseExpr(branches :+ ((condition.expr, lit(stringValue).expr)))
case booleanValue: Boolean =>
new CaseExpr(branches :+ ((condition.expr, lit(booleanValue).expr)))
case floatValue: Float => new CaseExpr(branches :+ ((condition.expr, lit(floatValue).expr)))
case doubleValue: Double =>
new CaseExpr(branches :+ ((condition.expr, lit(doubleValue).expr)))
case _ => throw new IllegalArgumentException("Unsupported value type")
}
}

/**
* Sets the default result for this CASE expression.
*
* @since 0.2.0
*/
def otherwise(value: Column): Column = withExpr {
CaseWhen(branches, Option(value.expr))
def otherwise(value: Any): Column = {
value match {
case columnValue: Column =>
withExpr {
CaseWhen(branches, Option(columnValue.expr))
}
case intValue: Int =>
withExpr {
CaseWhen(branches, Option(lit(intValue).expr))
}
case stringValue: String =>
withExpr {
CaseWhen(branches, Option(lit(stringValue).expr))
}
case booleanValue: Boolean =>
withExpr {
CaseWhen(branches, Option(lit(booleanValue).expr))
}
case floatValue: Float =>
withExpr {
CaseWhen(branches, Option(lit(floatValue).expr))
}
case doubleValue: Double =>
withExpr {
CaseWhen(branches, Option(lit(doubleValue).expr))
}
case _ => throw new IllegalArgumentException("Unsupported value type")
}
}

/**
* Sets the default result for this CASE expression. Alias for [[otherwise]].
*
* @since 0.2.0
*/
def `else`(value: Column): Column = otherwise(value)
def `else`(value: Any): Column = {
value match {
case columnValue: Column => otherwise(columnValue)
case intValue: Int => otherwise(intValue)
case stringValue: String => otherwise(stringValue)
case booleanValue: Boolean => otherwise(booleanValue)
case floatValue: Float => otherwise(floatValue)
case doubleValue: Double => otherwise(doubleValue)
case _ => throw new IllegalArgumentException("Unsupported value type")
}
}
}
Loading
Loading