From fd8224e1d32282c1664575f32a2eb5d39ffaabba Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 9 Sep 2024 14:13:21 -0400 Subject: [PATCH] Update Aggregate functions to take builder parameters (#859) * Add NullTreatment enum wrapper and add filter option to approx_distinct * Small usability on aggregate * Adding documentation and additional unit test for approx_median * Update approx_percentil_cont with builder parameters it uses, which is filter but not distinct * Update approx_percentil_cont_with_weight with builder parameters it uses, which is filter but not distinct * Update array_agg to use aggregate options * Update builder options for avg aggregate function * move bit_and bit_or to use macro to generaty python fn * Update builder arguments for bitwise operators * Use macro for bool_and and bool_or * Update python wrapper for arguments appropriate to bool operators * Set corr to use macro for pyfunction * Update unit test to make it easier to debug * Update corr python wrapper to expose only builder parameters used * Update count and count_star to use macro for exposing * Update count and count_star with approprate aggregation options * Move covar_pop and covar_samp to use macro for aggregates * Updateing covar_pop and covar_samp with builder option * Use macro for last_value and move first_value to be near it * Update first_value and last_value with the builder parameters that are relevant * Remove grouping since it is not actually implemented upstream * Move median to use macro * Expose builder options for median * Expose nth value * Updating linear regression functions to use filter and macro * Update stddev and stddev_pop to use filter and macro * Expose string_agg * Add string_agg to python wrappers and add unit test * Switch sum to use macro in rust side and expose correct options in python wrapper * Use macro for exposing var_pop and var_samp * Add unit tests for filtering on var_pop and var_samp * Move approximation functions to use macro when possible * Update user documentation to explain in detail the options for aggregate functions * Update unit test to handle Python 3.10 * Clean up commented code --- .../common-operations/aggregations.rst | 206 ++++- python/datafusion/common.py | 15 +- python/datafusion/dataframe.py | 7 +- python/datafusion/expr.py | 4 +- python/datafusion/functions.py | 867 ++++++++++++++---- python/datafusion/tests/test_aggregation.py | 359 +++++++- python/datafusion/tests/test_functions.py | 99 +- .../datafusion/tests/test_wrapper_coverage.py | 11 + src/functions.rs | 521 ++++------- 9 files changed, 1470 insertions(+), 619 deletions(-) diff --git a/docs/source/user-guide/common-operations/aggregations.rst b/docs/source/user-guide/common-operations/aggregations.rst index 7ad40221..8fee26a1 100644 --- a/docs/source/user-guide/common-operations/aggregations.rst +++ b/docs/source/user-guide/common-operations/aggregations.rst @@ -20,43 +20,205 @@ Aggregation ============ -An aggregate or aggregation is a function where the values of multiple rows are processed together to form a single summary value. -For performing an aggregation, DataFusion provides the :py:func:`~datafusion.dataframe.DataFrame.aggregate` +An aggregate or aggregation is a function where the values of multiple rows are processed together +to form a single summary value. For performing an aggregation, DataFusion provides the +:py:func:`~datafusion.dataframe.DataFrame.aggregate` .. ipython:: python + import urllib.request from datafusion import SessionContext - from datafusion import column, lit + from datafusion import col, lit from datafusion import functions as f - import random - ctx = SessionContext() - df = ctx.from_pydict( - { - "a": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], - "b": ["one", "one", "two", "three", "two", "two", "one", "three"], - "c": [random.randint(0, 100) for _ in range(8)], - "d": [random.random() for _ in range(8)], - }, - name="foo_bar" + urllib.request.urlretrieve( + "https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv", + "pokemon.csv", ) - col_a = column("a") - col_b = column("b") - col_c = column("c") - col_d = column("d") + ctx = SessionContext() + df = ctx.read_csv("pokemon.csv") + + col_type_1 = col('"Type 1"') + col_type_2 = col('"Type 2"') + col_speed = col('"Speed"') + col_attack = col('"Attack"') - df.aggregate([], [f.approx_distinct(col_c), f.approx_median(col_d), f.approx_percentile_cont(col_d, lit(0.5))]) + df.aggregate([col_type_1], [ + f.approx_distinct(col_speed).alias("Count"), + f.approx_median(col_speed).alias("Median Speed"), + f.approx_percentile_cont(col_speed, 0.9).alias("90% Speed")]) -When the :code:`group_by` list is empty the aggregation is done over the whole :class:`.DataFrame`. For grouping -the :code:`group_by` list must contain at least one column +When the :code:`group_by` list is empty the aggregation is done over the whole :class:`.DataFrame`. +For grouping the :code:`group_by` list must contain at least one column. .. ipython:: python - df.aggregate([col_a], [f.sum(col_c), f.max(col_d), f.min(col_d)]) + df.aggregate([col_type_1], [ + f.max(col_speed).alias("Max Speed"), + f.avg(col_speed).alias("Avg Speed"), + f.min(col_speed).alias("Min Speed")]) More than one column can be used for grouping .. ipython:: python - df.aggregate([col_a, col_b], [f.sum(col_c), f.max(col_d), f.min(col_d)]) + df.aggregate([col_type_1, col_type_2], [ + f.max(col_speed).alias("Max Speed"), + f.avg(col_speed).alias("Avg Speed"), + f.min(col_speed).alias("Min Speed")]) + + + +Setting Parameters +------------------ + +Each of the built in aggregate functions provides arguments for the parameters that affect their +operation. These can also be overridden using the builder approach to setting any of the following +parameters. When you use the builder, you must call ``build()`` to finish. For example, these two +expressions are equivalent. + +.. ipython:: python + + first_1 = f.first_value(col("a"), order_by=[col("a")]) + first_2 = f.first_value(col("a")).order_by(col("a")).build() + +Ordering +^^^^^^^^ + +You can control the order in which rows are processed by window functions by providing +a list of ``order_by`` functions for the ``order_by`` parameter. In the following example, we +sort the Pokemon by their attack in increasing order and take the first value, which gives us the +Pokemon with the smallest attack value in each ``Type 1``. + +.. ipython:: python + + df.aggregate( + [col('"Type 1"')], + [f.first_value( + col('"Name"'), + order_by=[col('"Attack"').sort(ascending=True)] + ).alias("Smallest Attack") + ]) + +Distinct +^^^^^^^^ + +When you set the parameter ``distinct`` to ``True``, then unique values will only be evaluated one +time each. Suppose we want to create an array of all of the ``Type 2`` for each ``Type 1`` of our +Pokemon set. Since there will be many entries of ``Type 2`` we only one each distinct value. + +.. ipython:: python + + df.aggregate([col_type_1], [f.array_agg(col_type_2, distinct=True).alias("Type 2 List")]) + +In the output of the above we can see that there are some ``Type 1`` for which the ``Type 2`` entry +is ``null``. In reality, we probably want to filter those out. We can do this in two ways. First, +we can filter DataFrame rows that have no ``Type 2``. If we do this, we might have some ``Type 1`` +entries entirely removed. The second is we can use the ``filter`` argument described below. + +.. ipython:: python + + df.filter(col_type_2.is_not_null()).aggregate([col_type_1], [f.array_agg(col_type_2, distinct=True).alias("Type 2 List")]) + + df.aggregate([col_type_1], [f.array_agg(col_type_2, distinct=True, filter=col_type_2.is_not_null()).alias("Type 2 List")]) + +Which approach you take should depend on your use case. + +Null Treatment +^^^^^^^^^^^^^^ + +This option allows you to either respect or ignore null values. + +One common usage for handling nulls is the case where you want to find the first value within a +partition. By setting the null treatment to ignore nulls, we can find the first non-null value +in our partition. + + +.. ipython:: python + + from datafusion.common import NullTreatment + + df.aggregate([col_type_1], [ + f.first_value( + col_type_2, + order_by=[col_attack], + null_treatment=NullTreatment.RESPECT_NULLS + ).alias("Lowest Attack Type 2")]) + + df.aggregate([col_type_1], [ + f.first_value( + col_type_2, + order_by=[col_attack], + null_treatment=NullTreatment.IGNORE_NULLS + ).alias("Lowest Attack Type 2")]) + +Filter +^^^^^^ + +Using the filter option is useful for filtering results to include in the aggregate function. It can +be seen in the example above on how this can be useful to only filter rows evaluated by the +aggregate function without filtering rows from the entire DataFrame. + +Filter takes a single expression. + +Suppose we want to find the speed values for only Pokemon that have low Attack values. + +.. ipython:: python + + df.aggregate([col_type_1], [ + f.avg(col_speed).alias("Avg Speed All"), + f.avg(col_speed, filter=col_attack < lit(50)).alias("Avg Speed Low Attack")]) + + +Aggregate Functions +------------------- + +The available aggregate functions are: + +1. Comparison Functions + - :py:func:`datafusion.functions.min` + - :py:func:`datafusion.functions.max` +2. Math Functions + - :py:func:`datafusion.functions.sum` + - :py:func:`datafusion.functions.avg` + - :py:func:`datafusion.functions.median` +3. Array Functions + - :py:func:`datafusion.functions.array_agg` +4. Logical Functions + - :py:func:`datafusion.functions.bit_and` + - :py:func:`datafusion.functions.bit_or` + - :py:func:`datafusion.functions.bit_xor` + - :py:func:`datafusion.functions.bool_and` + - :py:func:`datafusion.functions.bool_or` +5. Statistical Functions + - :py:func:`datafusion.functions.count` + - :py:func:`datafusion.functions.corr` + - :py:func:`datafusion.functions.covar_samp` + - :py:func:`datafusion.functions.covar_pop` + - :py:func:`datafusion.functions.stddev` + - :py:func:`datafusion.functions.stddev_pop` + - :py:func:`datafusion.functions.var_samp` + - :py:func:`datafusion.functions.var_pop` +6. Linear Regression Functions + - :py:func:`datafusion.functions.regr_count` + - :py:func:`datafusion.functions.regr_slope` + - :py:func:`datafusion.functions.regr_intercept` + - :py:func:`datafusion.functions.regr_r2` + - :py:func:`datafusion.functions.regr_avgx` + - :py:func:`datafusion.functions.regr_avgy` + - :py:func:`datafusion.functions.regr_sxx` + - :py:func:`datafusion.functions.regr_syy` + - :py:func:`datafusion.functions.regr_slope` +7. Positional Functions + - :py:func:`datafusion.functions.first_value` + - :py:func:`datafusion.functions.last_value` + - :py:func:`datafusion.functions.nth_value` +8. String Functions + - :py:func:`datafusion.functions.string_agg` +9. Approximation Functions + - :py:func:`datafusion.functions.approx_distinct` + - :py:func:`datafusion.functions.approx_median` + - :py:func:`datafusion.functions.approx_percentile_cont` + - :py:func:`datafusion.functions.approx_percentile_cont_with_weight` + diff --git a/python/datafusion/common.py b/python/datafusion/common.py index 225e3330..7db8333f 100644 --- a/python/datafusion/common.py +++ b/python/datafusion/common.py @@ -17,13 +17,13 @@ """Common data types used throughout the DataFusion project.""" from ._internal import common as common_internal +from enum import Enum # TODO these should all have proper wrapper classes DFSchema = common_internal.DFSchema DataType = common_internal.DataType DataTypeMap = common_internal.DataTypeMap -NullTreatment = common_internal.NullTreatment PythonType = common_internal.PythonType RexType = common_internal.RexType SqlFunction = common_internal.SqlFunction @@ -47,3 +47,16 @@ "SqlStatistics", "SqlFunction", ] + + +class NullTreatment(Enum): + """Describe how null values are to be treated by functions. + + This is used primarily by aggregate and window functions. It can be set on + these functions using the builder approach described in + ref:`_window_functions` and ref:`_aggregation` in the online documentation. + + """ + + RESPECT_NULLS = common_internal.NullTreatment.RESPECT_NULLS + IGNORE_NULLS = common_internal.NullTreatment.IGNORE_NULLS diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 46b8fa1b..56dff22a 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -180,7 +180,9 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame: """ return DataFrame(self.df.with_column_renamed(old_name, new_name)) - def aggregate(self, group_by: list[Expr], aggs: list[Expr]) -> DataFrame: + def aggregate( + self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr + ) -> DataFrame: """Aggregates the rows of the current DataFrame. Args: @@ -190,6 +192,9 @@ def aggregate(self, group_by: list[Expr], aggs: list[Expr]) -> DataFrame: Returns: DataFrame after aggregation. """ + group_by = group_by if isinstance(group_by, list) else [group_by] + aggs = aggs if isinstance(aggs, list) else [aggs] + group_by = [e.expr for e in group_by] aggs = [e.expr for e in aggs] return DataFrame(self.df.aggregate(group_by, aggs)) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 7fa60803..bd6a86fb 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -473,7 +473,7 @@ def null_treatment(self, null_treatment: NullTreatment) -> ExprFuncBuilder: set parameters for either window or aggregate functions. If used on any other type of expression, an error will be generated when ``build()`` is called. """ - return ExprFuncBuilder(self.expr.null_treatment(null_treatment)) + return ExprFuncBuilder(self.expr.null_treatment(null_treatment.value)) def partition_by(self, *partition_by: Expr) -> ExprFuncBuilder: """Set the partitioning for a window function. @@ -518,7 +518,7 @@ def distinct(self) -> ExprFuncBuilder: def null_treatment(self, null_treatment: NullTreatment) -> ExprFuncBuilder: """Set how nulls are treated for either window or aggregate functions.""" - return ExprFuncBuilder(self.builder.null_treatment(null_treatment)) + return ExprFuncBuilder(self.builder.null_treatment(null_treatment.value)) def partition_by(self, *partition_by: Expr) -> ExprFuncBuilder: """Set partitioning for window functions.""" diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 97b4fe1d..163ff04e 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -18,9 +18,10 @@ from __future__ import annotations -from datafusion._internal import functions as f, common +from datafusion._internal import functions as f, expr as expr_internal from datafusion.expr import CaseBuilder, Expr, WindowFrame from datafusion.context import SessionContext +from datafusion.common import NullTreatment from typing import Any, Optional @@ -126,7 +127,6 @@ "floor", "from_unixtime", "gcd", - "grouping", "in_list", "initcap", "isnan", @@ -180,6 +180,7 @@ "named_struct", "nanvl", "now", + "nth_value", "nullif", "octet_length", "order_by", @@ -222,6 +223,7 @@ "stddev", "stddev_pop", "stddev_samp", + "string_agg", "strpos", "struct", "substr", @@ -244,6 +246,7 @@ "var", "var_pop", "var_samp", + "var_sample", "when", # Window Functions "window", @@ -258,6 +261,12 @@ ] +def expr_list_to_raw_expr_list( + expr_list: Optional[list[Expr]], +) -> Optional[list[expr_internal.Expr]]: + return [e.expr for e in expr_list] if expr_list is not None else None + + def isnan(expr: Expr) -> Expr: """Returns true if a given number is +NaN or -NaN otherwise returns false.""" return Expr(f.isnan(expr.expr)) @@ -358,9 +367,18 @@ def col(name: str) -> Expr: return Expr(f.col(name)) -def count_star() -> Expr: - """Create a COUNT(1) aggregate expression.""" - return Expr(f.count_star()) +def count_star(filter: Optional[Expr] = None) -> Expr: + """Create a COUNT(1) aggregate expression. + + This aggregate function will count all of the rows in the partition. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``distinct``, and ``null_treatment``. + + Args: + filter: If provided, only count rows for which the filter is True + """ + return count(Expr.literal(1), filter=filter) def case(expr: Expr) -> CaseBuilder: @@ -400,8 +418,8 @@ def window( df.select(functions.lag(col("a")).partition_by(col("b")).build()) """ args = [a.expr for a in args] - partition_by = [e.expr for e in partition_by] if partition_by is not None else None - order_by = [o.expr for o in order_by] if order_by is not None else None + partition_by = expr_list_to_raw_expr_list(partition_by) + order_by = expr_list_to_raw_expr_list(order_by) window_frame = window_frame.window_frame if window_frame is not None else None return Expr(f.window(name, args, partition_by, order_by, window_frame, ctx)) @@ -1486,291 +1504,788 @@ def flatten(array: Expr) -> Expr: # aggregate functions -def approx_distinct(expression: Expr) -> Expr: - """Returns the approximate number of distinct values.""" - return Expr(f.approx_distinct(expression.expr)) +def approx_distinct( + expression: Expr, + filter: Optional[Expr] = None, +) -> Expr: + """Returns the approximate number of distinct values. + + This aggregate function is similar to :py:func:`count` with distinct set, but it + will approximate the number of distinct entries. It may return significantly faster + than :py:func:`count` for some DataFrames. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + expression: Values to check for distinct entries + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + + return Expr(f.approx_distinct(expression.expr, filter=filter_raw)) + +def approx_median(expression: Expr, filter: Optional[Expr] = None) -> Expr: + """Returns the approximate median value. -def approx_median(arg: Expr, distinct: bool = False) -> Expr: - """Returns the approximate median value.""" - return Expr(f.approx_median(arg.expr, distinct=distinct)) + This aggregate function is similar to :py:func:`median`, but it will only + approximate the median. It may return significantly faster for some DataFrames. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by`` and ``null_treatment``, and ``distinct``. + + Args: + expression: Values to find the median for + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.approx_median(expression.expr, filter=filter_raw)) def approx_percentile_cont( expression: Expr, - percentile: Expr, - num_centroids: Expr | None = None, - distinct: bool = False, + percentile: float, + num_centroids: Optional[int] = None, + filter: Optional[Expr] = None, ) -> Expr: - """Returns the value that is approximately at a given percentile of ``expr``.""" - if num_centroids is None: - return Expr( - f.approx_percentile_cont( - expression.expr, percentile.expr, distinct=distinct, num_centroids=None - ) - ) + """Returns the value that is approximately at a given percentile of ``expr``. + + This aggregate function assumes the input values form a continuous distribution. + Suppose you have a DataFrame which consists of 100 different test scores. If you + called this function with a percentile of 0.9, it would return the value of the + test score that is above 90% of the other test scores. The returned value may be + between two of the values. + + This function uses the [t-digest](https://arxiv.org/abs/1902.04023) algorithm to + compute the percentil. You can limit the number of bins used in this algorithm by + setting the ``num_centroids`` parameter. + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + expression: Values for which to find the approximate percentile + percentile: This must be between 0.0 and 1.0, inclusive + num_centroids: Max bin size for the t-digest algorithm + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None return Expr( f.approx_percentile_cont( - expression.expr, - percentile.expr, - distinct=distinct, - num_centroids=num_centroids.expr, + expression.expr, percentile, num_centroids=num_centroids, filter=filter_raw ) ) def approx_percentile_cont_with_weight( - arg: Expr, weight: Expr, percentile: Expr, distinct: bool = False + expression: Expr, weight: Expr, percentile: float, filter: Optional[Expr] = None ) -> Expr: - """Returns the value of the approximate percentile. + """Returns the value of the weighted approximate percentile. + + This aggregate function is similar to :py:func:`approx_percentile_cont` except that + it uses the associated associated weights. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + expression: Values for which to find the approximate percentile + weight: Relative weight for each of the values in ``expression`` + percentile: This must be between 0.0 and 1.0, inclusive + filter: If provided, only compute against rows for which the filter is True - This function is similar to :py:func:`approx_percentile_cont` except that it uses - the associated associated weights. """ + filter_raw = filter.expr if filter is not None else None return Expr( f.approx_percentile_cont_with_weight( - arg.expr, weight.expr, percentile.expr, distinct=distinct + expression.expr, weight.expr, percentile, filter=filter_raw ) ) -def array_agg(arg: Expr, distinct: bool = False) -> Expr: - """Aggregate values into an array.""" - return Expr(f.array_agg(arg.expr, distinct=distinct)) +def array_agg( + expression: Expr, + distinct: bool = False, + filter: Optional[Expr] = None, + order_by: Optional[list[Expr]] = None, +) -> Expr: + """Aggregate values into an array. + Currently ``distinct`` and ``order_by`` cannot be used together. As a work around, + consider :py:func:`array_sort` after aggregation. + [Issue Tracker](https://github.com/apache/datafusion/issues/12371) -def avg(arg: Expr, distinct: bool = False) -> Expr: - """Returns the average value.""" - return Expr(f.avg(arg.expr, distinct=distinct)) + If using the builder functions described in ref:`_aggregation` this function ignores + the option ``null_treatment``. + Args: + expression: Values to combine into an array + distinct: If True, a single entry for each distinct value will be in the result + filter: If provided, only compute against rows for which the filter is True + order_by: Order the resultant array values + """ + order_by_raw = expr_list_to_raw_expr_list(order_by) + filter_raw = filter.expr if filter is not None else None + + return Expr( + f.array_agg( + expression.expr, distinct=distinct, filter=filter_raw, order_by=order_by_raw + ) + ) -def corr(value1: Expr, value2: Expr, distinct: bool = False) -> Expr: - """Returns the correlation coefficient between ``value1`` and ``value2``.""" - return Expr(f.corr(value1.expr, value2.expr, distinct=distinct)) +def avg( + expression: Expr, + filter: Optional[Expr] = None, +) -> Expr: + """Returns the average value. -def count(args: Expr | list[Expr] | None = None, distinct: bool = False) -> Expr: - """Returns the number of rows that match the given arguments.""" - if args is None: - return count(Expr.literal(1), distinct=distinct) - if isinstance(args, list): - args = [arg.expr for arg in args] - elif isinstance(args, Expr): - args = [args.expr] - return Expr(f.count(*args, distinct=distinct)) + This aggregate function expects a numeric expression and will return a float. + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. -def covar(y: Expr, x: Expr) -> Expr: - """Computes the sample covariance. + Args: + expression: Values to combine into an array + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.avg(expression.expr, filter=filter_raw)) - This is an alias for :py:func:`covar_samp`. + +def corr(value_y: Expr, value_x: Expr, filter: Optional[Expr] = None) -> Expr: + """Returns the correlation coefficient between ``value1`` and ``value2``. + + This aggregate function expects both values to be numeric and will return a float. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + value_y: The dependent variable for correlation + value_x: The independent variable for correlation + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.corr(value_y.expr, value_x.expr, filter=filter_raw)) + + +def count( + expressions: Expr | list[Expr] | None = None, + distinct: bool = False, + filter: Optional[Expr] = None, +) -> Expr: + """Returns the number of rows that match the given arguments. + + This aggregate function will count the non-null rows provided in the expression. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by`` and ``null_treatment``. + + Args: + expressions: Argument to perform bitwise calculation on + distinct: If True, a single entry for each distinct value will be in the result + filter: If provided, only compute against rows for which the filter is True """ - return covar_samp(y, x) + filter_raw = filter.expr if filter is not None else None + if expressions is None: + args = [Expr.literal(1).expr] + elif isinstance(expressions, list): + args = [arg.expr for arg in expressions] + else: + args = [expressions.expr] -def covar_pop(y: Expr, x: Expr) -> Expr: - """Computes the population covariance.""" - return Expr(f.covar_pop(y.expr, x.expr)) + return Expr(f.count(*args, distinct=distinct, filter=filter_raw)) -def covar_samp(y: Expr, x: Expr) -> Expr: - """Computes the sample covariance.""" - return Expr(f.covar_samp(y.expr, x.expr)) +def covar_pop(value_y: Expr, value_x: Expr, filter: Optional[Expr] = None) -> Expr: + """Computes the population covariance. + This aggregate function expects both values to be numeric and will return a float. -def grouping(arg: Expr, distinct: bool = False) -> Expr: - """Indicates if the expression is aggregated or not. + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. - Returns 1 if the value of the argument is aggregated, 0 if not. + Args: + value_y: The dependent variable for covariance + value_x: The independent variable for covariance + filter: If provided, only compute against rows for which the filter is True """ - return Expr(f.grouping(arg.expr, distinct=distinct)) + filter_raw = filter.expr if filter is not None else None + return Expr(f.covar_pop(value_y.expr, value_x.expr, filter=filter_raw)) -def max(arg: Expr, distinct: bool = False) -> Expr: - """Returns the maximum value of the argument.""" - return Expr(f.max(arg.expr, distinct=distinct)) +def covar_samp(value_y: Expr, value_x: Expr, filter: Optional[Expr] = None) -> Expr: + """Computes the sample covariance. + This aggregate function expects both values to be numeric and will return a float. -def mean(arg: Expr, distinct: bool = False) -> Expr: + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + value_y: The dependent variable for covariance + value_x: The independent variable for covariance + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.covar_samp(value_y.expr, value_x.expr, filter=filter_raw)) + + +def covar(value_y: Expr, value_x: Expr, filter: Optional[Expr] = None) -> Expr: + """Computes the sample covariance. + + This is an alias for :py:func:`covar_samp`. + """ + return covar_samp(value_y, value_x, filter) + + +def max(expression: Expr, filter: Optional[Expr] = None) -> Expr: + """Aggregate function that returns the maximum value of the argument. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + expression: The value to find the maximum of + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.max(expression.expr, filter=filter_raw)) + + +def mean(expression: Expr, filter: Optional[Expr] = None) -> Expr: """Returns the average (mean) value of the argument. This is an alias for :py:func:`avg`. """ - return avg(arg, distinct) + return avg(expression, filter) + + +def median( + expression: Expr, distinct: bool = False, filter: Optional[Expr] = None +) -> Expr: + """Computes the median of a set of numbers. + + This aggregate function returns the median value of the expression for the given + aggregate function. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by`` and ``null_treatment``. + + Args: + expression: The value to compute the median of + distinct: If True, a single entry for each distinct value will be in the result + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.median(expression.expr, distinct=distinct, filter=filter_raw)) + +def min(expression: Expr, filter: Optional[Expr] = None) -> Expr: + """Returns the minimum value of the argument. -def median(arg: Expr) -> Expr: - """Computes the median of a set of numbers.""" - return Expr(f.median(arg.expr)) + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + Args: + expression: The value to find the minimum of + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.min(expression.expr, filter=filter_raw)) -def min(arg: Expr, distinct: bool = False) -> Expr: - """Returns the minimum value of the argument.""" - return Expr(f.min(arg.expr, distinct=distinct)) +def sum( + expression: Expr, + filter: Optional[Expr] = None, +) -> Expr: + """Computes the sum of a set of numbers. -def sum(arg: Expr) -> Expr: - """Computes the sum of a set of numbers.""" - return Expr(f.sum(arg.expr)) + This aggregate function expects a numeric expression. + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. -def stddev(arg: Expr, distinct: bool = False) -> Expr: - """Computes the standard deviation of the argument.""" - return Expr(f.stddev(arg.expr, distinct=distinct)) + Args: + expression: Values to combine into an array + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.sum(expression.expr, filter=filter_raw)) -def stddev_pop(arg: Expr, distinct: bool = False) -> Expr: - """Computes the population standard deviation of the argument.""" - return Expr(f.stddev_pop(arg.expr, distinct=distinct)) +def stddev(expression: Expr, filter: Optional[Expr] = None) -> Expr: + """Computes the standard deviation of the argument. + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. -def stddev_samp(arg: Expr, distinct: bool = False) -> Expr: + Args: + expression: The value to find the minimum of + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.stddev(expression.expr, filter=filter_raw)) + + +def stddev_pop(expression: Expr, filter: Optional[Expr] = None) -> Expr: + """Computes the population standard deviation of the argument. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + expression: The value to find the minimum of + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.stddev_pop(expression.expr, filter=filter_raw)) + + +def stddev_samp(arg: Expr, filter: Optional[Expr] = None) -> Expr: """Computes the sample standard deviation of the argument. This is an alias for :py:func:`stddev`. """ - return stddev(arg, distinct) + return stddev(arg, filter=filter) -def var(arg: Expr) -> Expr: +def var(expression: Expr, filter: Optional[Expr] = None) -> Expr: """Computes the sample variance of the argument. This is an alias for :py:func:`var_samp`. """ - return var_samp(arg) + return var_samp(expression, filter) + + +def var_pop(expression: Expr, filter: Optional[Expr] = None) -> Expr: + """Computes the population variance of the argument. + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. -def var_pop(arg: Expr, distinct: bool = False) -> Expr: - """Computes the population variance of the argument.""" - return Expr(f.var_pop(arg.expr, distinct=distinct)) + Args: + expression: The variable to compute the variance for + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.var_pop(expression.expr, filter=filter_raw)) -def var_samp(arg: Expr) -> Expr: - """Computes the sample variance of the argument.""" - return Expr(f.var_samp(arg.expr)) +def var_samp(expression: Expr, filter: Optional[Expr] = None) -> Expr: + """Computes the sample variance of the argument. + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. -def regr_avgx(y: Expr, x: Expr, distinct: bool = False) -> Expr: + Args: + expression: The variable to compute the variance for + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.var_sample(expression.expr, filter=filter_raw)) + + +def var_sample(expression: Expr, filter: Optional[Expr] = None) -> Expr: + """Computes the sample variance of the argument. + + This is an alias for :py:func:`var_samp`. + """ + return var_samp(expression, filter) + + +def regr_avgx( + y: Expr, + x: Expr, + filter: Optional[Expr] = None, +) -> Expr: """Computes the average of the independent variable ``x``. - Only non-null pairs of the inputs are evaluated. + This is a linear regression aggregate function. Only non-null pairs of the inputs + are evaluated. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + y: The linear regression dependent variable + x: The linear regression independent variable + filter: If provided, only compute against rows for which the filter is True """ - return Expr(f.regr_avgx(y.expr, x.expr, distinct)) + filter_raw = filter.expr if filter is not None else None + + return Expr(f.regr_avgx(y.expr, x.expr, filter=filter_raw)) -def regr_avgy(y: Expr, x: Expr, distinct: bool = False) -> Expr: +def regr_avgy( + y: Expr, + x: Expr, + filter: Optional[Expr] = None, +) -> Expr: """Computes the average of the dependent variable ``y``. - Only non-null pairs of the inputs are evaluated. + This is a linear regression aggregate function. Only non-null pairs of the inputs + are evaluated. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + y: The linear regression dependent variable + x: The linear regression independent variable + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + + return Expr(f.regr_avgy(y.expr, x.expr, filter=filter_raw)) + + +def regr_count( + y: Expr, + x: Expr, + filter: Optional[Expr] = None, +) -> Expr: + """Counts the number of rows in which both expressions are not null. + + This is a linear regression aggregate function. Only non-null pairs of the inputs + are evaluated. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + y: The linear regression dependent variable + x: The linear regression independent variable + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + + return Expr(f.regr_count(y.expr, x.expr, filter=filter_raw)) + + +def regr_intercept( + y: Expr, + x: Expr, + filter: Optional[Expr] = None, +) -> Expr: + """Computes the intercept from the linear regression. + + This is a linear regression aggregate function. Only non-null pairs of the inputs + are evaluated. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + y: The linear regression dependent variable + x: The linear regression independent variable + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + + return Expr(f.regr_intercept(y.expr, x.expr, filter=filter_raw)) + + +def regr_r2( + y: Expr, + x: Expr, + filter: Optional[Expr] = None, +) -> Expr: + """Computes the R-squared value from linear regression. + + This is a linear regression aggregate function. Only non-null pairs of the inputs + are evaluated. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + y: The linear regression dependent variable + x: The linear regression independent variable + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + + return Expr(f.regr_r2(y.expr, x.expr, filter=filter_raw)) + + +def regr_slope( + y: Expr, + x: Expr, + filter: Optional[Expr] = None, +) -> Expr: + """Computes the slope from linear regression. + + This is a linear regression aggregate function. Only non-null pairs of the inputs + are evaluated. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + y: The linear regression dependent variable + x: The linear regression independent variable + filter: If provided, only compute against rows for which the filter is True """ - return Expr(f.regr_avgy(y.expr, x.expr, distinct)) + filter_raw = filter.expr if filter is not None else None + + return Expr(f.regr_slope(y.expr, x.expr, filter=filter_raw)) -def regr_count(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """Counts the number of rows in which both expressions are not null.""" - return Expr(f.regr_count(y.expr, x.expr, distinct)) +def regr_sxx( + y: Expr, + x: Expr, + filter: Optional[Expr] = None, +) -> Expr: + """Computes the sum of squares of the independent variable ``x``. + + This is a linear regression aggregate function. Only non-null pairs of the inputs + are evaluated. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + y: The linear regression dependent variable + x: The linear regression independent variable + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.regr_sxx(y.expr, x.expr, filter=filter_raw)) -def regr_intercept(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """Computes the intercept from the linear regression.""" - return Expr(f.regr_intercept(y.expr, x.expr, distinct)) +def regr_sxy( + y: Expr, + x: Expr, + filter: Optional[Expr] = None, +) -> Expr: + """Computes the sum of products of pairs of numbers. + + This is a linear regression aggregate function. Only non-null pairs of the inputs + are evaluated. -def regr_r2(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """Computes the R-squared value from linear regression.""" - return Expr(f.regr_r2(y.expr, x.expr, distinct)) + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + Args: + y: The linear regression dependent variable + x: The linear regression independent variable + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None -def regr_slope(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """Computes the slope from linear regression.""" - return Expr(f.regr_slope(y.expr, x.expr, distinct)) + return Expr(f.regr_sxy(y.expr, x.expr, filter=filter_raw)) -def regr_sxx(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """Computes the sum of squares of the independent variable ``x``.""" - return Expr(f.regr_sxx(y.expr, x.expr, distinct)) +def regr_syy( + y: Expr, + x: Expr, + filter: Optional[Expr] = None, +) -> Expr: + """Computes the sum of squares of the dependent variable ``y``. + This is a linear regression aggregate function. Only non-null pairs of the inputs + are evaluated. -def regr_sxy(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """Computes the sum of products of pairs of numbers.""" - return Expr(f.regr_sxy(y.expr, x.expr, distinct)) + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + Args: + y: The linear regression dependent variable + x: The linear regression independent variable + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None -def regr_syy(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """Computes the sum of squares of the dependent variable ``y``.""" - return Expr(f.regr_syy(y.expr, x.expr, distinct)) + return Expr(f.regr_syy(y.expr, x.expr, filter=filter_raw)) def first_value( - arg: Expr, - distinct: bool = False, - filter: Optional[bool] = None, + expression: Expr, + filter: Optional[Expr] = None, order_by: Optional[list[Expr]] = None, - null_treatment: Optional[common.NullTreatment] = None, + null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: - """Returns the first value in a group of values.""" - order_by_cols = [e.expr for e in order_by] if order_by is not None else None + """Returns the first value in a group of values. + + This aggregate function will return the first value in the partition. + + If using the builder functions described in ref:`_aggregation` this function ignores + the option ``distinct``. + + Args: + expression: Argument to perform bitwise calculation on + filter: If provided, only compute against rows for which the filter is True + order_by: Set the ordering of the expression to evaluate + null_treatment: Assign whether to respect or ignull null values. + """ + order_by_raw = expr_list_to_raw_expr_list(order_by) + filter_raw = filter.expr if filter is not None else None return Expr( f.first_value( - arg.expr, - distinct=distinct, - filter=filter, - order_by=order_by_cols, - null_treatment=null_treatment, + expression.expr, + filter=filter_raw, + order_by=order_by_raw, + null_treatment=null_treatment.value, ) ) def last_value( - arg: Expr, - distinct: bool = False, - filter: Optional[bool] = None, + expression: Expr, + filter: Optional[Expr] = None, order_by: Optional[list[Expr]] = None, - null_treatment: Optional[common.NullTreatment] = None, + null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the last value in a group of values. - To set parameters on this expression, use ``.order_by()``, ``.distinct()``, - ``.filter()``, or ``.null_treatment()``. + This aggregate function will return the last value in the partition. + + If using the builder functions described in ref:`_aggregation` this function ignores + the option ``distinct``. + + Args: + expression: Argument to perform bitwise calculation on + filter: If provided, only compute against rows for which the filter is True + order_by: Set the ordering of the expression to evaluate + null_treatment: Assign whether to respect or ignull null values. """ - order_by_cols = [e.expr for e in order_by] if order_by is not None else None + order_by_raw = expr_list_to_raw_expr_list(order_by) + filter_raw = filter.expr if filter is not None else None return Expr( f.last_value( - arg.expr, - distinct=distinct, - filter=filter, - order_by=order_by_cols, - null_treatment=null_treatment, + expression.expr, + filter=filter_raw, + order_by=order_by_raw, + null_treatment=null_treatment.value, + ) + ) + + +def nth_value( + expression: Expr, + n: int, + filter: Optional[Expr] = None, + order_by: Optional[list[Expr]] = None, + null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, +) -> Expr: + """Returns the n-th value in a group of values. + + This aggregate function will return the n-th value in the partition. + + If using the builder functions described in ref:`_aggregation` this function ignores + the option ``distinct``. + + Args: + expression: Argument to perform bitwise calculation on + n: Index of value to return. Starts at 1. + filter: If provided, only compute against rows for which the filter is True + order_by: Set the ordering of the expression to evaluate + null_treatment: Assign whether to respect or ignull null values. + """ + order_by_raw = expr_list_to_raw_expr_list(order_by) + filter_raw = filter.expr if filter is not None else None + + return Expr( + f.nth_value( + expression.expr, + n, + filter=filter_raw, + order_by=order_by_raw, + null_treatment=null_treatment.value, ) ) -def bit_and(arg: Expr, distinct: bool = False) -> Expr: - """Computes the bitwise AND of the argument.""" - return Expr(f.bit_and(arg.expr, distinct=distinct)) +def bit_and(expression: Expr, filter: Optional[Expr] = None) -> Expr: + """Computes the bitwise AND of the argument. + + This aggregate function will bitwise compare every value in the input partition. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + Args: + expression: Argument to perform bitwise calculation on + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.bit_and(expression.expr, filter=filter_raw)) -def bit_or(arg: Expr, distinct: bool = False) -> Expr: - """Computes the bitwise OR of the argument.""" - return Expr(f.bit_or(arg.expr, distinct=distinct)) +def bit_or(expression: Expr, filter: Optional[Expr] = None) -> Expr: + """Computes the bitwise OR of the argument. -def bit_xor(arg: Expr, distinct: bool = False) -> Expr: - """Computes the bitwise XOR of the argument.""" - return Expr(f.bit_xor(arg.expr, distinct=distinct)) + This aggregate function will bitwise compare every value in the input partition. + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + expression: Argument to perform bitwise calculation on + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.bit_or(expression.expr, filter=filter_raw)) -def bool_and(arg: Expr, distinct: bool = False) -> Expr: - """Computes the boolean AND of the argument.""" - return Expr(f.bool_and(arg.expr, distinct=distinct)) +def bit_xor( + expression: Expr, distinct: bool = False, filter: Optional[Expr] = None +) -> Expr: + """Computes the bitwise XOR of the argument. -def bool_or(arg: Expr, distinct: bool = False) -> Expr: - """Computes the boolean OR of the argument.""" - return Expr(f.bool_or(arg.expr, distinct=distinct)) + This aggregate function will bitwise compare every value in the input partition. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by`` and ``null_treatment``. + + Args: + expression: Argument to perform bitwise calculation on + distinct: If True, evaluate each unique value of expression only once + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.bit_xor(expression.expr, distinct=distinct, filter=filter_raw)) + + +def bool_and(expression: Expr, filter: Optional[Expr] = None) -> Expr: + """Computes the boolean AND of the argument. + + This aggregate function will compare every value in the input partition. These are + expected to be boolean values. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + expression: Argument to perform calculation on + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.bool_and(expression.expr, filter=filter_raw)) + + +def bool_or(expression: Expr, filter: Optional[Expr] = None) -> Expr: + """Computes the boolean OR of the argument. + + This aggregate function will compare every value in the input partition. These are + expected to be boolean values. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + expression: Argument to perform calculation on + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.bool_or(expression.expr, filter=filter_raw)) def lead( @@ -2107,3 +2622,37 @@ def ntile( order_by=order_cols, ) ) + + +def string_agg( + expression: Expr, + delimiter: str, + filter: Optional[Expr] = None, + order_by: Optional[list[Expr]] = None, +) -> Expr: + """Concatenates the input strings. + + This aggregate function will concatenate input strings, ignoring null values, and + seperating them with the specified delimiter. Non-string values will be converted to + their string equivalents. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``distinct`` and ``null_treatment``. + + Args: + expression: Argument to perform bitwise calculation on + delimiter: Text to place between each value of expression + filter: If provided, only compute against rows for which the filter is True + order_by: Set the ordering of the expression to evaluate + """ + order_by_raw = expr_list_to_raw_expr_list(order_by) + filter_raw = filter.expr if filter is not None else None + + return Expr( + f.string_agg( + expression.expr, + delimiter, + filter=filter_raw, + order_by=order_by_raw, + ) + ) diff --git a/python/datafusion/tests/test_aggregation.py b/python/datafusion/tests/test_aggregation.py index ab653c40..243a8c3c 100644 --- a/python/datafusion/tests/test_aggregation.py +++ b/python/datafusion/tests/test_aggregation.py @@ -21,6 +21,7 @@ from datafusion import SessionContext, column, lit from datafusion import functions as f +from datafusion.common import NullTreatment @pytest.fixture @@ -34,12 +35,30 @@ def df(): pa.array([4, 4, 6]), pa.array([9, 8, 5]), pa.array([True, True, False]), + pa.array([1, 2, None]), ], - names=["a", "b", "c", "d"], + names=["a", "b", "c", "d", "e"], ) return ctx.create_dataframe([[batch]]) +@pytest.fixture +def df_partitioned(): + ctx = SessionContext() + + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [ + pa.array([0, 1, 2, 3, 4, 5, 6]), + pa.array([7, None, 7, 8, 9, None, 9]), + pa.array(["A", "A", "A", "A", "B", "B", "B"]), + ], + names=["a", "b", "c"], + ) + + return ctx.create_dataframe([[batch]]) + + @pytest.fixture def df_aggregate_100(): ctx = SessionContext() @@ -87,6 +106,7 @@ def df_aggregate_100(): ], ) def test_aggregation_stats(df, agg_expr, calc_expected): + df = df.select("a", "b", "c", "d") agg_df = df.aggregate([], [agg_expr]) result = agg_df.collect()[0] values_a, values_b, values_c, values_d = df.collect()[0] @@ -95,68 +115,323 @@ def test_aggregation_stats(df, agg_expr, calc_expected): @pytest.mark.parametrize( - "agg_expr, expected", + "agg_expr, expected, array_sort", [ - (f.approx_distinct(column("b")), pa.array([2], type=pa.uint64())), - (f.approx_median(column("b")), pa.array([4])), - (f.approx_percentile_cont(column("b"), lit(0.5)), pa.array([4])), + (f.approx_distinct(column("b")), pa.array([2], type=pa.uint64()), False), + ( + f.approx_distinct( + column("b"), + filter=column("a") != lit(3), + ), + pa.array([1], type=pa.uint64()), + False, + ), + (f.approx_median(column("b")), pa.array([4]), False), + (f.median(column("b"), distinct=True), pa.array([5]), False), + (f.median(column("b"), filter=column("a") != 2), pa.array([5]), False), + (f.approx_median(column("b"), filter=column("a") != 2), pa.array([5]), False), + (f.approx_percentile_cont(column("b"), 0.5), pa.array([4]), False), ( - f.approx_percentile_cont_with_weight(column("b"), lit(0.6), lit(0.5)), + f.approx_percentile_cont_with_weight(column("b"), lit(0.6), 0.5), pa.array([6], type=pa.float64()), + False, + ), + ( + f.approx_percentile_cont_with_weight( + column("b"), lit(0.6), 0.5, filter=column("a") != lit(3) + ), + pa.array([4], type=pa.float64()), + False, + ), + (f.array_agg(column("b")), pa.array([[4, 4, 6]]), False), + (f.array_agg(column("b"), distinct=True), pa.array([[4, 6]]), True), + ( + f.array_agg(column("e"), filter=column("e").is_not_null()), + pa.array([[1, 2]]), + False, + ), + ( + f.array_agg(column("b"), order_by=[column("c")]), + pa.array([[6, 4, 4]]), + False, + ), + (f.avg(column("b"), filter=column("a") != lit(1)), pa.array([5.0]), False), + (f.sum(column("b"), filter=column("a") != lit(1)), pa.array([10]), False), + (f.count(column("b"), distinct=True), pa.array([2]), False), + (f.count(column("b"), filter=column("a") != 3), pa.array([2]), False), + (f.count(), pa.array([3]), False), + (f.count(column("e")), pa.array([2]), False), + (f.count_star(filter=column("a") != 3), pa.array([2]), False), + (f.max(column("a"), filter=column("a") != lit(3)), pa.array([2]), False), + (f.min(column("a"), filter=column("a") != lit(1)), pa.array([2]), False), + ( + f.stddev(column("a"), filter=column("a") != lit(2)), + pa.array([np.sqrt(2)]), + False, + ), + ( + f.stddev_pop(column("a"), filter=column("a") != lit(2)), + pa.array([1.0]), + False, ), - (f.array_agg(column("b")), pa.array([[4, 4, 6]])), ], ) -def test_aggregation(df, agg_expr, expected): - agg_df = df.aggregate([], [agg_expr]) +def test_aggregation(df, agg_expr, expected, array_sort): + agg_df = df.aggregate([], [agg_expr.alias("agg_expr")]) + if array_sort: + agg_df = agg_df.select(f.array_sort(column("agg_expr"))) + agg_df.show() result = agg_df.collect()[0] + + print(result) assert result.column(0) == expected -def test_aggregate_100(df_aggregate_100): +@pytest.mark.parametrize( + "name,expr,expected", + [ + ( + "approx_percentile_cont", + f.approx_percentile_cont(column("c3"), 0.95, num_centroids=200), + [73, 68, 122, 124, 115], + ), + ( + "approx_perc_cont_few_centroids", + f.approx_percentile_cont(column("c3"), 0.95, num_centroids=5), + [72, 68, 119, 124, 115], + ), + ( + "approx_perc_cont_filtered", + f.approx_percentile_cont( + column("c3"), 0.95, num_centroids=200, filter=column("c3") > lit(0) + ), + [83, 68, 122, 124, 117], + ), + ( + "corr", + f.corr(column("c3"), column("c2")), + [-0.1056, -0.2808, 0.0023, 0.0022, -0.2473], + ), + ( + "corr_w_filter", + f.corr(column("c3"), column("c2"), filter=column("c3") > lit(0)), + [-0.3298, 0.2925, 0.2467, -0.2269, 0.0358], + ), + ( + "covar_pop", + f.covar_pop(column("c3"), column("c2")), + [-7.2857, -25.6731, 0.2222, 0.2469, -20.2857], + ), + ( + "covar_pop_w_filter", + f.covar_pop(column("c3"), column("c2"), filter=column("c3") > lit(0)), + [-9.25, 9.0579, 13.7521, -9.9669, 1.1641], + ), + ( + "covar_samp", + f.covar_samp(column("c3"), column("c2")), + [-7.65, -27.0994, 0.2333, 0.2614, -21.3], + ), + ( + "covar_samp_w_filter", + f.covar_samp(column("c3"), column("c2"), filter=column("c3") > lit(0)), + [-10.5714, 9.9636, 15.1273, -10.9636, 1.2417], + ), + ( + "var_samp", + f.var_samp(column("c2")), + [1.9286, 2.2047, 1.6333, 2.1438, 1.6], + ), + ( + "var_samp_w_filter", + f.var_samp(column("c2"), filter=column("c3") > lit(0)), + [1.4286, 2.4182, 1.8545, 1.4727, 1.6292], + ), + ( + "var_pop", + f.var_pop(column("c2")), + [1.8367, 2.0886, 1.5556, 2.0247, 1.5238], + ), + ( + "var_pop_w_filter", + f.var_pop(column("c2"), filter=column("c3") > lit(0)), + [1.25, 2.1983, 1.686, 1.3388, 1.5273], + ), + ], +) +def test_aggregate_100(df_aggregate_100, name, expr, expected): # https://github.com/apache/datafusion/blob/bddb6415a50746d2803dd908d19c3758952d74f9/datafusion/sqllogictest/test_files/aggregate.slt#L1490-L1498 - result = ( + df = ( df_aggregate_100.aggregate( [column("c1")], - [f.approx_percentile_cont(column("c3"), lit(0.95), lit(200)).alias("c3")], + [expr.alias(name)], ) + .select("c1", f.round(column(name), lit(4)).alias(name)) .sort(column("c1").sort(ascending=True)) - .collect() ) + df.show() - assert len(result) == 1 - result = result[0] - assert result.column("c1") == pa.array(["a", "b", "c", "d", "e"]) - assert result.column("c3") == pa.array([73, 68, 122, 124, 115]) + expected_dict = { + "c1": ["a", "b", "c", "d", "e"], + name: expected, + } + assert df.collect()[0].to_pydict() == expected_dict -def test_bit_add_or_xor(df): - df = df.aggregate( - [], - [ - f.bit_and(column("a")), - f.bit_or(column("b")), - f.bit_xor(column("c")), - ], - ) - result = df.collect() - result = result[0] - assert result.column(0) == pa.array([0]) - assert result.column(1) == pa.array([6]) - assert result.column(2) == pa.array([4]) +data_test_bitwise_and_boolean_functions = [ + ("bit_and", f.bit_and(column("a")), [0]), + ("bit_and_filter", f.bit_and(column("a"), filter=column("a") != lit(2)), [1]), + ("bit_or", f.bit_or(column("b")), [6]), + ("bit_or_filter", f.bit_or(column("b"), filter=column("a") != lit(3)), [4]), + ("bit_xor", f.bit_xor(column("c")), [4]), + ("bit_xor_distinct", f.bit_xor(column("b"), distinct=True), [2]), + ("bit_xor_filter", f.bit_xor(column("b"), filter=column("a") != lit(3)), [0]), + ( + "bit_xor_filter_distinct", + f.bit_xor(column("b"), distinct=True, filter=column("a") != lit(3)), + [4], + ), + ("bool_and", f.bool_and(column("d")), [False]), + ("bool_and_filter", f.bool_and(column("d"), filter=column("a") != lit(3)), [True]), + ("bool_or", f.bool_or(column("d")), [True]), + ("bool_or_filter", f.bool_or(column("d"), filter=column("a") == lit(3)), [False]), +] -def test_bool_and_or(df): - df = df.aggregate( - [], - [ - f.bool_and(column("d")), - f.bool_or(column("d")), - ], +@pytest.mark.parametrize("name,expr,result", data_test_bitwise_and_boolean_functions) +def test_bit_and_bool_fns(df, name, expr, result): + df = df.aggregate([], [expr.alias(name)]) + + expected = { + name: result, + } + + assert df.collect()[0].to_pydict() == expected + + +@pytest.mark.parametrize( + "name,expr,result", + [ + ("first_value", f.first_value(column("a")), [0, 4]), + ( + "first_value_ordered", + f.first_value(column("a"), order_by=[column("a").sort(ascending=False)]), + [3, 6], + ), + ( + "first_value_with_null", + f.first_value( + column("b"), + order_by=[column("b").sort(ascending=True)], + null_treatment=NullTreatment.RESPECT_NULLS, + ), + [None, None], + ), + ( + "first_value_ignore_null", + f.first_value( + column("b"), + order_by=[column("b").sort(ascending=True)], + null_treatment=NullTreatment.IGNORE_NULLS, + ), + [7, 9], + ), + ("last_value", f.last_value(column("a")), [3, 6]), + ( + "last_value_ordered", + f.last_value(column("a"), order_by=[column("a").sort(ascending=False)]), + [0, 4], + ), + ( + "last_value_with_null", + f.last_value( + column("b"), + order_by=[column("b").sort(ascending=True, nulls_first=False)], + null_treatment=NullTreatment.RESPECT_NULLS, + ), + [None, None], + ), + ( + "last_value_ignore_null", + f.last_value( + column("b"), + order_by=[column("b").sort(ascending=True)], + null_treatment=NullTreatment.IGNORE_NULLS, + ), + [8, 9], + ), + ("first_value", f.first_value(column("a")), [0, 4]), + ( + "nth_value_ordered", + f.nth_value(column("a"), 2, order_by=[column("a").sort(ascending=False)]), + [2, 5], + ), + ( + "nth_value_with_null", + f.nth_value( + column("b"), + 3, + order_by=[column("b").sort(ascending=True, nulls_first=False)], + null_treatment=NullTreatment.RESPECT_NULLS, + ), + [8, None], + ), + ( + "nth_value_ignore_null", + f.nth_value( + column("b"), + 2, + order_by=[column("b").sort(ascending=True)], + null_treatment=NullTreatment.IGNORE_NULLS, + ), + [7, 9], + ), + ], +) +def test_first_last_value(df_partitioned, name, expr, result) -> None: + df = df_partitioned.aggregate([column("c")], [expr.alias(name)]).sort(column("c")) + + expected = { + "c": ["A", "B"], + name: result, + } + + assert df.collect()[0].to_pydict() == expected + + +@pytest.mark.parametrize( + "name,expr,result", + [ + ("string_agg", f.string_agg(column("a"), ","), "one,two,three,two"), + ("string_agg", f.string_agg(column("b"), ""), "03124"), + ( + "string_agg", + f.string_agg(column("a"), ",", filter=column("b") != lit(3)), + "one,three,two", + ), + ( + "string_agg", + f.string_agg(column("a"), ",", order_by=[column("b")]), + "one,three,two,two", + ), + ], +) +def test_string_agg(name, expr, result) -> None: + ctx = SessionContext() + + df = ctx.from_pydict( + { + "a": ["one", "two", None, "three", "two"], + "b": [0, 3, 1, 2, 4], + } ) - result = df.collect() - result = result[0] - assert result.column(0) == pa.array([False]) - assert result.column(1) == pa.array([True]) + + df = df.aggregate([], [expr.alias(name)]) + + expected = { + name: [result], + } + df.show() + assert df.collect()[0].to_pydict() == expected diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index e7e6d79e..8e3c5139 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -912,17 +912,64 @@ def test_regr_funcs_sql_2(): @pytest.mark.parametrize( "func, expected", [ - pytest.param(f.regr_slope, pa.array([2], type=pa.float64()), id="regr_slope"), + pytest.param(f.regr_slope(column("c2"), column("c1")), [4.6], id="regr_slope"), pytest.param( - f.regr_intercept, pa.array([0], type=pa.float64()), id="regr_intercept" + f.regr_slope(column("c2"), column("c1"), filter=column("c1") > literal(2)), + [8], + id="regr_slope_filter", + ), + pytest.param( + f.regr_intercept(column("c2"), column("c1")), [-4], id="regr_intercept" + ), + pytest.param( + f.regr_intercept( + column("c2"), column("c1"), filter=column("c1") > literal(2) + ), + [-16], + id="regr_intercept_filter", + ), + pytest.param(f.regr_count(column("c2"), column("c1")), [4], id="regr_count"), + pytest.param( + f.regr_count(column("c2"), column("c1"), filter=column("c1") > literal(2)), + [2], + id="regr_count_filter", + ), + pytest.param(f.regr_r2(column("c2"), column("c1")), [0.92], id="regr_r2"), + pytest.param( + f.regr_r2(column("c2"), column("c1"), filter=column("c1") > literal(2)), + [1.0], + id="regr_r2_filter", + ), + pytest.param(f.regr_avgx(column("c2"), column("c1")), [2.5], id="regr_avgx"), + pytest.param( + f.regr_avgx(column("c2"), column("c1"), filter=column("c1") > literal(2)), + [3.5], + id="regr_avgx_filter", + ), + pytest.param(f.regr_avgy(column("c2"), column("c1")), [7.5], id="regr_avgy"), + pytest.param( + f.regr_avgy(column("c2"), column("c1"), filter=column("c1") > literal(2)), + [12.0], + id="regr_avgy_filter", + ), + pytest.param(f.regr_sxx(column("c2"), column("c1")), [5.0], id="regr_sxx"), + pytest.param( + f.regr_sxx(column("c2"), column("c1"), filter=column("c1") > literal(2)), + [0.5], + id="regr_sxx_filter", + ), + pytest.param(f.regr_syy(column("c2"), column("c1")), [115.0], id="regr_syy"), + pytest.param( + f.regr_syy(column("c2"), column("c1"), filter=column("c1") > literal(2)), + [32.0], + id="regr_syy_filter", + ), + pytest.param(f.regr_sxy(column("c2"), column("c1")), [23.0], id="regr_sxy"), + pytest.param( + f.regr_sxy(column("c2"), column("c1"), filter=column("c1") > literal(2)), + [4.0], + id="regr_sxy_filter", ), - pytest.param(f.regr_count, pa.array([3], type=pa.uint64()), id="regr_count"), - pytest.param(f.regr_r2, pa.array([1], type=pa.float64()), id="regr_r2"), - pytest.param(f.regr_avgx, pa.array([2], type=pa.float64()), id="regr_avgx"), - pytest.param(f.regr_avgy, pa.array([4], type=pa.float64()), id="regr_avgy"), - pytest.param(f.regr_sxx, pa.array([2], type=pa.float64()), id="regr_sxx"), - pytest.param(f.regr_syy, pa.array([8], type=pa.float64()), id="regr_syy"), - pytest.param(f.regr_sxy, pa.array([4], type=pa.float64()), id="regr_sxy"), ], ) def test_regr_funcs_df(func, expected): @@ -932,38 +979,18 @@ def test_regr_funcs_df(func, expected): ctx = SessionContext() # Create a DataFrame - data = {"column1": [1, 2, 3], "column2": [2, 4, 6]} + data = {"c1": [1, 2, 3, 4, 5, None], "c2": [2, 4, 8, 16, None, 64]} df = ctx.from_pydict(data, name="test_table") # Perform the regression function using DataFrame API - result_df = df.aggregate([], [func(f.col("column2"), f.col("column1"))]).collect() - - # Assertion for DataFrame API result - assert result_df[0].column(0) == expected - + df = df.aggregate([], [func.alias("result")]) + df.show() -def test_first_last_value(df): - df = df.aggregate( - [], - [ - f.first_value(column("a")), - f.first_value(column("b")), - f.first_value(column("d")), - f.last_value(column("a")), - f.last_value(column("b")), - f.last_value(column("d")), - ], - ) + expected_dict = { + "result": expected, + } - result = df.collect() - result = result[0] - assert result.column(0) == pa.array(["Hello"]) - assert result.column(1) == pa.array([4]) - assert result.column(2) == pa.array([datetime(2022, 12, 31)]) - assert result.column(3) == pa.array(["!"]) - assert result.column(4) == pa.array([6]) - assert result.column(5) == pa.array([datetime(2020, 7, 2)]) - df.show() + assert df.collect()[0].to_pydict() == expected_dict def test_binary_string_functions(df): diff --git a/python/datafusion/tests/test_wrapper_coverage.py b/python/datafusion/tests/test_wrapper_coverage.py index 44b9ca83..4a47de2e 100644 --- a/python/datafusion/tests/test_wrapper_coverage.py +++ b/python/datafusion/tests/test_wrapper_coverage.py @@ -20,8 +20,19 @@ import datafusion.object_store import datafusion.substrait +# EnumType introduced in 3.11. 3.10 and prior it was called EnumMeta. +try: + from enum import EnumType +except ImportError: + from enum import EnumMeta as EnumType + def missing_exports(internal_obj, wrapped_obj) -> None: + # Special case enums - just make sure they exist since dir() + # and other functions get overridden. + if isinstance(wrapped_obj, EnumType): + return + for attr in dir(internal_obj): assert attr in dir(wrapped_obj) diff --git a/src/functions.rs b/src/functions.rs index b5b003df..b9ca6301 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -35,305 +35,27 @@ use datafusion::functions_aggregate; use datafusion::logical_expr::expr::Alias; use datafusion::logical_expr::sqlparser::ast::NullTreatment as DFNullTreatment; use datafusion::logical_expr::{ - expr::{find_df_window_func, AggregateFunction, Sort, WindowFunction}, + expr::{find_df_window_func, Sort, WindowFunction}, lit, Expr, WindowFunctionDefinition, }; -#[pyfunction] -pub fn approx_distinct(expression: PyExpr) -> PyExpr { - functions_aggregate::expr_fn::approx_distinct(expression.expr).into() -} - -#[pyfunction] -pub fn approx_median(expression: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::approx_median(expression.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn approx_percentile_cont( - expression: PyExpr, - percentile: PyExpr, - distinct: bool, - num_centroids: Option, // enforces optional arguments at the end, currently -) -> PyResult { - let args = if let Some(num_centroids) = num_centroids { - vec![expression.expr, percentile.expr, num_centroids.expr] - } else { - vec![expression.expr, percentile.expr] - }; - let udaf = functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf(); - let expr = udaf.call(args); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn approx_percentile_cont_with_weight( - expression: PyExpr, - weight: PyExpr, - percentile: PyExpr, - distinct: bool, -) -> PyResult { - let expr = functions_aggregate::expr_fn::approx_percentile_cont_with_weight( - expression.expr, - weight.expr, - percentile.expr, - ); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn avg(expression: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::avg(expression.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn bit_and(expr_x: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::bit_and(expr_x.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn bit_or(expression: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::bit_or(expression.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn bit_xor(expression: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::bit_xor(expression.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn bool_and(expression: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::bool_and(expression.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn bool_or(expression: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::bool_or(expression.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn corr(y: PyExpr, x: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::corr(y.expr, x.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn grouping(expression: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::grouping(expression.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn sum(args: PyExpr) -> PyExpr { - functions_aggregate::expr_fn::sum(args.expr).into() -} - -#[pyfunction] -pub fn covar_samp(y: PyExpr, x: PyExpr) -> PyExpr { - functions_aggregate::expr_fn::covar_samp(y.expr, x.expr).into() -} - -#[pyfunction] -pub fn covar_pop(y: PyExpr, x: PyExpr) -> PyExpr { - functions_aggregate::expr_fn::covar_pop(y.expr, x.expr).into() -} - -#[pyfunction] -pub fn median(arg: PyExpr) -> PyExpr { - functions_aggregate::expr_fn::median(arg.expr).into() -} - -#[pyfunction] -pub fn stddev(expression: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::stddev(expression.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn stddev_pop(expression: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::stddev_pop(expression.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn var_samp(expression: PyExpr) -> PyExpr { - functions_aggregate::expr_fn::var_sample(expression.expr).into() -} - -#[pyfunction] -pub fn var_pop(expression: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::var_pop(expression.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn regr_avgx(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::regr_avgx(expr_y.expr, expr_x.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn regr_avgy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::regr_avgy(expr_y.expr, expr_x.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn regr_count(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::regr_count(expr_y.expr, expr_x.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn regr_intercept(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::regr_intercept(expr_y.expr, expr_x.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn regr_r2(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::regr_r2(expr_y.expr, expr_x.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn regr_slope(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::regr_slope(expr_y.expr, expr_x.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn regr_sxx(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::regr_sxx(expr_y.expr, expr_x.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn regr_sxy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::regr_sxy(expr_y.expr, expr_x.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - -#[pyfunction] -pub fn regr_syy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::regr_syy(expr_y.expr, expr_x.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - fn add_builder_fns_to_aggregate( agg_fn: Expr, - distinct: bool, + distinct: Option, filter: Option, order_by: Option>, null_treatment: Option, ) -> PyResult { // Since ExprFuncBuilder::new() is private, we can guarantee initializing - // a builder with an `order_by` default of empty vec - let order_by = order_by - .map(|x| x.into_iter().map(|x| x.expr).collect::>()) - .unwrap_or_default(); - let mut builder = agg_fn.order_by(order_by); + // a builder with an `null_treatment` with option None + let mut builder = agg_fn.null_treatment(None); - if distinct { + if let Some(order_by_cols) = order_by { + let order_by_cols = to_sort_expressions(order_by_cols); + builder = builder.order_by(order_by_cols); + } + + if let Some(true) = distinct { builder = builder.distinct(); } @@ -341,39 +63,11 @@ fn add_builder_fns_to_aggregate( builder = builder.filter(filter.expr); } - // would be nice if all the options builder methods accepted Option ... builder = builder.null_treatment(null_treatment.map(DFNullTreatment::from)); Ok(builder.build()?.into()) } -#[pyfunction] -pub fn first_value( - expr: PyExpr, - distinct: bool, - filter: Option, - order_by: Option>, - null_treatment: Option, -) -> PyResult { - // If we initialize the UDAF with order_by directly, then it gets over-written by the builder - let agg_fn = functions_aggregate::expr_fn::first_value(expr.expr, None); - - add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) -} - -#[pyfunction] -pub fn last_value( - expr: PyExpr, - distinct: bool, - filter: Option, - order_by: Option>, - null_treatment: Option, -) -> PyResult { - let agg_fn = functions_aggregate::expr_fn::last_value(vec![expr.expr]); - - add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) -} - #[pyfunction] fn in_list(expr: PyExpr, value: Vec, negated: bool) -> PyExpr { datafusion::logical_expr::in_list( @@ -505,25 +199,6 @@ fn col(name: &str) -> PyResult { }) } -// TODO: should we just expose this in python? -/// Create a COUNT(1) aggregate expression -#[pyfunction] -fn count_star() -> PyExpr { - functions_aggregate::expr_fn::count(lit(1)).into() -} - -/// Wrapper for [`functions_aggregate::expr_fn::count`] -/// Count the number of non-null values in the column -#[pyfunction] -fn count(expr: PyExpr, distinct: bool) -> PyResult { - let expr = functions_aggregate::expr_fn::count(expr.expr); - if distinct { - Ok(expr.distinct().build()?.into()) - } else { - Ok(expr.into()) - } -} - /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. #[pyfunction] fn case(expr: PyExpr) -> PyResult { @@ -646,24 +321,46 @@ fn window( }) } +// Generates a [pyo3] wrapper for associated aggregate functions. +// All of the builder options are exposed to the python internal +// function and we rely on the wrappers to only use those that +// are appropriate. macro_rules! aggregate_function { - ($NAME: ident, $FUNC: path) => { - aggregate_function!($NAME, $FUNC, stringify!($NAME)); + ($NAME: ident) => { + aggregate_function!($NAME, expr); }; - ($NAME: ident, $FUNC: path, $DOC: expr) => { - #[doc = $DOC] + ($NAME: ident, $($arg:ident)*) => { #[pyfunction] - #[pyo3(signature = (*args, distinct=false))] - fn $NAME(args: Vec, distinct: bool) -> PyExpr { - let expr = datafusion::logical_expr::Expr::AggregateFunction(AggregateFunction { - func: $FUNC(), - args: args.into_iter().map(|e| e.into()).collect(), - distinct, - filter: None, - order_by: None, - null_treatment: None, - }); - expr.into() + fn $NAME( + $($arg: PyExpr),*, + distinct: Option, + filter: Option, + order_by: Option>, + null_treatment: Option + ) -> PyResult { + let agg_fn = functions_aggregate::expr_fn::$NAME($($arg.into()),*); + + add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) + } + }; +} + +macro_rules! aggregate_function_vec_args { + ($NAME: ident) => { + aggregate_function_vec_args!($NAME, expr); + }; + ($NAME: ident, $($arg:ident)*) => { + #[pyfunction] + fn $NAME( + $($arg: PyExpr),*, + distinct: Option, + filter: Option, + order_by: Option>, + null_treatment: Option + ) -> PyResult { + let agg_fn = functions_aggregate::expr_fn::$NAME(vec![$($arg.into()),*]); + + add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) } }; } @@ -891,9 +588,120 @@ array_fn!(array_resize, array size value); array_fn!(flatten, array); array_fn!(range, start stop step); -aggregate_function!(array_agg, functions_aggregate::array_agg::array_agg_udaf); -aggregate_function!(max, functions_aggregate::min_max::max_udaf); -aggregate_function!(min, functions_aggregate::min_max::min_udaf); +aggregate_function!(array_agg); +aggregate_function!(max); +aggregate_function!(min); +aggregate_function!(avg); +aggregate_function!(sum); +aggregate_function!(bit_and); +aggregate_function!(bit_or); +aggregate_function!(bit_xor); +aggregate_function!(bool_and); +aggregate_function!(bool_or); +aggregate_function!(corr, y x); +aggregate_function!(count); +aggregate_function!(covar_samp, y x); +aggregate_function!(covar_pop, y x); +aggregate_function!(median); +aggregate_function!(regr_slope, y x); +aggregate_function!(regr_intercept, y x); +aggregate_function!(regr_count, y x); +aggregate_function!(regr_r2, y x); +aggregate_function!(regr_avgx, y x); +aggregate_function!(regr_avgy, y x); +aggregate_function!(regr_sxx, y x); +aggregate_function!(regr_syy, y x); +aggregate_function!(regr_sxy, y x); +aggregate_function!(stddev); +aggregate_function!(stddev_pop); +aggregate_function!(var_sample); +aggregate_function!(var_pop); +aggregate_function!(approx_distinct); +aggregate_function!(approx_median); + +// Code is commented out since grouping is not yet implemented +// https://github.com/apache/datafusion-python/issues/861 +// aggregate_function!(grouping); + +#[pyfunction] +pub fn approx_percentile_cont( + expression: PyExpr, + percentile: f64, + num_centroids: Option, // enforces optional arguments at the end, currently + filter: Option, +) -> PyResult { + let args = if let Some(num_centroids) = num_centroids { + vec![expression.expr, lit(percentile), lit(num_centroids)] + } else { + vec![expression.expr, lit(percentile)] + }; + let udaf = functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf(); + let agg_fn = udaf.call(args); + + add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) +} + +#[pyfunction] +pub fn approx_percentile_cont_with_weight( + expression: PyExpr, + weight: PyExpr, + percentile: f64, + filter: Option, +) -> PyResult { + let agg_fn = functions_aggregate::expr_fn::approx_percentile_cont_with_weight( + expression.expr, + weight.expr, + lit(percentile), + ); + + add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) +} + +aggregate_function_vec_args!(last_value); + +// We handle first_value explicitly because the signature expects an order_by +// https://github.com/apache/datafusion/issues/12376 +#[pyfunction] +pub fn first_value( + expr: PyExpr, + distinct: Option, + filter: Option, + order_by: Option>, + null_treatment: Option, +) -> PyResult { + // If we initialize the UDAF with order_by directly, then it gets over-written by the builder + let agg_fn = functions_aggregate::expr_fn::first_value(expr.expr, None); + + add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) +} + +// nth_value requires a non-expr argument +#[pyfunction] +pub fn nth_value( + expr: PyExpr, + n: i64, + distinct: Option, + filter: Option, + order_by: Option>, + null_treatment: Option, +) -> PyResult { + let agg_fn = datafusion::functions_aggregate::nth_value::nth_value(vec![expr.expr, lit(n)]); + add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) +} + +// string_agg requires a non-expr argument +#[pyfunction] +pub fn string_agg( + expr: PyExpr, + delimiter: String, + distinct: Option, + filter: Option, + order_by: Option>, + null_treatment: Option, +) -> PyResult { + let agg_fn = datafusion::functions_aggregate::string_agg::string_agg(expr.expr, lit(delimiter)); + add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) +} fn add_builder_fns_to_window( window_fn: Expr, @@ -1042,7 +850,6 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(cosh))?; m.add_wrapped(wrap_pyfunction!(cot))?; m.add_wrapped(wrap_pyfunction!(count))?; - m.add_wrapped(wrap_pyfunction!(count_star))?; m.add_wrapped(wrap_pyfunction!(covar_pop))?; m.add_wrapped(wrap_pyfunction!(covar_samp))?; m.add_wrapped(wrap_pyfunction!(current_date))?; @@ -1059,7 +866,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(floor))?; m.add_wrapped(wrap_pyfunction!(from_unixtime))?; m.add_wrapped(wrap_pyfunction!(gcd))?; - m.add_wrapped(wrap_pyfunction!(grouping))?; + // m.add_wrapped(wrap_pyfunction!(grouping))?; m.add_wrapped(wrap_pyfunction!(in_list))?; m.add_wrapped(wrap_pyfunction!(initcap))?; m.add_wrapped(wrap_pyfunction!(isnan))?; @@ -1113,6 +920,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(starts_with))?; m.add_wrapped(wrap_pyfunction!(stddev))?; m.add_wrapped(wrap_pyfunction!(stddev_pop))?; + m.add_wrapped(wrap_pyfunction!(string_agg))?; m.add_wrapped(wrap_pyfunction!(strpos))?; m.add_wrapped(wrap_pyfunction!(r#struct))?; // Use raw identifier since struct is a keyword m.add_wrapped(wrap_pyfunction!(substr))?; @@ -1134,7 +942,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(upper))?; m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision m.add_wrapped(wrap_pyfunction!(var_pop))?; - m.add_wrapped(wrap_pyfunction!(var_samp))?; + m.add_wrapped(wrap_pyfunction!(var_sample))?; m.add_wrapped(wrap_pyfunction!(window))?; m.add_wrapped(wrap_pyfunction!(regr_avgx))?; m.add_wrapped(wrap_pyfunction!(regr_avgy))?; @@ -1147,6 +955,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(regr_syy))?; m.add_wrapped(wrap_pyfunction!(first_value))?; m.add_wrapped(wrap_pyfunction!(last_value))?; + m.add_wrapped(wrap_pyfunction!(nth_value))?; m.add_wrapped(wrap_pyfunction!(bit_and))?; m.add_wrapped(wrap_pyfunction!(bit_or))?; m.add_wrapped(wrap_pyfunction!(bit_xor))?;