diff --git a/datafusion/tests/test_functions.py b/datafusion/tests/test_functions.py index e504cc498..be2a2f1f5 100644 --- a/datafusion/tests/test_functions.py +++ b/datafusion/tests/test_functions.py @@ -479,6 +479,28 @@ def test_case(df): assert result.column(2) == pa.array(["Hola", "Mundo", None]) +def test_regr_funcs(df): + # test case base on + # https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330 + ctx = SessionContext() + result = ctx.sql( + "select regr_slope(1,1), regr_intercept(1,1), " + "regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), " + "regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), " + "regr_sxy(1,1);" + ).collect() + + assert result[0].column(0) == pa.array([None], type=pa.float64()) + assert result[0].column(1) == pa.array([None], type=pa.float64()) + assert result[0].column(2) == pa.array([1], type=pa.float64()) + assert result[0].column(3) == pa.array([None], type=pa.float64()) + assert result[0].column(4) == pa.array([1], type=pa.float64()) + assert result[0].column(5) == pa.array([1], type=pa.float64()) + assert result[0].column(6) == pa.array([0], type=pa.float64()) + assert result[0].column(7) == pa.array([0], type=pa.float64()) + assert result[0].column(8) == pa.array([0], type=pa.float64()) + + def test_first_last_value(df): df = df.aggregate( [], diff --git a/src/functions.rs b/src/functions.rs index 2f2f34ee0..e509aff71 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -362,6 +362,15 @@ aggregate_function!(stddev_samp, Stddev); aggregate_function!(var, Variance); aggregate_function!(var_pop, VariancePop); aggregate_function!(var_samp, Variance); +aggregate_function!(regr_avgx, RegrAvgx); +aggregate_function!(regr_avgy, RegrAvgy); +aggregate_function!(regr_count, RegrCount); +aggregate_function!(regr_intercept, RegrIntercept); +aggregate_function!(regr_r2, RegrR2); +aggregate_function!(regr_slope, RegrSlope); +aggregate_function!(regr_sxx, RegrSXX); +aggregate_function!(regr_sxy, RegrSXY); +aggregate_function!(regr_syy, RegrSYY); aggregate_function!(first_value, FirstValue); aggregate_function!(last_value, LastValue); aggregate_function!(bit_and, BitAnd); @@ -496,6 +505,15 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(var_pop))?; m.add_wrapped(wrap_pyfunction!(var_samp))?; m.add_wrapped(wrap_pyfunction!(window))?; + m.add_wrapped(wrap_pyfunction!(regr_avgx))?; + m.add_wrapped(wrap_pyfunction!(regr_avgy))?; + m.add_wrapped(wrap_pyfunction!(regr_count))?; + m.add_wrapped(wrap_pyfunction!(regr_intercept))?; + m.add_wrapped(wrap_pyfunction!(regr_r2))?; + m.add_wrapped(wrap_pyfunction!(regr_slope))?; + m.add_wrapped(wrap_pyfunction!(regr_sxx))?; + m.add_wrapped(wrap_pyfunction!(regr_sxy))?; + m.add_wrapped(wrap_pyfunction!(regr_syy))?; m.add_wrapped(wrap_pyfunction!(first_value))?; m.add_wrapped(wrap_pyfunction!(last_value))?; m.add_wrapped(wrap_pyfunction!(bit_and))?;