Skip to content

Commit

Permalink
add regr_* functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangzhx committed Oct 7, 2023
1 parent 9ef0a57 commit c94bc6a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
22 changes: 22 additions & 0 deletions datafusion/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,28 @@ def test_case(df):
assert result.column(2) == pa.array(["Hola", "Mundo", None])


def test_regr_funcs(df):
# test case base on
# https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330
ctx = SessionContext()
result = ctx.sql(
"select regr_slope(1,1), regr_intercept(1,1), "
"regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), "
"regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), "
"regr_sxy(1,1);"
).collect()

assert result[0].column(0) == pa.array([None], type=pa.float64())
assert result[0].column(1) == pa.array([None], type=pa.float64())
assert result[0].column(2) == pa.array([1], type=pa.float64())
assert result[0].column(3) == pa.array([None], type=pa.float64())
assert result[0].column(4) == pa.array([1], type=pa.float64())
assert result[0].column(5) == pa.array([1], type=pa.float64())
assert result[0].column(6) == pa.array([0], type=pa.float64())
assert result[0].column(7) == pa.array([0], type=pa.float64())
assert result[0].column(8) == pa.array([0], type=pa.float64())


def test_binary_string_functions(df):
df = df.select(
f.encode(column("a"), literal("base64")),
Expand Down
18 changes: 18 additions & 0 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,15 @@ aggregate_function!(stddev_samp, Stddev);
aggregate_function!(var, Variance);
aggregate_function!(var_pop, VariancePop);
aggregate_function!(var_samp, Variance);
aggregate_function!(regr_avgx, RegrAvgx);
aggregate_function!(regr_avgy, RegrAvgy);
aggregate_function!(regr_count, RegrCount);
aggregate_function!(regr_intercept, RegrIntercept);
aggregate_function!(regr_r2, RegrR2);
aggregate_function!(regr_slope, RegrSlope);
aggregate_function!(regr_sxx, RegrSXX);
aggregate_function!(regr_sxy, RegrSXY);
aggregate_function!(regr_syy, RegrSYY);

pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(abs))?;
Expand Down Expand Up @@ -489,6 +498,15 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(var_pop))?;
m.add_wrapped(wrap_pyfunction!(var_samp))?;
m.add_wrapped(wrap_pyfunction!(window))?;
m.add_wrapped(wrap_pyfunction!(regr_avgx))?;
m.add_wrapped(wrap_pyfunction!(regr_avgy))?;
m.add_wrapped(wrap_pyfunction!(regr_count))?;
m.add_wrapped(wrap_pyfunction!(regr_intercept))?;
m.add_wrapped(wrap_pyfunction!(regr_r2))?;
m.add_wrapped(wrap_pyfunction!(regr_slope))?;
m.add_wrapped(wrap_pyfunction!(regr_sxx))?;
m.add_wrapped(wrap_pyfunction!(regr_sxy))?;
m.add_wrapped(wrap_pyfunction!(regr_syy))?;

//Binary String Functions
m.add_wrapped(wrap_pyfunction!(encode))?;
Expand Down

0 comments on commit c94bc6a

Please sign in to comment.