Skip to content

Commit

Permalink
feat: expose PyWindowFrame (#509)
Browse files Browse the repository at this point in the history
* feat: expose PyWindowFrame

* fix: PyWindowFrame: return Err instead of panicking

* test: test PyWindowFrame creation
  • Loading branch information
dlovell authored Oct 17, 2023
1 parent 5ec45dd commit 399fa75
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 5 deletions.
2 changes: 2 additions & 0 deletions datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
SessionConfig,
RuntimeConfig,
ScalarUDF,
WindowFrame,
)

from .common import (
Expand Down Expand Up @@ -98,6 +99,7 @@
"Expr",
"AggregateUDF",
"ScalarUDF",
"WindowFrame",
"column",
"literal",
"TableScan",
Expand Down
41 changes: 40 additions & 1 deletion datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
import pytest

from datafusion import functions as f
from datafusion import DataFrame, SessionContext, column, literal, udf
from datafusion import (
DataFrame,
SessionContext,
WindowFrame,
column,
literal,
udf,
)


@pytest.fixture
Expand Down Expand Up @@ -304,6 +311,38 @@ def test_window_functions(df):
assert table.sort_by("a").to_pydict() == expected


@pytest.mark.parametrize(
("units", "start_bound", "end_bound"),
[
(units, start_bound, end_bound)
for units in ("rows", "range")
for start_bound in (None, 0, 1)
for end_bound in (None, 0, 1)
]
+ [
("groups", 0, 0),
],
)
def test_valid_window_frame(units, start_bound, end_bound):
WindowFrame(units, start_bound, end_bound)


@pytest.mark.parametrize(
("units", "start_bound", "end_bound"),
[
("invalid-units", 0, None),
("invalid-units", None, 0),
("invalid-units", None, None),
("groups", None, 0),
("groups", 0, None),
("groups", None, None),
],
)
def test_invalid_window_frame(units, start_bound, end_bound):
with pytest.raises(RuntimeError):
WindowFrame(units, start_bound, end_bound)


def test_get_dataframe(tmp_path):
ctx = SessionContext()

Expand Down
20 changes: 17 additions & 3 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@

use pyo3::{prelude::*, wrap_pyfunction};

use crate::context::PySessionContext;
use crate::errors::DataFusionError;
use crate::expr::conditional_expr::PyCaseBuilder;
use crate::expr::PyExpr;
use crate::window_frame::PyWindowFrame;
use datafusion::execution::FunctionRegistry;
use datafusion_common::Column;
use datafusion_expr::expr::Alias;
use datafusion_expr::{
aggregate_function,
expr::{AggregateFunction, ScalarFunction, Sort, WindowFunction},
lit,
window_function::find_df_window_func,
BuiltinScalarFunction, Expr, WindowFrame,
BuiltinScalarFunction, Expr,
};

#[pyfunction]
Expand Down Expand Up @@ -130,13 +133,24 @@ fn window(
args: Vec<PyExpr>,
partition_by: Option<Vec<PyExpr>>,
order_by: Option<Vec<PyExpr>>,
window_frame: Option<PyWindowFrame>,
ctx: Option<PySessionContext>,
) -> PyResult<PyExpr> {
let fun = find_df_window_func(name);
let fun = find_df_window_func(name).or_else(|| {
ctx.and_then(|ctx| {
ctx.ctx
.udaf(name)
.map(|fun| datafusion_expr::WindowFunction::AggregateUDF(fun))
.ok()
})
});
if fun.is_none() {
return Err(DataFusionError::Common("window function not found".to_string()).into());
}
let fun = fun.unwrap();
let window_frame = WindowFrame::new(order_by.is_some());
let window_frame = window_frame
.unwrap_or_else(|| PyWindowFrame::new("rows", None, Some(0)).unwrap())
.into();
Ok(PyExpr {
expr: datafusion_expr::Expr::WindowFunction(WindowFunction {
fun,
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ mod udaf;
#[allow(clippy::borrow_deref_ref)]
mod udf;
pub mod utils;
mod window_frame;

#[cfg(feature = "mimalloc")]
#[global_allocator]
Expand Down Expand Up @@ -83,6 +84,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<context::PySessionContext>()?;
m.add_class::<dataframe::PyDataFrame>()?;
m.add_class::<udf::PyScalarUDF>()?;
m.add_class::<window_frame::PyWindowFrame>()?;
m.add_class::<udaf::PyAggregateUDF>()?;
m.add_class::<config::PyConfig>()?;
m.add_class::<sql::logical::PyLogicalPlan>()?;
Expand Down
30 changes: 29 additions & 1 deletion src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use std::sync::Arc;

use pyo3::{prelude::*, types::PyTuple};
use pyo3::{prelude::*, types::PyBool, types::PyTuple};

use datafusion::arrow::array::{Array, ArrayRef};
use datafusion::arrow::datatypes::DataType;
Expand Down Expand Up @@ -93,6 +93,34 @@ impl Accumulator for RustAccumulator {
fn size(&self) -> usize {
std::mem::size_of_val(self)
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
Python::with_gil(|py| {
// 1. cast args to Pyarrow array
let py_args = values
.iter()
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
let py_args = PyTuple::new(py, py_args);

// 2. call function
self.accum
.as_ref(py)
.call_method1("retract_batch", py_args)
.map_err(|e| DataFusionError::Execution(format!("{e}")))?;

Ok(())
})
}

fn supports_retract_batch(&self) -> bool {
Python::with_gil(|py| {
let x: Result<&PyAny, PyErr> =
self.accum.as_ref(py).call_method0("supports_retract_batch");
let x: &PyAny = x.unwrap_or(PyBool::new(py, false));
x.extract().unwrap_or(false)
})
}
}

pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction {
Expand Down
110 changes: 110 additions & 0 deletions src/window_frame.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use pyo3::prelude::*;
use std::fmt::{Display, Formatter};

use crate::errors::py_datafusion_err;

#[pyclass(name = "WindowFrame", module = "datafusion", subclass)]
#[derive(Clone)]
pub struct PyWindowFrame {
frame: WindowFrame,
}

impl From<PyWindowFrame> for WindowFrame {
fn from(frame: PyWindowFrame) -> Self {
frame.frame
}
}

impl From<WindowFrame> for PyWindowFrame {
fn from(frame: WindowFrame) -> PyWindowFrame {
PyWindowFrame { frame }
}
}

impl Display for PyWindowFrame {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
write!(
f,
"OVER ({} BETWEEN {} AND {})",
self.frame.units, self.frame.start_bound, self.frame.end_bound
)
}
}

#[pymethods]
impl PyWindowFrame {
#[new(unit, start_bound, end_bound)]
pub fn new(units: &str, start_bound: Option<u64>, end_bound: Option<u64>) -> PyResult<Self> {
let units = units.to_ascii_lowercase();
let units = match units.as_str() {
"rows" => WindowFrameUnits::Rows,
"range" => WindowFrameUnits::Range,
"groups" => WindowFrameUnits::Groups,
_ => {
return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
"{:?}",
units,
))));
}
};
let start_bound = match start_bound {
Some(start_bound) => {
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(start_bound)))
}
None => match units {
WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Groups => {
return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
"{:?}",
units,
))));
}
},
};
let end_bound = match end_bound {
Some(end_bound) => WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound))),
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Groups => {
return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
"{:?}",
units,
))));
}
},
};
Ok(PyWindowFrame {
frame: WindowFrame {
units,
start_bound,
end_bound,
},
})
}

/// Get a String representation of this window frame
fn __repr__(&self) -> String {
format!("{}", self)
}
}

0 comments on commit 399fa75

Please sign in to comment.