diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 74cc39a8..1c2e12c7 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -2365,14 +2365,13 @@ public static Column regexp_count(Column strExpr, Column pattern) { */ public static Column regexp_replace(Column strExpr, Column pattern) { return new Column( - com.snowflake.snowpark.functions.regexp_replace( - strExpr.toScalaColumn(), pattern.toScalaColumn())); + com.snowflake.snowpark.functions.regexp_replace( + strExpr.toScalaColumn(), pattern.toScalaColumn())); } /** - * Returns the subject with the specified pattern (or all occurrences of the pattern) - * replaced by a replacement string. If no matches are found, returns the original - * subject. + * Returns the subject with the specified pattern (or all occurrences of the pattern) replaced by + * a replacement string. If no matches are found, returns the original subject. * * @param strExpr The input string * @param pattern The pattern @@ -2382,8 +2381,8 @@ public static Column regexp_replace(Column strExpr, Column pattern) { */ public static Column regexp_replace(Column strExpr, Column pattern, Column replacement) { return new Column( - com.snowflake.snowpark.functions.regexp_replace( - strExpr.toScalaColumn(), pattern.toScalaColumn(), replacement.toScalaColumn())); + com.snowflake.snowpark.functions.regexp_replace( + strExpr.toScalaColumn(), pattern.toScalaColumn(), replacement.toScalaColumn())); } /** diff --git a/src/main/scala/com/snowflake/snowpark/Column.scala b/src/main/scala/com/snowflake/snowpark/Column.scala index c0cec2aa..4867071e 100644 --- a/src/main/scala/com/snowflake/snowpark/Column.scala +++ b/src/main/scala/com/snowflake/snowpark/Column.scala @@ -718,6 +718,216 @@ case class Column private[snowpark] (private[snowpark] val expr: Expression) ext } protected def withExpr(newExpr: Expression): Column = Column(newExpr) + + /** + * Function that validates if the value of the column is + * within the list of strings from parameter. + * @since 1.10.0 + * @param strings List of strings to compare with the value. + * @return Column object. + */ + def isin(strings: String*): Column = { + in(strings) + } + + /** + * Function that validates if the values + * of the column are not null. + * @since 1.10.0 + * @return Column object. + */ + def isNotNull(): Column = + is_not_null + + /** + * Function that validates if the values + * of the column are null. + * @since 1.10.0 + * @return Column object. + */ + def isNull(): Column = + is_null + + /** + * Function that validates if the values of the column start + * with the parameter expression. + * @since 1.10.0 + * @param expr Expression to validate with the column's values. + * @return Column object. + */ + def startsWith(expr: String): Column = + com.snowflake.snowpark.functions.startswith(self, lit(expr)) + + /** + * Function that validates if the values of the column contain + * the value from the paratemer expression. + * @since 1.10.0 + * @param expr Expression to validate with the column's values. + * @return Column object. + */ + def contains(expr: String): Column = + builtin("contains")(self, expr) + + /** + * Function that replaces column's values according to the regex + * pattern and replacement value parameters. + * @since 1.10.0 + * @param pattern Regex pattern to replace. + * @param replacement Value to replace matches with. + * @return Column object. + */ + def regexp_replace(pattern: String, replacement: String): Column = + builtin("regexp_replace")(self, pattern, replacement) + + /** + * Function that gives the column an alias using a symbol. + * @since 1.10.0 + * @param symbol Symbol name. + * @return Column object. + */ + def as(symbol: Symbol): Column = + as(symbol.name) + + /** + * Function that returns True if the current expression is NaN. + * @since 1.10.0 + * @return Column object. + */ + def isNaN(): Column = equal_nan + + /** + * Function that returns the portion of the string or binary value str, + * starting from the character/byte specified by pos, with limited length. + * @since 1.10.0 + * @param pos Start position. + * @param len Length of the substring. + * @return Column object. + */ + def substr(pos: Column, len: Column): Column = + substring(self, pos, len) + + /** + * Function that returns the portion of the string or binary value str, + * starting from the character/byte specified by pos, with limited length. + * @since 1.10.0 + * @param pos Start position. + * @param len Length of the substring. + * @return Column object. + */ + def substr(pos: Int, len: Int): Column = + substring(self, lit(pos), lit(len)) + + /** + * Function that returns the result of the comparison of two columns. + * @since 1.10.0 + * @param other Column to compare. + * @return Column object. + */ + def notEqual(other: Any): Column = + not_equal(lit(other)) + + /** + * Function that returns a boolean column based on a match. + * @since 1.10.0 + * @param literal expresion to match. + * @return Column object. + */ + def like(literal: String): Column = + like(lit(literal)) + + /** + * Function that returns a boolean column based on a regex match. + * @since 1.10.0 + * @param literal Regex expresion to match. + * @return Column object. + */ + def rlike(literal: String): Column = + regexp(lit(literal)) + + /** + * Function that computes bitwise AND of this expression with another expression. + * @since 1.10.0 + * @param other Expression to match. + * @return Column object. + */ + def bitwiseAND(other: Any): Column = + bitand(lit(other)) + + /** + * Function that computes bitwise OR of this expression with another expression. + * @since 1.10.0 + * @param other Expression to match. + * @return Column object. + */ + def bitwiseOR(other: Any): Column = + bitor(lit(other)) + + /** + * Function that computes bitwise XOR of this expression with another expression.] + * @since 1.10.0 + * @param other Expression to match. + * @return Column object. + */ + def bitwiseXOR(other: Any): Column = + bitxor(lit(other)) + + /** + * Function that gets an item at position ordinal out of an array, + * or gets a value by key in a object. + * NOTE: The function returns a Variant type. You might need to add a cast + * @since 1.10.0 + * @param key Key element to get. + * @return Column object. + */ + def getItem(key: Any): Column = + builtin("get")(self, key) + + /** + * Function that gets a value by field name or key in a object. + * NOTE: The function returns a Variant type you might need to add a cast + * @since 1.10.0 + * @param fieldName Field name. + * @return Column object. + */ + def getField(fieldName: String): Column = + builtin("get")(self, fieldName) + + /** + * Function that casts the column to a different data type, + * using the canonical string representation of the type. + * The supported types are: string, boolean, byte, short, int, + * long, float, double, decimal, date, timestamp. + * NOTE: If cast is not possible returns null + * @since 1.10.0 + * @param to String representation of the type. + * @return Column object. + */ + def cast(to: String): Column = { + to match { + case "string" => c.try_cast(StringType) + case "boolean" => c.try_cast(BooleanType) + case "byte" => c.try_cast(ByteType) + case "short" => c.try_cast(ShortType) + case "int" => c.try_cast(IntegerType) + case "long" => c.try_cast(LongType) + case "float" => c.try_cast(FloatType) + case "double" => c.try_cast(DoubleType) + case "decimal" => c.try_cast(DecimalType(38, 0)) + case "date" => c.try_cast(DateType) + case "timestamp" => c.try_cast(TimestampType) + case _ => lit(null) + } + } + + /** + * Function that performs quality test that is safe for null values. + * @since 1.10.0 + * @param other Value to compare + * @return Column object. + */ + def eqNullSafe(other: Any): Column = + c.equal_null(lit(other)) + } private[snowpark] object Column { @@ -735,8 +945,10 @@ private[snowpark] object Column { * [[https://docs.snowflake.com/en/sql-reference/functions/case.html CASE]] expression. * * To construct this object for a CASE expression, call the - * [[com.snowflake.snowpark.functions.when functions.when]]. specifying a condition and the - * corresponding result for that condition. Then, call the [[when]] and [[otherwise]] methods to + * [[com.snowflake.snowpark.functions.when functions.when]]. + * specifying a condition and the + * corresponding result for that condition. Then, + * call the [[when]] and [[otherwise]] methods to * specify additional conditions and results. * * For example: @@ -759,16 +971,54 @@ class CaseExpr private[snowpark] (branches: Seq[(Expression, Expression)]) * * @since 0.2.0 */ - def when(condition: Column, value: Column): CaseExpr = - new CaseExpr(branches :+ ((condition.expr, value.expr))) + def when(condition: Column, value: Any): CaseExpr = { + value match { + case columnValue: Column => new CaseExpr(branches :+ ((condition.expr, columnValue.expr))) + case intValue: Int => new CaseExpr(branches :+ ((condition.expr, lit(intValue).expr))) + case stringValue: String => + new CaseExpr(branches :+ ((condition.expr, lit(stringValue).expr))) + case booleanValue: Boolean => + new CaseExpr(branches :+ ((condition.expr, lit(booleanValue).expr))) + case floatValue: Float => new CaseExpr(branches :+ ((condition.expr, lit(floatValue).expr))) + case doubleValue: Double => + new CaseExpr(branches :+ ((condition.expr, lit(doubleValue).expr))) + case _ => throw new IllegalArgumentException("Unsupported value type") + } + } /** * Sets the default result for this CASE expression. * * @since 0.2.0 */ - def otherwise(value: Column): Column = withExpr { - CaseWhen(branches, Option(value.expr)) + def otherwise(value: Any): Column = { + value match { + case columnValue: Column => + withExpr { + CaseWhen(branches, Option(columnValue.expr)) + } + case intValue: Int => + withExpr { + CaseWhen(branches, Option(lit(intValue).expr)) + } + case stringValue: String => + withExpr { + CaseWhen(branches, Option(lit(stringValue).expr)) + } + case booleanValue: Boolean => + withExpr { + CaseWhen(branches, Option(lit(booleanValue).expr)) + } + case floatValue: Float => + withExpr { + CaseWhen(branches, Option(lit(floatValue).expr)) + } + case doubleValue: Double => + withExpr { + CaseWhen(branches, Option(lit(doubleValue).expr)) + } + case _ => throw new IllegalArgumentException("Unsupported value type") + } } /** @@ -776,5 +1026,15 @@ class CaseExpr private[snowpark] (branches: Seq[(Expression, Expression)]) * * @since 0.2.0 */ - def `else`(value: Column): Column = otherwise(value) + def `else`(value: Any): Column = { + value match { + case columnValue: Column => otherwise(columnValue) + case intValue: Int => otherwise(intValue) + case stringValue: String => otherwise(stringValue) + case booleanValue: Boolean => otherwise(booleanValue) + case floatValue: Float => otherwise(floatValue) + case doubleValue: Double => otherwise(doubleValue) + case _ => throw new IllegalArgumentException("Unsupported value type") + } + } } diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 56307cb6..941d701a 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -534,6 +534,20 @@ class DataFrame private[snowpark] ( select(first +: remaining) } + /** + * Selects columns based on the + * expressions specified. They could either be + * column names, or calls to other functions such as conversions, + * case expressions, among others. + * @since 1.10.0 + * @param exprs Expressions to apply to select from the DataFrame. + * @return DataFrame with the selected expressions as columns. + * Unspecified columns are not included. + */ + def selectExpr(exprs: String*): DataFrame = { + select(exprs.map(e => sqlExpr(e))) + } + /** * Returns a new DataFrame with the specified Column expressions as output * (similar to SELECT in SQL). Only the Columns specified as arguments will be present in @@ -826,6 +840,24 @@ class DataFrame private[snowpark] ( */ def filter(condition: Column): DataFrame = withPlan(Filter(condition.expr, plan)) + /** + * Filters rows based on the specified conditional expression (similar to WHERE in SQL). + * + * For example: + * + * {{{ + * val dfFiltered = df.filter($"colA > 1 and colB < 100") + * }}} + * + * @group transform + * @since 1.10.0 + * @param condition Filter condition defined as a SQL expression + * @return A filtered [[DataFrame]] + */ + def filter(conditionExpr: String): DataFrame = { + df.where(sqlExpr(conditionExpr)) + } + /** * Filters rows based on the specified conditional expression (similar to WHERE in SQL). * This is equivalent to calling [[filter]]. @@ -1321,6 +1353,20 @@ class DataFrame private[snowpark] ( } } + /** + * Overload of dropDuplicates. + * Unspecified columns + * from the dataframe will be preserved, but won't be + * considered to calculate duplicates. For rows with different + * values on unspecified columns, it will return the first row. + * @param columns List of columns to group by to detect the duplicates. + * @since 1.10.0 + * @return DataFrame without duplicates on the specified columns. + */ + def dropDuplicates(columns: Seq[String]): DataFrame = { + dropDuplicates(columns: _*) + } + /** * Rotates this DataFrame by turning the unique values from one column in the input * expression into multiple columns and aggregating results where required on any @@ -2780,6 +2826,98 @@ class DataFrame private[snowpark] ( } @inline protected def withPlan(plan: LogicalPlan): DataFrame = DataFrame(session, plan) + + /** + * Function that returns the dataframe with a column renamed. + * @since 1.10.0 + * @param existingName Name of the column to rename. + * @param newName New name to give to the column. + * @return DataFrame with the column renamed. + */ + def withColumnRenamed(existingName: String, newName: String): DataFrame = { + rename(newName, col(existingName)) + } + + /** + * Transforms the DataFrame according to the function from the parameter. + * @since 1.10.0 + * @param func Function to apply to the DataFrame. + * @return DataFrame with the transformation applied. + */ + def transform(func: DataFrame => DataFrame): DataFrame = func(self) + + /** + * Returns the first row. Since this is an Option element, a `.get` + * is required to get the actual row. + * @since 1.10.0 + * @return The first row of the DataFrame. + */ + def head(): Option[Row] = first() + + /** + * Returns the first N rows. + * @since 1.10.0 + * @param n Amount of rows to return. + * @return Array with the amount of rows specified in the parameter. + */ + def head(n: Int): Array[Row] = first(n) + + /** + * Returns the first N rows. + * @since 1.10.0 + * @param n Amount of rows to return. + * @return Array with the amount of rows specified in the parameter. + */ + def take(n: Int): Array[Row] = first(n) + + /** + * Caches the result of the DataFrame and creates a new Dataframe, + * whose operations won't affect the original DataFrame. + * @since 1.10.0 + * @return New cached DataFrame. + */ + def cache(): DataFrame = cacheResult() + + /** + * Alias for sort function. Receives columns or column expressions. + * @since 1.10.0 + * @param sortExprs Column expressions to order the dataset by. + * @return Returns the dataset ordered by the specified expressions + */ + def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs: _*) + + /** + * Alias for sort function. Receives column names + * @since 1.10.0 + * @param sortCol Column name 1 + * @param sortCols Variable column names + * @return DataFrame filtered on the variable names. + */ + def orderBy(sortCol: String, sortCols: String*): DataFrame = + sort((Seq(sortCol) ++ sortCols).map(s => col(s))) + + /** + * This is a shortcut to schema.printTreeString(). Prints the schema + * of the DataFrame in a tree format. + * Includes column names, data types and if they're nullable or not. + * @since 1.10.0 + */ + def printSchema(): Unit = schema.printTreeString() + + /** + * Converts each row into a JSON object and returns a DataFrame with a single column. + * @since 1.10.0 + * @return DataFrame with 1 column whose value corresponds to a JSON object of the row. + */ + def toJSON: DataFrame = select(object_construct(col("*")).cast(StringType).as("value")) + + /** + * @since 1.10.0 + * Collects the DataFrame and converts it to a java.util.List[Row] object. + * @return A java.util.List[Row] representation of the DataFrame. + */ + def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(collect(): _*) + } /** diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index 41cc6fa2..1dcb8033 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -1424,6 +1424,16 @@ object Session extends Logging { createInternal(None) } + /** + * If there is an already existing session return it or create + * a new one and return it. + * @since 1.10.0 + * @return A [[Session]] + */ + def getOrCreate: Session = { + Session.getActiveSession.getOrElse(create) + } + private[snowpark] def createInternal(conn: Option[SnowflakeConnectionV1]): Session = { conn match { case Some(_) => diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index 49f38593..0388af76 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -1817,7 +1817,6 @@ object functions { def regexp_replace(strExpr: Column, pattern: Column, replacement: Column): Column = builtin("regexp_replace")(strExpr, pattern, replacement) - /** * Removes all occurrences of a specified strExpr, * and optionally replaces them with replacement. @@ -3709,4 +3708,835 @@ object functions { Column(FunctionExpression(name, exprs, isDistinct)) } + /** + * Function to convert a string into an SQL expression. + * @since 1.10.0 + * @param s SQL Expression as text. + * @return Converted SQL Expression. + */ + def expr(s: String): Column = sqlExpr(s) + + /** + * Function to convert column name into column and order in a descending manner. + * @since 1.10.0 + * @param c Column name. + * @return Column object ordered in a descending manner. + */ + def desc(c: String): Column = col(c).desc + + /** + * Function to convert column name into column and order in an ascending manner. + * @since 1.10.0 + * @param colname Column name. + * @return Column object ordered in an ascending manner. + */ + def asc(colname: String): Column = col(colname).asc + + /** + * Wrapper for Snowflake built-in size function. Gets the size of array column. + * @since 1.10.0 + * @param c Column to get the size. + * @return Size of array column. + */ + def size(c: Column): Column = array_size(c) + + /** + * Wrapper for Snowflake built-in array function. Create array from columns. + * @since 1.10.0 + * @param c Columns to build the array. + * @return The array. + */ + def array(c: Column*): Column = array_construct(c: _*) + + /** + * Wrapper for Snowflake built-in date_format function. Converts a date into a string using the specified format. + * @since 1.10.0 + * @param c Column to convert to string. + * @param s Date format. + * @return Column object. + */ + def date_format(c: Column, s: String): Column = + builtin("to_varchar")(c.cast(TimestampType), s.replace("mm", "mi")) + + /** + * Wrapper for Snowflake built-in last function. Gets the last value of a column according to its grouping. + * Functional difference with windows, In Snowpark is needed the order by. SQL doesn't guarantee the order. + * @since 1.10.0 + * @param c Column to obtain last value. + * @return Column object. + */ + def last(c: Column): Column = + builtin("LAST_VALUE")(c) + + /** + * Formats the arguments in printf-style and returns the result as a string column. + * @since 1.10.0 + * @note this function requires the format_string UDF to be previosly created + * @param format the printf-style format + * @param arguments arguments for the formatting string + * @return formatted string + */ + def format_string(format: String, arguments: Column*): Column = { + callBuiltin("format_string", lit(format), array_construct(arguments: _*)) + } + + /** + * Locate the position of the first occurrence of substr in a string column, after position pos. + * @note The position is not zero based, but 1 based index. returns 0 if substr + * could not be found in str. This function is just leverages the SF POSITION builtin + * @since 1.10.0 + * @param substr string to search + * @param str value where string will be searched + * @param pos index for starting the search + * @return Returns the position of the first occurrence + */ + def locate(substr: String, str: Column, pos: Int = 0): Column = + if (pos == 0) lit(0) else callBuiltin("POSITION", lit(substr), str, lit(pos)) + + /** + * + * Locate the position of the first occurrence of substr in a string column, after position pos. + * @since 1.10.0 + * @note The position is not zero based, but 1 based index. returns 0 if substr + * could not be found in str. This function is just leverages the SF POSITION builtin + * @param substr string to search + * @param str value where string will be searched + * @param pos index for starting the search + * @return returns the position of the first occurrence. + */ + def locate(substr: Column, str: Column, pos: Int): Column = + if (pos == 0) lit(0) else callBuiltin("POSITION", substr, str, pos) + + /** + * Computes the logarithm of the given column in base 10. + * @since 1.10.0 + * @param expr Column to apply this mathematical operation + * @return log2 of the given column + */ + def log10(expr: Column): Column = builtin("LOG")(10, expr) + + /** + * Computes the logarithm of the given column in base 10. + * @since 1.10.0 + * @param columnName Column to apply this mathematical operation + * @return log2 of the given column + */ + def log10(columnName: String): Column = builtin("LOG")(10, col(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + * @since 1.10.0 + * @param columnName the value to use + * @return the natural logarithm of the given value plus one. + */ + def log1p(columnName: String): Column = callBuiltin("ln", lit(1) + col(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + * @since 1.10.0 + * @param col the value to use + * @return the natural logarithm of the given value plus one. + */ + def log1p(col: Column): Column = callBuiltin("ln", lit(1) + col) + + /** + * Returns expr1 if it is not NaN, or expr2 if expr1 is NaN. + * @since 1.10.0 + * @param expr1 expression when value is NaN + * @param expr2 expression when value is not NaN + */ + def nanvl(expr1: Column, expr2: Column): Column = + callBuiltin("nanvl", expr1.cast(FloatType), expr2.cast(FloatType)).cast(FloatType) + + /** + * Computes the BASE64 encoding of a column + * @since 1.10.0 + * @param col + * @return the encoded column + */ + def base64(col: Column): Column = callBuiltin("BASE64_ENCODE", col) + + /** + * Decodes a BASE64 encoded string + * @since 1.10.0 + * @param col + * @return the decoded column + */ + def unbase64(col: Column): Column = callBuiltin("BASE64_DECODE_STRING", col) + + /** + * Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window + * partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second + * quarter will get 2, the third quarter will get 3, and the last quarter will get 4. + * + * This is equivalent to the NTILE function in SQL. + * @since 1.10.0 + * @param n number of groups + * @return retyr + */ + def ntile(n: Int): Column = callBuiltin("ntile", lit(n)) + + /** + * Alias for bitshiftleft + * @since 1.10.0 + * @param c Column to modify. + * @param numBits Number of bits to shift. + * @return Column object. + */ + def shiftleft(c: Column, numBits: Int): Column = + bitshiftleft(c, lit(numBits)) + + /** + * Alias for bitshiftright. + * @since 1.10.0 + * @param c Column to modify. + * @param numBits Number of bits to shift. + * @return Column object. + */ + def shiftright(c: Column, numBits: Int): Column = + bitshiftright(c, lit(numBits)) + + /** + * Wrapper for Snowflake built-in hex_encode function. Returns the hexadecimal representation of a string. + * @since 1.10.0 + * @param c Column to encode. + * @return Encoded string. + */ + def hex(c: Column): Column = + builtin("HEX_ENCODE")(c) + + /** + * Wrapper for Snowflake built-in hex_decode_string function. Returns the string representation of a hexadecimal value. + * @param c Column to encode. + * @since 1.10.0 + * @return Encoded string. + */ + def unhex(c: Column): Column = + builtin("HEX_DECODE_STRING")(c) + + /** + * Return a call to the Snowflake RANDOM function. + * NOTE: Snowflake returns integers of 17-19 digits. + * @since 1.10.0 + * @return Random number. + */ + def randn(): Column = + builtin("RANDOM")() + + /** + * Calls to the Snowflake RANDOM function. + * NOTE: Snowflake returns integers of 17-19 digits. + * @since 1.10.0 + * @param seed Seed to use in the random function. + * @return Random number. + */ + def randn(seed: Long): Column = + builtin("RANDOM")(seed) + + /** + * This leverages JSON_EXTRACT_PATH_TEXT and improves functionality by allowing multiple columns + * in a single call, whereas JSON_EXTRACT_PATH_TEXT must be called once for every column. + * + * NOTE: + *