diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 74cc39a8..1c2e12c7 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -2365,14 +2365,13 @@ public static Column regexp_count(Column strExpr, Column pattern) { */ public static Column regexp_replace(Column strExpr, Column pattern) { return new Column( - com.snowflake.snowpark.functions.regexp_replace( - strExpr.toScalaColumn(), pattern.toScalaColumn())); + com.snowflake.snowpark.functions.regexp_replace( + strExpr.toScalaColumn(), pattern.toScalaColumn())); } /** - * Returns the subject with the specified pattern (or all occurrences of the pattern) - * replaced by a replacement string. If no matches are found, returns the original - * subject. + * Returns the subject with the specified pattern (or all occurrences of the pattern) replaced by + * a replacement string. If no matches are found, returns the original subject. * * @param strExpr The input string * @param pattern The pattern @@ -2382,8 +2381,8 @@ public static Column regexp_replace(Column strExpr, Column pattern) { */ public static Column regexp_replace(Column strExpr, Column pattern, Column replacement) { return new Column( - com.snowflake.snowpark.functions.regexp_replace( - strExpr.toScalaColumn(), pattern.toScalaColumn(), replacement.toScalaColumn())); + com.snowflake.snowpark.functions.regexp_replace( + strExpr.toScalaColumn(), pattern.toScalaColumn(), replacement.toScalaColumn())); } /** diff --git a/src/main/scala/com/snowflake/snowpark/Column.scala b/src/main/scala/com/snowflake/snowpark/Column.scala index c0cec2aa..4867071e 100644 --- a/src/main/scala/com/snowflake/snowpark/Column.scala +++ b/src/main/scala/com/snowflake/snowpark/Column.scala @@ -718,6 +718,216 @@ case class Column private[snowpark] (private[snowpark] val expr: Expression) ext } protected def withExpr(newExpr: Expression): Column = Column(newExpr) + + /** + * Function that validates if the value of the column is + * within the list of strings from parameter. + * @since 1.10.0 + * @param strings List of strings to compare with the value. + * @return Column object. + */ + def isin(strings: String*): Column = { + in(strings) + } + + /** + * Function that validates if the values + * of the column are not null. + * @since 1.10.0 + * @return Column object. + */ + def isNotNull(): Column = + is_not_null + + /** + * Function that validates if the values + * of the column are null. + * @since 1.10.0 + * @return Column object. + */ + def isNull(): Column = + is_null + + /** + * Function that validates if the values of the column start + * with the parameter expression. + * @since 1.10.0 + * @param expr Expression to validate with the column's values. + * @return Column object. + */ + def startsWith(expr: String): Column = + com.snowflake.snowpark.functions.startswith(self, lit(expr)) + + /** + * Function that validates if the values of the column contain + * the value from the paratemer expression. + * @since 1.10.0 + * @param expr Expression to validate with the column's values. + * @return Column object. + */ + def contains(expr: String): Column = + builtin("contains")(self, expr) + + /** + * Function that replaces column's values according to the regex + * pattern and replacement value parameters. + * @since 1.10.0 + * @param pattern Regex pattern to replace. + * @param replacement Value to replace matches with. + * @return Column object. + */ + def regexp_replace(pattern: String, replacement: String): Column = + builtin("regexp_replace")(self, pattern, replacement) + + /** + * Function that gives the column an alias using a symbol. + * @since 1.10.0 + * @param symbol Symbol name. + * @return Column object. + */ + def as(symbol: Symbol): Column = + as(symbol.name) + + /** + * Function that returns True if the current expression is NaN. + * @since 1.10.0 + * @return Column object. + */ + def isNaN(): Column = equal_nan + + /** + * Function that returns the portion of the string or binary value str, + * starting from the character/byte specified by pos, with limited length. + * @since 1.10.0 + * @param pos Start position. + * @param len Length of the substring. + * @return Column object. + */ + def substr(pos: Column, len: Column): Column = + substring(self, pos, len) + + /** + * Function that returns the portion of the string or binary value str, + * starting from the character/byte specified by pos, with limited length. + * @since 1.10.0 + * @param pos Start position. + * @param len Length of the substring. + * @return Column object. + */ + def substr(pos: Int, len: Int): Column = + substring(self, lit(pos), lit(len)) + + /** + * Function that returns the result of the comparison of two columns. + * @since 1.10.0 + * @param other Column to compare. + * @return Column object. + */ + def notEqual(other: Any): Column = + not_equal(lit(other)) + + /** + * Function that returns a boolean column based on a match. + * @since 1.10.0 + * @param literal expresion to match. + * @return Column object. + */ + def like(literal: String): Column = + like(lit(literal)) + + /** + * Function that returns a boolean column based on a regex match. + * @since 1.10.0 + * @param literal Regex expresion to match. + * @return Column object. + */ + def rlike(literal: String): Column = + regexp(lit(literal)) + + /** + * Function that computes bitwise AND of this expression with another expression. + * @since 1.10.0 + * @param other Expression to match. + * @return Column object. + */ + def bitwiseAND(other: Any): Column = + bitand(lit(other)) + + /** + * Function that computes bitwise OR of this expression with another expression. + * @since 1.10.0 + * @param other Expression to match. + * @return Column object. + */ + def bitwiseOR(other: Any): Column = + bitor(lit(other)) + + /** + * Function that computes bitwise XOR of this expression with another expression.] + * @since 1.10.0 + * @param other Expression to match. + * @return Column object. + */ + def bitwiseXOR(other: Any): Column = + bitxor(lit(other)) + + /** + * Function that gets an item at position ordinal out of an array, + * or gets a value by key in a object. + * NOTE: The function returns a Variant type. You might need to add a cast + * @since 1.10.0 + * @param key Key element to get. + * @return Column object. + */ + def getItem(key: Any): Column = + builtin("get")(self, key) + + /** + * Function that gets a value by field name or key in a object. + * NOTE: The function returns a Variant type you might need to add a cast + * @since 1.10.0 + * @param fieldName Field name. + * @return Column object. + */ + def getField(fieldName: String): Column = + builtin("get")(self, fieldName) + + /** + * Function that casts the column to a different data type, + * using the canonical string representation of the type. + * The supported types are: string, boolean, byte, short, int, + * long, float, double, decimal, date, timestamp. + * NOTE: If cast is not possible returns null + * @since 1.10.0 + * @param to String representation of the type. + * @return Column object. + */ + def cast(to: String): Column = { + to match { + case "string" => c.try_cast(StringType) + case "boolean" => c.try_cast(BooleanType) + case "byte" => c.try_cast(ByteType) + case "short" => c.try_cast(ShortType) + case "int" => c.try_cast(IntegerType) + case "long" => c.try_cast(LongType) + case "float" => c.try_cast(FloatType) + case "double" => c.try_cast(DoubleType) + case "decimal" => c.try_cast(DecimalType(38, 0)) + case "date" => c.try_cast(DateType) + case "timestamp" => c.try_cast(TimestampType) + case _ => lit(null) + } + } + + /** + * Function that performs quality test that is safe for null values. + * @since 1.10.0 + * @param other Value to compare + * @return Column object. + */ + def eqNullSafe(other: Any): Column = + c.equal_null(lit(other)) + } private[snowpark] object Column { @@ -735,8 +945,10 @@ private[snowpark] object Column { * [[https://docs.snowflake.com/en/sql-reference/functions/case.html CASE]] expression. * * To construct this object for a CASE expression, call the - * [[com.snowflake.snowpark.functions.when functions.when]]. specifying a condition and the - * corresponding result for that condition. Then, call the [[when]] and [[otherwise]] methods to + * [[com.snowflake.snowpark.functions.when functions.when]]. + * specifying a condition and the + * corresponding result for that condition. Then, + * call the [[when]] and [[otherwise]] methods to * specify additional conditions and results. * * For example: @@ -759,16 +971,54 @@ class CaseExpr private[snowpark] (branches: Seq[(Expression, Expression)]) * * @since 0.2.0 */ - def when(condition: Column, value: Column): CaseExpr = - new CaseExpr(branches :+ ((condition.expr, value.expr))) + def when(condition: Column, value: Any): CaseExpr = { + value match { + case columnValue: Column => new CaseExpr(branches :+ ((condition.expr, columnValue.expr))) + case intValue: Int => new CaseExpr(branches :+ ((condition.expr, lit(intValue).expr))) + case stringValue: String => + new CaseExpr(branches :+ ((condition.expr, lit(stringValue).expr))) + case booleanValue: Boolean => + new CaseExpr(branches :+ ((condition.expr, lit(booleanValue).expr))) + case floatValue: Float => new CaseExpr(branches :+ ((condition.expr, lit(floatValue).expr))) + case doubleValue: Double => + new CaseExpr(branches :+ ((condition.expr, lit(doubleValue).expr))) + case _ => throw new IllegalArgumentException("Unsupported value type") + } + } /** * Sets the default result for this CASE expression. * * @since 0.2.0 */ - def otherwise(value: Column): Column = withExpr { - CaseWhen(branches, Option(value.expr)) + def otherwise(value: Any): Column = { + value match { + case columnValue: Column => + withExpr { + CaseWhen(branches, Option(columnValue.expr)) + } + case intValue: Int => + withExpr { + CaseWhen(branches, Option(lit(intValue).expr)) + } + case stringValue: String => + withExpr { + CaseWhen(branches, Option(lit(stringValue).expr)) + } + case booleanValue: Boolean => + withExpr { + CaseWhen(branches, Option(lit(booleanValue).expr)) + } + case floatValue: Float => + withExpr { + CaseWhen(branches, Option(lit(floatValue).expr)) + } + case doubleValue: Double => + withExpr { + CaseWhen(branches, Option(lit(doubleValue).expr)) + } + case _ => throw new IllegalArgumentException("Unsupported value type") + } } /** @@ -776,5 +1026,15 @@ class CaseExpr private[snowpark] (branches: Seq[(Expression, Expression)]) * * @since 0.2.0 */ - def `else`(value: Column): Column = otherwise(value) + def `else`(value: Any): Column = { + value match { + case columnValue: Column => otherwise(columnValue) + case intValue: Int => otherwise(intValue) + case stringValue: String => otherwise(stringValue) + case booleanValue: Boolean => otherwise(booleanValue) + case floatValue: Float => otherwise(floatValue) + case doubleValue: Double => otherwise(doubleValue) + case _ => throw new IllegalArgumentException("Unsupported value type") + } + } } diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 56307cb6..941d701a 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -534,6 +534,20 @@ class DataFrame private[snowpark] ( select(first +: remaining) } + /** + * Selects columns based on the + * expressions specified. They could either be + * column names, or calls to other functions such as conversions, + * case expressions, among others. + * @since 1.10.0 + * @param exprs Expressions to apply to select from the DataFrame. + * @return DataFrame with the selected expressions as columns. + * Unspecified columns are not included. + */ + def selectExpr(exprs: String*): DataFrame = { + select(exprs.map(e => sqlExpr(e))) + } + /** * Returns a new DataFrame with the specified Column expressions as output * (similar to SELECT in SQL). Only the Columns specified as arguments will be present in @@ -826,6 +840,24 @@ class DataFrame private[snowpark] ( */ def filter(condition: Column): DataFrame = withPlan(Filter(condition.expr, plan)) + /** + * Filters rows based on the specified conditional expression (similar to WHERE in SQL). + * + * For example: + * + * {{{ + * val dfFiltered = df.filter($"colA > 1 and colB < 100") + * }}} + * + * @group transform + * @since 1.10.0 + * @param condition Filter condition defined as a SQL expression + * @return A filtered [[DataFrame]] + */ + def filter(conditionExpr: String): DataFrame = { + df.where(sqlExpr(conditionExpr)) + } + /** * Filters rows based on the specified conditional expression (similar to WHERE in SQL). * This is equivalent to calling [[filter]]. @@ -1321,6 +1353,20 @@ class DataFrame private[snowpark] ( } } + /** + * Overload of dropDuplicates. + * Unspecified columns + * from the dataframe will be preserved, but won't be + * considered to calculate duplicates. For rows with different + * values on unspecified columns, it will return the first row. + * @param columns List of columns to group by to detect the duplicates. + * @since 1.10.0 + * @return DataFrame without duplicates on the specified columns. + */ + def dropDuplicates(columns: Seq[String]): DataFrame = { + dropDuplicates(columns: _*) + } + /** * Rotates this DataFrame by turning the unique values from one column in the input * expression into multiple columns and aggregating results where required on any @@ -2780,6 +2826,98 @@ class DataFrame private[snowpark] ( } @inline protected def withPlan(plan: LogicalPlan): DataFrame = DataFrame(session, plan) + + /** + * Function that returns the dataframe with a column renamed. + * @since 1.10.0 + * @param existingName Name of the column to rename. + * @param newName New name to give to the column. + * @return DataFrame with the column renamed. + */ + def withColumnRenamed(existingName: String, newName: String): DataFrame = { + rename(newName, col(existingName)) + } + + /** + * Transforms the DataFrame according to the function from the parameter. + * @since 1.10.0 + * @param func Function to apply to the DataFrame. + * @return DataFrame with the transformation applied. + */ + def transform(func: DataFrame => DataFrame): DataFrame = func(self) + + /** + * Returns the first row. Since this is an Option element, a `.get` + * is required to get the actual row. + * @since 1.10.0 + * @return The first row of the DataFrame. + */ + def head(): Option[Row] = first() + + /** + * Returns the first N rows. + * @since 1.10.0 + * @param n Amount of rows to return. + * @return Array with the amount of rows specified in the parameter. + */ + def head(n: Int): Array[Row] = first(n) + + /** + * Returns the first N rows. + * @since 1.10.0 + * @param n Amount of rows to return. + * @return Array with the amount of rows specified in the parameter. + */ + def take(n: Int): Array[Row] = first(n) + + /** + * Caches the result of the DataFrame and creates a new Dataframe, + * whose operations won't affect the original DataFrame. + * @since 1.10.0 + * @return New cached DataFrame. + */ + def cache(): DataFrame = cacheResult() + + /** + * Alias for sort function. Receives columns or column expressions. + * @since 1.10.0 + * @param sortExprs Column expressions to order the dataset by. + * @return Returns the dataset ordered by the specified expressions + */ + def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs: _*) + + /** + * Alias for sort function. Receives column names + * @since 1.10.0 + * @param sortCol Column name 1 + * @param sortCols Variable column names + * @return DataFrame filtered on the variable names. + */ + def orderBy(sortCol: String, sortCols: String*): DataFrame = + sort((Seq(sortCol) ++ sortCols).map(s => col(s))) + + /** + * This is a shortcut to schema.printTreeString(). Prints the schema + * of the DataFrame in a tree format. + * Includes column names, data types and if they're nullable or not. + * @since 1.10.0 + */ + def printSchema(): Unit = schema.printTreeString() + + /** + * Converts each row into a JSON object and returns a DataFrame with a single column. + * @since 1.10.0 + * @return DataFrame with 1 column whose value corresponds to a JSON object of the row. + */ + def toJSON: DataFrame = select(object_construct(col("*")).cast(StringType).as("value")) + + /** + * @since 1.10.0 + * Collects the DataFrame and converts it to a java.util.List[Row] object. + * @return A java.util.List[Row] representation of the DataFrame. + */ + def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(collect(): _*) + } /** diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index 41cc6fa2..1dcb8033 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -1424,6 +1424,16 @@ object Session extends Logging { createInternal(None) } + /** + * If there is an already existing session return it or create + * a new one and return it. + * @since 1.10.0 + * @return A [[Session]] + */ + def getOrCreate: Session = { + Session.getActiveSession.getOrElse(create) + } + private[snowpark] def createInternal(conn: Option[SnowflakeConnectionV1]): Session = { conn match { case Some(_) => diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index 49f38593..0388af76 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -1817,7 +1817,6 @@ object functions { def regexp_replace(strExpr: Column, pattern: Column, replacement: Column): Column = builtin("regexp_replace")(strExpr, pattern, replacement) - /** * Removes all occurrences of a specified strExpr, * and optionally replaces them with replacement. @@ -3709,4 +3708,835 @@ object functions { Column(FunctionExpression(name, exprs, isDistinct)) } + /** + * Function to convert a string into an SQL expression. + * @since 1.10.0 + * @param s SQL Expression as text. + * @return Converted SQL Expression. + */ + def expr(s: String): Column = sqlExpr(s) + + /** + * Function to convert column name into column and order in a descending manner. + * @since 1.10.0 + * @param c Column name. + * @return Column object ordered in a descending manner. + */ + def desc(c: String): Column = col(c).desc + + /** + * Function to convert column name into column and order in an ascending manner. + * @since 1.10.0 + * @param colname Column name. + * @return Column object ordered in an ascending manner. + */ + def asc(colname: String): Column = col(colname).asc + + /** + * Wrapper for Snowflake built-in size function. Gets the size of array column. + * @since 1.10.0 + * @param c Column to get the size. + * @return Size of array column. + */ + def size(c: Column): Column = array_size(c) + + /** + * Wrapper for Snowflake built-in array function. Create array from columns. + * @since 1.10.0 + * @param c Columns to build the array. + * @return The array. + */ + def array(c: Column*): Column = array_construct(c: _*) + + /** + * Wrapper for Snowflake built-in date_format function. Converts a date into a string using the specified format. + * @since 1.10.0 + * @param c Column to convert to string. + * @param s Date format. + * @return Column object. + */ + def date_format(c: Column, s: String): Column = + builtin("to_varchar")(c.cast(TimestampType), s.replace("mm", "mi")) + + /** + * Wrapper for Snowflake built-in last function. Gets the last value of a column according to its grouping. + * Functional difference with windows, In Snowpark is needed the order by. SQL doesn't guarantee the order. + * @since 1.10.0 + * @param c Column to obtain last value. + * @return Column object. + */ + def last(c: Column): Column = + builtin("LAST_VALUE")(c) + + /** + * Formats the arguments in printf-style and returns the result as a string column. + * @since 1.10.0 + * @note this function requires the format_string UDF to be previosly created + * @param format the printf-style format + * @param arguments arguments for the formatting string + * @return formatted string + */ + def format_string(format: String, arguments: Column*): Column = { + callBuiltin("format_string", lit(format), array_construct(arguments: _*)) + } + + /** + * Locate the position of the first occurrence of substr in a string column, after position pos. + * @note The position is not zero based, but 1 based index. returns 0 if substr + * could not be found in str. This function is just leverages the SF POSITION builtin + * @since 1.10.0 + * @param substr string to search + * @param str value where string will be searched + * @param pos index for starting the search + * @return Returns the position of the first occurrence + */ + def locate(substr: String, str: Column, pos: Int = 0): Column = + if (pos == 0) lit(0) else callBuiltin("POSITION", lit(substr), str, lit(pos)) + + /** + * + * Locate the position of the first occurrence of substr in a string column, after position pos. + * @since 1.10.0 + * @note The position is not zero based, but 1 based index. returns 0 if substr + * could not be found in str. This function is just leverages the SF POSITION builtin + * @param substr string to search + * @param str value where string will be searched + * @param pos index for starting the search + * @return returns the position of the first occurrence. + */ + def locate(substr: Column, str: Column, pos: Int): Column = + if (pos == 0) lit(0) else callBuiltin("POSITION", substr, str, pos) + + /** + * Computes the logarithm of the given column in base 10. + * @since 1.10.0 + * @param expr Column to apply this mathematical operation + * @return log2 of the given column + */ + def log10(expr: Column): Column = builtin("LOG")(10, expr) + + /** + * Computes the logarithm of the given column in base 10. + * @since 1.10.0 + * @param columnName Column to apply this mathematical operation + * @return log2 of the given column + */ + def log10(columnName: String): Column = builtin("LOG")(10, col(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + * @since 1.10.0 + * @param columnName the value to use + * @return the natural logarithm of the given value plus one. + */ + def log1p(columnName: String): Column = callBuiltin("ln", lit(1) + col(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + * @since 1.10.0 + * @param col the value to use + * @return the natural logarithm of the given value plus one. + */ + def log1p(col: Column): Column = callBuiltin("ln", lit(1) + col) + + /** + * Returns expr1 if it is not NaN, or expr2 if expr1 is NaN. + * @since 1.10.0 + * @param expr1 expression when value is NaN + * @param expr2 expression when value is not NaN + */ + def nanvl(expr1: Column, expr2: Column): Column = + callBuiltin("nanvl", expr1.cast(FloatType), expr2.cast(FloatType)).cast(FloatType) + + /** + * Computes the BASE64 encoding of a column + * @since 1.10.0 + * @param col + * @return the encoded column + */ + def base64(col: Column): Column = callBuiltin("BASE64_ENCODE", col) + + /** + * Decodes a BASE64 encoded string + * @since 1.10.0 + * @param col + * @return the decoded column + */ + def unbase64(col: Column): Column = callBuiltin("BASE64_DECODE_STRING", col) + + /** + * Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window + * partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second + * quarter will get 2, the third quarter will get 3, and the last quarter will get 4. + * + * This is equivalent to the NTILE function in SQL. + * @since 1.10.0 + * @param n number of groups + * @return retyr + */ + def ntile(n: Int): Column = callBuiltin("ntile", lit(n)) + + /** + * Alias for bitshiftleft + * @since 1.10.0 + * @param c Column to modify. + * @param numBits Number of bits to shift. + * @return Column object. + */ + def shiftleft(c: Column, numBits: Int): Column = + bitshiftleft(c, lit(numBits)) + + /** + * Alias for bitshiftright. + * @since 1.10.0 + * @param c Column to modify. + * @param numBits Number of bits to shift. + * @return Column object. + */ + def shiftright(c: Column, numBits: Int): Column = + bitshiftright(c, lit(numBits)) + + /** + * Wrapper for Snowflake built-in hex_encode function. Returns the hexadecimal representation of a string. + * @since 1.10.0 + * @param c Column to encode. + * @return Encoded string. + */ + def hex(c: Column): Column = + builtin("HEX_ENCODE")(c) + + /** + * Wrapper for Snowflake built-in hex_decode_string function. Returns the string representation of a hexadecimal value. + * @param c Column to encode. + * @since 1.10.0 + * @return Encoded string. + */ + def unhex(c: Column): Column = + builtin("HEX_DECODE_STRING")(c) + + /** + * Return a call to the Snowflake RANDOM function. + * NOTE: Snowflake returns integers of 17-19 digits. + * @since 1.10.0 + * @return Random number. + */ + def randn(): Column = + builtin("RANDOM")() + + /** + * Calls to the Snowflake RANDOM function. + * NOTE: Snowflake returns integers of 17-19 digits. + * @since 1.10.0 + * @param seed Seed to use in the random function. + * @return Random number. + */ + def randn(seed: Long): Column = + builtin("RANDOM")(seed) + + /** + * This leverages JSON_EXTRACT_PATH_TEXT and improves functionality by allowing multiple columns + * in a single call, whereas JSON_EXTRACT_PATH_TEXT must be called once for every column. + * + * NOTE: + *
+ * df = session.createDataFrame(Seq(("CR", "{\"id\": 5, \"name\": \"Jose\", \"age\": 29}"))).toDF(Seq("nationality", "json_string")) + *+ * When the result of this function is the only part of the select statement, no changes are needed: + *
+ * df.select(json_tuple(col("json_string"), "id", "name", "age")).show() + *+ * + *
+ * ---------------------- + * |"C0" |"C1" |"C2" | + * ---------------------- + * |5 |Jose |29 | + * ---------------------- + *+ * However, when specifying multiple columns, an expression like this is required: + *
+ * df.select( + * col("nationality") + * , json_tuple(col("json_string"), "id", "name", "age"):_* // Notice the :_* syntax. + * ).show() + *+ * + *
+ * ------------------------------------------------- + * |"NATIONALITY" |"C0" |"C1" |"C2" |"C3" | + * ------------------------------------------------- + * |CR |5 |Jose |29 |Mobilize | + * ------------------------------------------------- + *+ * @since 1.10.0 + * @param json Column containing the JSON string text. + * @param fields Fields to pull from the JSON file. + * @return Column sequence with the specified strings. + */ + def json_tuple(json: Column, fields: String*): Seq[Column] = { + var i = -1 + fields.map(f => { + i += 1 + builtin("JSON_EXTRACT_PATH_TEXT")(json, f).as(s"c$i") + }) + } + + /** + * Used to calculate the cubic root of a number. + * @since 1.10.0 + * @param column Column to calculate the cubic root. + * @return Column object. + */ + def cbrt(e: Column): Column = { + builtin("CBRT")(e) + } + + /** + * Used to calculate the cubic root of a number. There were slight differences found: + * @since 1.10.0 + * @param column Column to calculate the cubic root. + * @return Column object. + */ + def cbrt(columnName: String): Column = { + cbrt(col(columnName)) + } + + /** + * This function converts a JSON string to a variant in Snowflake. + * + * In Snowflake the values are converted automatically, however they're converted as variants, meaning that the printSchema function would return different datatypes. + * To convert the datatype and it to be printed as the expected datatype, it should be read on the selectExpr function as "json['relative']['age']::integer". + *
+ * val data_for_json = Seq( + * (1, "{\"id\": 172319, \"age\": 41, \"relative\": {\"id\": 885471, \"age\": 29}}"), + * (2, "{\"id\": 532161, \"age\": 17, \"relative\":{\"id\": 873513, \"age\": 47}}") + * ) + * val data_for_json_column = Seq("col1", "col2") + * val df_for_json = session.createDataFrame(data_for_json).toDF(data_for_json_column) + * + * val json_df = df_for_json.select( + * from_json(col("col2")).as("json") + * ) + * + * json_df.selectExpr( + * "json['id']::integer as id" + * , "json['age']::integer as age" + * , "json['relative']['id']::integer as rel_id" + * , "json['relative']['age']::integer as rel_age" + * ).show(10, 10000) + *+ * + *
+ * ----------------------------------------- + * |"ID" |"AGE" |"REL_ID" |"REL_AGE" | + * ----------------------------------------- + * |172319 |41 |885471 |29 | + * |532161 |17 |873513 |47 | + * ----------------------------------------- + *+ * @since 1.10.0 + * @param e String column to convert to variant. + * @return Column object. + */ + def from_json(e: Column): Column = { + builtin("TRY_PARSE_JSON")(e) + } + + /** + * This function receives a date or timestamp, as well as a properly formatted string and subtracts the specified + * amount of days from it. If receiving a string, this string is casted to date using try_cast and if it's not possible to cast, returns null. If receiving + * a timestamp it will be casted to date (removing its time). + * @since 1.10.0 + * @param start Date, Timestamp or String column to subtract days from. + * @param days Days to subtract. + * @return Column object. + */ + def date_sub(start: Column, days: Int): Column = { + dateadd("DAY", lit(days * -1), sqlExpr(s"try_cast(${start.getName.get} :: STRING as DATE)")) + } + + /** + * This function receives a column and extracts the groupIdx from the string + * after applying the exp regex. Returns empty string when the string doesn't match and null if the input is null. + * + * This function applies the `case sensitive` and `extract` flags. It doesn't apply multiline nor .* matches newlines. + * If these flags need to be applied, use `builtin("REGEXP_SUBSTR")` instead and apply the desired flags. + * + * Note: non-greedy tokens such as `.*?` are not supported + * @since 1.10.0 + * @param colName Column to apply regex. + * @param exp Regex expression to apply. + * @param grpIdx Group to extract. + * @return Column object. + */ + def regexp_extract(colName: Column, exp: String, grpIdx: Int): Column = { + when(colName.is_null, lit(null)) + .otherwise( + coalesce( + builtin("REGEXP_SUBSTR")(colName, lit(exp), lit(1), lit(1), lit("ce"), lit(grpIdx)), + lit(""))) + } + + /** + * Returns the sign of the given column. Returns either 1 for positive, 0 for 0 or + * NaN, -1 for negative and null for null. + * NOTE: if string values are provided snowflake will attempts to cast. If it casts correctly, returns the calculation, + * if not an error will be thrown + * @since 1.10.0 + * @param e Column to calculate the sign. + * @return Column object. + */ + def signum(colName: Column): Column = { + builtin("SIGN")(colName) + } + + /** + * Returns the sign of the given column. Returns either 1 for positive, 0 for 0 or + * NaN, -1 for negative and null for null. + * NOTE: if string values are provided snowflake will attempts to cast. If it casts correctly, returns the calculation, + * if not an error will be thrown + * @since 1.10.0 + * @param columnName Name of the column to calculate the sign. + * @return Column object. + */ + def signum(columnName: String): Column = { + signum(col(columnName)) + } + + def substring_index(str: String, delim: String, count: int): Column = { + when( + lit(count) < lit(0), + callBuiltin( + "substring", + lit(str), + callBuiltin("regexp_instr", reverse(lit(str), lit(delim), 1, abs(lit(count)), lit(0)))) + .otherwise( + callBuiltin( + "substring", + lit(str), + 1, + callBuiltin("regexp_instr", lit(str), lit(delim), 1, lit(count), 1)))) + } + + /** + * Wrapper for Snowflake built-in array function. Create array from columns names. + * @since 1.10.0 + * @param s Columns names to build the array. + * @return The array. + */ + def array(colName: String, colNames: String*): Column = + array_construct((colName +: colNames).map(col): _*) + + /** + * Wrapper for Snowflake built-in collect_list function. Get the values of array column. + * @since 1.10.0 + * @param c Column to be collect. + * @return The array. + */ + def collect_list(c: Column): Column = array_agg(c) + + /** + * Wrapper for Snowflake built-in collect_list function. Get the values of array column. + * @since 1.10.0 + * @param s Column name to be collected. + * @return The array. + */ + def collect_list(s: String): Column = array_agg(col(s)) + + /** + * Wrapper for Snowflake built-in reverse function. Gets the reversed string. + * @since 1.10.0 + * @param c Column to be reverse. + * @return Column object. + */ + def reverse(c: Column): Column = + builtin("reverse")(c) + + /** + * Wrapper for Snowflake built-in isnull function. Gets a boolean depending if value is NULL or not. + * @since 1.10.0 + * @param c Column to qnalize if it is null value. + * @return Column object. + */ + def isnull(c: Column): Column = is_null(c) + + /** + * Wrapper for Snowflake built-in last function. Gets the last value of a column according to its grouping. + * @since 1.10.0 + * @param c Column to obtain last value. + * @return Column object. + */ + def last(s: String): Column = + builtin("LAST_VALUE")(col(s)) + + /** + * Wrapper for Snowflake built-in conv function. Convert number with from and to base. + * @since 1.10.0 + * @param c Column to be converted. + * @param fromBase Column from base format. + * @param toBase Column to base format. + * @return Column object. + */ + def conv(c: Column, fromBase: Int, toBase: Int): Column = + callBuiltin("conv", c, fromBase, toBase) + + /** + * Wrapper for Snowflake built-in last function. Gets the last value of a column according to its grouping. + * Functional difference with windows, In Snowpark is needed the order by. SQL doesn't guarantee the order. + * @since 1.10.0 + * @param s Column name to get last value. + * @param nulls Consider null values or not. + * @return Column object. + */ + def last(s: String, nulls: Boolean): Column = { + if (nulls) { + sqlExpr(s"LAST_VALUE(${col(s).getName.get}) IGNORE NULLS") + } else { + sqlExpr(s"LAST_VALUE(${col(s).getName.get}) RESPECT NULLS") + } + } + + /** + * Wrapper for Snowflake built-in last function. Gets the last value of a column according to its grouping. + * Functional difference with windows, In Snowpark is needed the order by. SQL doesn't guarantee the order. + * @since 1.10.0 + * @param c Column to get last value. + * @param nulls Consider null values or not. + * @return Column object. + */ + def last(c: Column, nulls: Boolean): Column = { + if (nulls) { + sqlExpr(s"LAST_VALUE(${c.getName.get}) IGNORE NULLS") + } else { + sqlExpr(s"LAST_VALUE(${c.getName.get}) RESPECT NULLS") + } + } + + /** + * Wrapper for Snowflake built-in first function. Gets the first value of a column according to its grouping. + * @since 1.10.0 + * @param c Column to obtain first value. + * @return Column object. + */ + def first(s: String): Column = + builtin("FIRST_VALUE")(col(s)) + + /** + * Wrapper for Snowflake built-in first function. Gets the first value of a column according to its grouping. + * @since 1.10.0 + * @param s Column name to get first value. + * @param nulls Consider null values or not. + * @return Column object. + */ + def first(s: String, nulls: Boolean): Column = { + if (nulls) { + sqlExpr(s"FIRST_VALUE(${col(s).getName.get}) IGNORE NULLS") + } else { + sqlExpr(s"FIRST_VALUE(${col(s).getName.get}) RESPECT NULLS") + } + } + + /** + * Wrapper for Snowflake built-in last function. Gets the last value of a column according to its grouping. + * @since 1.10.0 + * @param c Column to get last value. + * @param nulls Consider null values or not. + * @return Column object. + */ + def first(c: Column, nulls: Boolean): Column = { + if (nulls) { + sqlExpr(s"FIRST_VALUE(${c.getName.get}) IGNORE NULLS") + } else { + sqlExpr(s"FIRST_VALUE(${c.getName.get}) RESPECT NULLS") + } + } + + /** + * Returns the current Unix timestamp (in seconds) as a long. + * @since 1.10.0 + * @note All calls of `unix_timestamp` within the same query return the same value + */ + def unix_timestamp(): Column = { + builtin("date_part")("epoch_second", current_timestamp()) + } + + /** + * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), + * using the default timezone and the default locale. + * @since 1.10.0 + * @param s A date, timestamp or string. If a string, the data must be in the + * `yyyy-MM-dd HH:mm:ss` format + * @return A long, or null if the input was a string not of the correct format + */ + def unix_timestamp(s: Column): Column = { + builtin("date_part")("epoch_second", s) + } + + /** + * Converts time string with given pattern to Unix timestamp (in seconds). + * @since 1.10.0 + * @param s A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param p A date time pattern detailing the format of `s` when `s` is a string + * @return A long, or null if `s` was a string that could not be cast to a date or `p` was + * an invalid format + */ + def unix_timestamp(s: Column, p: String): Column = { + builtin("date_part")("epoch_second", to_timestamp(s, lit(p))) + } + + /** + * Wrapper for Snowflake built-in regexp_replace function. Replaces parts of a string with the specified replacement value, based on a regular expression. + * @since 1.10.0 + * @param strExpr String to apply replacement. + * @param pattern Regex pattern to find in the expression. + * @param replacement Column to replace within the string. + * @return Column object. + */ + def regexp_replace(strExpr: Column, pattern: Column, replacement: Column): Column = + builtin("regexp_replace")(strExpr, pattern, replacement) + + /** + * Wrapper for Snowflake built-in regexp_replace function. Replaces parts of a string with the specified replacement value, based on a regular expression. + * @since 1.10.0 + * @param strExpr String to apply replacement. + * @param pattern Regex pattern to find in the expression. + * @param replacement Column to replace within the string. + * @return Column object. + */ + def regexp_replace(strExpr: Column, pattern: String, replacement: String): Column = { + builtin("regexp_replace")(strExpr, pattern, replacement) + } + + /** + * Returns the date that is `days` days after `start` + * @since 1.10.0 + * @param start A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days The number of days to add to `start`, can be negative to subtract days + * @return A date, or null if `start` was a string that could not be cast to a date + */ + def date_add(start: Column, days: Int): Column = dateadd("day", lit(days), start) + + /** + * Returns the date that is `days` days after `start` + * @since 1.10.0 + * @param start A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days The number of days to add to `start`, can be negative to subtract days + * @return A date, or null if `start` was a string that could not be cast to a date + */ + def date_add(start: Column, days: Column): Column = dateadd("day", days, start) + + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * @since 1.10.0 + * @param e The column to collect the list values + * @return A list with unique values + */ + def collect_set(e: Column): Column = sqlExpr(s"array_agg(distinct ${e.getName.get})") + + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * @since 1.10.0 + * @param e The column to collect the list values + * @return A list with unique values + */ + def collect_set(e: String): Column = sqlExpr(s"array_agg(distinct ${e})") + + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the + * yyyy-MM-dd HH:mm:ss format. + * @since 1.10.0 + * @param ut A number of a type that is castable to a long, such as string or integer. Can be + * negative for timestamps before the unix epoch + * @return A string, or null if the input was a string that could not be cast to a long + */ + def from_unixtime(ut: Column): Column = + ut.cast(LongType).cast(TimestampType).cast(StringType) + + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @since 1.10.0 + * @param ut A number of a type that is castable to a long, such as string or integer. Can be + * negative for timestamps before the unix epoch + * @param f A date time pattern that the input will be formatted to + * @return A string, or null if `ut` was a string that could not be cast to a long or `f` was + * an invalid date time pattern + */ + def from_unixtime(ut: Column, f: String): Column = + date_format(ut.cast(LongType).cast(TimestampType), f) + + /** + * A column expression that generates monotonically increasing 64-bit integers. + * @since 1.10.0 + */ + def monotonically_increasing_id(): Column = builtin("seq8")() + + /** + * Returns number of months between dates `start` and `end`. + * + * A whole number is returned if both inputs have the same day of month or both are the last day + * of their respective months. Otherwise, the difference is calculated assuming 31 days per month. + * + * For example: + * {{{ + * months_between("2017-11-14", "2017-07-14") // returns 4.0 + * months_between("2017-01-01", "2017-01-10") // returns 0.29032258 + * months_between("2017-06-01", "2017-06-16 12:00:00") // returns -0.5 + * }}} + * @since 1.10.0 + * @param end A date, timestamp or string. If a string, the data must be in a format that can + * be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param start A date, timestamp or string. If a string, the data must be in a format that can + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A double, or null if either `end` or `start` were strings that could not be cast to a + * timestamp. Negative if `end` is before `start` + */ + def months_between(end: Column, start: Column): Column = builtin("MONTHS_BETWEEN")(start, end) + + /** + * Locate the position of the first occurrence of substr column in the given string. + * Returns null if either of the arguments are null. + * @since 1.10.0 + * @note The position is not zero based, but 1 based index. Returns 0 if substr + * could not be found in str. + */ + def instr(str: Column, substring: String): Column = builtin("REGEXP_INSTR")(str, substring) + + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield + * '2017-07-14 03:40:00.0'. + * @since 1.10.0 + * @param ts A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param tz A string detailing the time zone ID that the input should be adjusted to. It should + * be in the format of either region-based zone IDs or zone offsets. Region IDs must + * have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in + * the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are + * supported as aliases of '+00:00'. Other short names are not recommended to use + * because they can be ambiguous. + * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or + * `tz` was an invalid value + */ + def from_utc_timestamp(ts: Column, tz: String): Column = + builtin("TO_TIMESTAMP_TZ")(ts, tz) + + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield + * '2017-07-14 03:40:00.0'. + * @since 1.10.0 + * @param ts A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param tz A string detailing the time zone ID that the input should be adjusted to. It should + * be in the format of either region-based zone IDs or zone offsets. Region IDs must + * have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in + * the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are + * supported as aliases of '+00:00'. Other short names are not recommended to use + * because they can be ambiguous. + * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or + * `tz` was an invalid value + */ + def from_utc_timestamp(ts: Column, tz: Column): Column = + builtin("TO_TIMESTAMP_TZ")(ts, tz) + + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time + * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield + * '2017-07-14 01:40:00.0'. + * @since 1.10.0 + * @param ts A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param tz A string detailing the time zone ID that the input should be adjusted to. It should + * be in the format of either region-based zone IDs or zone offsets. Region IDs must + * have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in + * the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are + * supported as aliases of '+00:00'. Other short names are not recommended to use + * because they can be ambiguous. + * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or + * `tz` was an invalid value + */ + def to_utc_timestamp(ts: Column, tz: String): Column = builtin("TO_TIMESTAMP_TZ")(ts, tz) + + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time + * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield + * '2017-07-14 01:40:00.0'. + * @since 1.10.0 + * @param ts A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param tz A string detailing the time zone ID that the input should be adjusted to. It should + * be in the format of either region-based zone IDs or zone offsets. Region IDs must + * have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in + * the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are + * supported as aliases of '+00:00'. Other short names are not recommended to use + * because they can be ambiguous. + * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or + * `tz` was an invalid value + */ + def to_utc_timestamp(ts: Column, tz: Column): Column = builtin("TO_TIMESTAMP_TZ")(ts, tz) + + /** + * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places + * with HALF_EVEN round mode, and returns the result as a string column. + * @since 1.10.0 + * If d is 0, the result has no decimal point or fractional part. + * If d is less than 0, the result will be null. + * + * @param x numeric column to be transformed + * @param d Amount of decimal for the number format + * + * @return Number casted to the specific string format + */ + def format_number(x: Column, d: Int): Column = { + if (d < 0) { + lit(null) + } else { + builtin("TO_VARCHAR")(x, if (d > 0) s"999,999.${"0" * d}" else "999,999") + } + } + + /** + * Computes the logarithm of the given column in base 2. + * @since 1.10.0 + * @param expr Column to apply this mathematical operation + * @return log2 of the given column + */ + def log2(expr: Column): Column = builtin("LOG")(2, expr) + + /** + * Computes the logarithm of the given column in base 2. + * + * @param columnName Column to apply this mathematical operation + * + * @return log2 of the given column + */ + def log2(columnName: String): Column = builtin("LOG")(2, col(columnName)) + + /** + * Returns element of array at given index in value if column is array. Mostly and overload for snowpark get_path + * @see Snowpark get_path + */ + def element_at(column: Column, index: int): Column = { + com.snowflake.snowpark.functions.get_path(column, lit(i)) + } + + /** + * Returns element of array at given index in value if column is array. Mostly and overload for snowpark get_path + * @see Snowpark get_path + */ + def element_at(column: Column, index: Column): Column = { + com.snowflake.snowpark.functions.get_path(column, c) + } + }