-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add PyArrayLike wrapper around PyReadonlyArray
Extracts a read-only reference if the correct NumPy array type is given. Tries to convert the input into the correct type using `numpy.asarray` otherwise.
- Loading branch information
1 parent
b5af0ed
commit 1671b00
Showing
4 changed files
with
340 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
use std::marker::PhantomData; | ||
use std::ops::Deref; | ||
|
||
use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; | ||
use pyo3::{intern, sync::GILOnceCell, types::PyDict, FromPyObject, Py, PyAny, PyResult}; | ||
|
||
use crate::sealed::Sealed; | ||
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray}; | ||
|
||
pub trait Coerce: Sealed { | ||
const VAL: bool; | ||
} | ||
|
||
/// Marker type to indicate that the element type received via [`PyArrayLike`] must match the specified type exactly. | ||
#[derive(Debug)] | ||
pub struct TypeMustMatch; | ||
|
||
impl Sealed for TypeMustMatch {} | ||
|
||
impl Coerce for TypeMustMatch { | ||
const VAL: bool = false; | ||
} | ||
|
||
/// Marker type to indicate that the element type received via [`PyArrayLike`] can be cast to the specified type by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html). | ||
#[derive(Debug)] | ||
pub struct AllowTypeChange; | ||
|
||
impl Sealed for AllowTypeChange {} | ||
|
||
impl Coerce for AllowTypeChange { | ||
const VAL: bool = true; | ||
} | ||
|
||
/// Receiver for arrays or array-like types. | ||
/// | ||
/// When building API using NumPy in Python, it is common for functions to additionally accept any array-like type such as `list[float]` as arguments. | ||
/// `PyArrayLike` enables the same pattern in Rust extensions, i.e. by taking this type as the argument of a `#[pyfunction]`, | ||
/// one will always get access to a [`PyReadonlyArray`] that will either reference to the NumPy array originally passed into the function | ||
/// or a temporary one created by converting the input type into a NumPy array. | ||
/// | ||
/// Depending on whether [`TypeMustMatch`] or [`AllowTypeChange`] is used for the `C` type parameter, | ||
/// the element type must either match the specific type `T` exactly or will be cast to it by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html). | ||
/// | ||
/// # Example | ||
/// | ||
/// `PyArrayLike1<'py, T, TypeMustMatch>` will enable you to receive both NumPy arrays and sequences | ||
/// | ||
/// ```rust | ||
/// # use pyo3::prelude::*; | ||
/// use pyo3::py_run; | ||
/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch}; | ||
/// | ||
/// #[pyfunction] | ||
/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, f64, TypeMustMatch>) -> f64 { | ||
/// array.as_array().sum() | ||
/// } | ||
/// | ||
/// Python::with_gil(|py| { | ||
/// let np = get_array_module(py).unwrap(); | ||
/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap(); | ||
/// | ||
/// py_run!(py, np sum_up, r"assert sum_up(np.array([1., 2., 3.])) == 6."); | ||
/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6."); | ||
/// }); | ||
/// ``` | ||
/// | ||
/// but it will not cast the element type if that is required | ||
/// | ||
/// ```rust,should_panic | ||
/// use pyo3::prelude::*; | ||
/// use pyo3::py_run; | ||
/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch}; | ||
/// | ||
/// #[pyfunction] | ||
/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, TypeMustMatch>) -> i32 { | ||
/// array.as_array().sum() | ||
/// } | ||
/// | ||
/// Python::with_gil(|py| { | ||
/// let np = get_array_module(py).unwrap(); | ||
/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap(); | ||
/// | ||
/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6"); | ||
/// }); | ||
/// ``` | ||
/// | ||
/// whereas `PyArrayLike1<'py, T, AllowTypeChange>` will do even at the cost loosing precision | ||
/// | ||
/// ```rust | ||
/// use pyo3::prelude::*; | ||
/// use pyo3::py_run; | ||
/// use numpy::{get_array_module, AllowTypeChange, PyArrayLike1}; | ||
/// | ||
/// #[pyfunction] | ||
/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, AllowTypeChange>) -> i32 { | ||
/// array.as_array().sum() | ||
/// } | ||
/// | ||
/// Python::with_gil(|py| { | ||
/// let np = get_array_module(py).unwrap(); | ||
/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap(); | ||
/// | ||
/// py_run!(py, np sum_up, r"assert sum_up((1.5, 2.5)) == 3"); | ||
/// }); | ||
/// ``` | ||
#[derive(Debug)] | ||
#[repr(transparent)] | ||
pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>) | ||
where | ||
T: Element, | ||
D: Dimension, | ||
C: Coerce; | ||
|
||
impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C> | ||
where | ||
T: Element, | ||
D: Dimension, | ||
C: Coerce, | ||
{ | ||
type Target = PyReadonlyArray<'py, T, D>; | ||
|
||
fn deref(&self) -> &Self::Target { | ||
&self.0 | ||
} | ||
} | ||
|
||
impl<'py, T, D, C> FromPyObject<'py> for PyArrayLike<'py, T, D, C> | ||
where | ||
T: Element, | ||
D: Dimension, | ||
C: Coerce, | ||
Vec<T>: FromPyObject<'py>, | ||
{ | ||
fn extract(ob: &'py PyAny) -> PyResult<Self> { | ||
if let Ok(array) = ob.downcast::<PyArray<T, D>>() { | ||
return Ok(Self(array.readonly(), PhantomData)); | ||
} | ||
|
||
let py = ob.py(); | ||
|
||
if matches!(D::NDIM, None | Some(1)) { | ||
if let Ok(vec) = ob.extract::<Vec<T>>() { | ||
let array = Array1::from(vec) | ||
.into_dimensionality() | ||
.expect("D being compatible to Ix1") | ||
.into_pyarray(py) | ||
.readonly(); | ||
return Ok(Self(array, PhantomData)); | ||
} | ||
} | ||
|
||
static AS_ARRAY: GILOnceCell<Py<PyAny>> = GILOnceCell::new(); | ||
|
||
let as_array = AS_ARRAY | ||
.get_or_try_init(py, || { | ||
get_array_module(py)?.getattr("asarray").map(Into::into) | ||
})? | ||
.as_ref(py); | ||
|
||
let kwargs = if C::VAL { | ||
let kwargs = PyDict::new(py); | ||
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?; | ||
Some(kwargs) | ||
} else { | ||
None | ||
}; | ||
|
||
let array = as_array.call((ob,), kwargs)?.extract()?; | ||
Ok(Self(array, PhantomData)) | ||
} | ||
} | ||
|
||
/// Receiver for zero-dimensional arrays or array-like types. | ||
pub type PyArrayLike0<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix0, C>; | ||
|
||
/// Receiver for one-dimensional arrays or array-like types. | ||
pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>; | ||
|
||
/// Receiver for two-dimensional arrays or array-like types. | ||
pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>; | ||
|
||
/// Receiver for three-dimensional arrays or array-like types. | ||
pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>; | ||
|
||
/// Receiver for four-dimensional arrays or array-like types. | ||
pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>; | ||
|
||
/// Receiver for five-dimensional arrays or array-like types. | ||
pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>; | ||
|
||
/// Receiver for six-dimensional arrays or array-like types. | ||
pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>; | ||
|
||
/// Receiver for arrays or array-like types whose dimensionality is determined at runtime. | ||
pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
use ndarray::array; | ||
use numpy::{get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn}; | ||
use pyo3::{ | ||
types::{IntoPyDict, PyDict}, | ||
Python, | ||
}; | ||
|
||
fn get_np_locals<'py>(py: Python<'py>) -> &'py PyDict { | ||
[("np", get_array_module(py).unwrap())].into_py_dict(py) | ||
} | ||
|
||
#[test] | ||
fn extract_reference() { | ||
Python::with_gil(|py| { | ||
let locals = get_np_locals(py); | ||
let py_array = py | ||
.eval( | ||
"np.array([[1,2],[3,4]], dtype='float64')", | ||
Some(locals), | ||
None, | ||
) | ||
.unwrap(); | ||
let extracted_array = py_array.extract::<PyArrayLike2<'_, f64>>().unwrap(); | ||
|
||
assert_eq!( | ||
array![[1_f64, 2_f64], [3_f64, 4_f64]], | ||
extracted_array.as_array() | ||
); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn convert_array_on_extract() { | ||
Python::with_gil(|py| { | ||
let locals = get_np_locals(py); | ||
let py_array = py | ||
.eval("np.array([[1,2],[3,4]], dtype='int32')", Some(locals), None) | ||
.unwrap(); | ||
let extracted_array = py_array | ||
.extract::<PyArrayLike2<'_, f64, AllowTypeChange>>() | ||
.unwrap(); | ||
|
||
assert_eq!( | ||
array![[1_f64, 2_f64], [3_f64, 4_f64]], | ||
extracted_array.as_array() | ||
); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn convert_list_on_extract() { | ||
Python::with_gil(|py| { | ||
let py_list = py.eval("[[1.0,2.0],[3.0,4.0]]", None, None).unwrap(); | ||
let extracted_array = py_list.extract::<PyArrayLike2<'_, f64>>().unwrap(); | ||
|
||
assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array()); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn convert_array_in_list_on_extract() { | ||
Python::with_gil(|py| { | ||
let locals = get_np_locals(py); | ||
let py_array = py | ||
.eval("[np.array([1.0, 2.0]), [3.0, 4.0]]", Some(locals), None) | ||
.unwrap(); | ||
let extracted_array = py_array.extract::<PyArrayLike2<'_, f64>>().unwrap(); | ||
|
||
assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array()); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn convert_list_on_extract_dyn() { | ||
Python::with_gil(|py| { | ||
let py_list = py | ||
.eval("[[[1,2],[3,4]],[[5,6],[7,8]]]", None, None) | ||
.unwrap(); | ||
let extracted_array = py_list | ||
.extract::<PyArrayLikeDyn<'_, i64, AllowTypeChange>>() | ||
.unwrap(); | ||
|
||
assert_eq!( | ||
array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(), | ||
extracted_array.as_array() | ||
); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn convert_1d_list_on_extract() { | ||
Python::with_gil(|py| { | ||
let py_list = py.eval("[1,2,3,4]", None, None).unwrap(); | ||
let extracted_array_1d = py_list.extract::<PyArrayLike1<'_, u32>>().unwrap(); | ||
let extracted_array_dyn = py_list.extract::<PyArrayLikeDyn<'_, f64>>().unwrap(); | ||
|
||
assert_eq!(array![1, 2, 3, 4], extracted_array_1d.as_array()); | ||
assert_eq!( | ||
array![1_f64, 2_f64, 3_f64, 4_f64].into_dyn(), | ||
extracted_array_dyn.as_array() | ||
); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn unsafe_cast_shall_fail() { | ||
Python::with_gil(|py| { | ||
let locals = get_np_locals(py); | ||
let py_list = py | ||
.eval( | ||
"np.array([1.1,2.2,3.3,4.4], dtype='float64')", | ||
Some(locals), | ||
None, | ||
) | ||
.unwrap(); | ||
let extracted_array = py_list.extract::<PyArrayLike1<'_, i32>>(); | ||
|
||
assert!(extracted_array.is_err()); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn unsafe_cast_with_coerce_works() { | ||
Python::with_gil(|py| { | ||
let locals = get_np_locals(py); | ||
let py_list = py | ||
.eval( | ||
"np.array([1.1,2.2,3.3,4.4], dtype='float64')", | ||
Some(locals), | ||
None, | ||
) | ||
.unwrap(); | ||
let extracted_array = py_list | ||
.extract::<PyArrayLike1<'_, i32, AllowTypeChange>>() | ||
.unwrap(); | ||
|
||
assert_eq!(array![1, 2, 3, 4], extracted_array.as_array()); | ||
}); | ||
} |