Skip to content

Commit

Permalink
Add PyArrayLike wrapper around PyReadonlyArray
Browse files Browse the repository at this point in the history
Extracts a read-only reference if the correct NumPy array type is given.
Tries to convert the input into the correct type using `numpy.asarray`
otherwise.
  • Loading branch information
124C41p authored and adamreichold committed Oct 7, 2023
1 parent b5af0ed commit 1671b00
Show file tree
Hide file tree
Showing 4 changed files with 340 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- Increase MSRV to 1.56 released in October 2021 and available in Debain 12, RHEL 9 and Alpine 3.17 following the same change for PyO3. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
- Add support for ASCII (`PyFixedString<N>`) and Unicode (`PyFixedUnicode<N>`) string arrays, i.e. dtypes `SN` and `UN` where `N` is the number of characters. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
- Add support for the `bfloat16` dtype by extending the optional integration with the `half` crate. Note that the `bfloat16` dtype is not part of NumPy itself so that usage requires third-party packages like Tensorflow. ([#381](https://github.com/PyO3/rust-numpy/pull/381))
- Add `PyArrayLike` type which extracts `PyReadonlyArray` if a NumPy array of the correct type is given and attempts a conversion using `numpy.asarray` otherwise. ([#383](https://github.com/PyO3/rust-numpy/pull/383))

- v0.19.0
- Add `PyUntypedArray` as an untyped base type for `PyArray` which can be used to inspect arguments before more targeted downcasts. This is accompanied by some methods like `dtype` and `shape` moving from `PyArray` to `PyUntypedArray`. They are still accessible though, as `PyArray` dereferences to `PyUntypedArray` via the `Deref` trait. ([#369](https://github.com/PyO3/rust-numpy/pull/369))
Expand Down
195 changes: 195 additions & 0 deletions src/array_like.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
use std::marker::PhantomData;
use std::ops::Deref;

use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
use pyo3::{intern, sync::GILOnceCell, types::PyDict, FromPyObject, Py, PyAny, PyResult};

use crate::sealed::Sealed;
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray};

pub trait Coerce: Sealed {
const VAL: bool;
}

/// Marker type to indicate that the element type received via [`PyArrayLike`] must match the specified type exactly.
#[derive(Debug)]
pub struct TypeMustMatch;

impl Sealed for TypeMustMatch {}

impl Coerce for TypeMustMatch {
const VAL: bool = false;
}

/// Marker type to indicate that the element type received via [`PyArrayLike`] can be cast to the specified type by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
#[derive(Debug)]
pub struct AllowTypeChange;

impl Sealed for AllowTypeChange {}

impl Coerce for AllowTypeChange {
const VAL: bool = true;
}

/// Receiver for arrays or array-like types.
///
/// When building API using NumPy in Python, it is common for functions to additionally accept any array-like type such as `list[float]` as arguments.
/// `PyArrayLike` enables the same pattern in Rust extensions, i.e. by taking this type as the argument of a `#[pyfunction]`,
/// one will always get access to a [`PyReadonlyArray`] that will either reference to the NumPy array originally passed into the function
/// or a temporary one created by converting the input type into a NumPy array.
///
/// Depending on whether [`TypeMustMatch`] or [`AllowTypeChange`] is used for the `C` type parameter,
/// the element type must either match the specific type `T` exactly or will be cast to it by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
///
/// # Example
///
/// `PyArrayLike1<'py, T, TypeMustMatch>` will enable you to receive both NumPy arrays and sequences
///
/// ```rust
/// # use pyo3::prelude::*;
/// use pyo3::py_run;
/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
///
/// #[pyfunction]
/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, f64, TypeMustMatch>) -> f64 {
/// array.as_array().sum()
/// }
///
/// Python::with_gil(|py| {
/// let np = get_array_module(py).unwrap();
/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
///
/// py_run!(py, np sum_up, r"assert sum_up(np.array([1., 2., 3.])) == 6.");
/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6.");
/// });
/// ```
///
/// but it will not cast the element type if that is required
///
/// ```rust,should_panic
/// use pyo3::prelude::*;
/// use pyo3::py_run;
/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
///
/// #[pyfunction]
/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, TypeMustMatch>) -> i32 {
/// array.as_array().sum()
/// }
///
/// Python::with_gil(|py| {
/// let np = get_array_module(py).unwrap();
/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
///
/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6");
/// });
/// ```
///
/// whereas `PyArrayLike1<'py, T, AllowTypeChange>` will do even at the cost loosing precision
///
/// ```rust
/// use pyo3::prelude::*;
/// use pyo3::py_run;
/// use numpy::{get_array_module, AllowTypeChange, PyArrayLike1};
///
/// #[pyfunction]
/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, AllowTypeChange>) -> i32 {
/// array.as_array().sum()
/// }
///
/// Python::with_gil(|py| {
/// let np = get_array_module(py).unwrap();
/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
///
/// py_run!(py, np sum_up, r"assert sum_up((1.5, 2.5)) == 3");
/// });
/// ```
#[derive(Debug)]
#[repr(transparent)]
pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
where
T: Element,
D: Dimension,
C: Coerce;

impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
where
T: Element,
D: Dimension,
C: Coerce,
{
type Target = PyReadonlyArray<'py, T, D>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<'py, T, D, C> FromPyObject<'py> for PyArrayLike<'py, T, D, C>
where
T: Element,
D: Dimension,
C: Coerce,
Vec<T>: FromPyObject<'py>,
{
fn extract(ob: &'py PyAny) -> PyResult<Self> {
if let Ok(array) = ob.downcast::<PyArray<T, D>>() {
return Ok(Self(array.readonly(), PhantomData));
}

let py = ob.py();

if matches!(D::NDIM, None | Some(1)) {
if let Ok(vec) = ob.extract::<Vec<T>>() {
let array = Array1::from(vec)
.into_dimensionality()
.expect("D being compatible to Ix1")
.into_pyarray(py)
.readonly();
return Ok(Self(array, PhantomData));
}
}

static AS_ARRAY: GILOnceCell<Py<PyAny>> = GILOnceCell::new();

let as_array = AS_ARRAY
.get_or_try_init(py, || {
get_array_module(py)?.getattr("asarray").map(Into::into)
})?
.as_ref(py);

let kwargs = if C::VAL {
let kwargs = PyDict::new(py);
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
Some(kwargs)
} else {
None
};

let array = as_array.call((ob,), kwargs)?.extract()?;
Ok(Self(array, PhantomData))
}
}

/// Receiver for zero-dimensional arrays or array-like types.
pub type PyArrayLike0<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix0, C>;

/// Receiver for one-dimensional arrays or array-like types.
pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;

/// Receiver for two-dimensional arrays or array-like types.
pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;

/// Receiver for three-dimensional arrays or array-like types.
pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;

/// Receiver for four-dimensional arrays or array-like types.
pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;

/// Receiver for five-dimensional arrays or array-like types.
pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;

/// Receiver for six-dimensional arrays or array-like types.
pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;

/// Receiver for arrays or array-like types whose dimensionality is determined at runtime.
pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ as well as the [`PyReadonlyArray::try_as_matrix`] and [`PyReadwriteArray::try_as
#![deny(missing_docs, missing_debug_implementations)]

pub mod array;
mod array_like;
pub mod borrow;
pub mod convert;
pub mod datetime;
Expand All @@ -94,6 +95,10 @@ pub use crate::array::{
get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5,
PyArray6, PyArrayDyn,
};
pub use crate::array_like::{
AllowTypeChange, PyArrayLike, PyArrayLike0, PyArrayLike1, PyArrayLike2, PyArrayLike3,
PyArrayLike4, PyArrayLike5, PyArrayLike6, PyArrayLikeDyn, TypeMustMatch,
};
pub use crate::borrow::{
PyReadonlyArray, PyReadonlyArray0, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3,
PyReadonlyArray4, PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn, PyReadwriteArray,
Expand Down
139 changes: 139 additions & 0 deletions tests/array_like.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use ndarray::array;
use numpy::{get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn};
use pyo3::{
types::{IntoPyDict, PyDict},
Python,
};

fn get_np_locals<'py>(py: Python<'py>) -> &'py PyDict {
[("np", get_array_module(py).unwrap())].into_py_dict(py)
}

#[test]
fn extract_reference() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_array = py
.eval(
"np.array([[1,2],[3,4]], dtype='float64')",
Some(locals),
None,
)
.unwrap();
let extracted_array = py_array.extract::<PyArrayLike2<'_, f64>>().unwrap();

assert_eq!(
array![[1_f64, 2_f64], [3_f64, 4_f64]],
extracted_array.as_array()
);
});
}

#[test]
fn convert_array_on_extract() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_array = py
.eval("np.array([[1,2],[3,4]], dtype='int32')", Some(locals), None)
.unwrap();
let extracted_array = py_array
.extract::<PyArrayLike2<'_, f64, AllowTypeChange>>()
.unwrap();

assert_eq!(
array![[1_f64, 2_f64], [3_f64, 4_f64]],
extracted_array.as_array()
);
});
}

#[test]
fn convert_list_on_extract() {
Python::with_gil(|py| {
let py_list = py.eval("[[1.0,2.0],[3.0,4.0]]", None, None).unwrap();
let extracted_array = py_list.extract::<PyArrayLike2<'_, f64>>().unwrap();

assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array());
});
}

#[test]
fn convert_array_in_list_on_extract() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_array = py
.eval("[np.array([1.0, 2.0]), [3.0, 4.0]]", Some(locals), None)
.unwrap();
let extracted_array = py_array.extract::<PyArrayLike2<'_, f64>>().unwrap();

assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array());
});
}

#[test]
fn convert_list_on_extract_dyn() {
Python::with_gil(|py| {
let py_list = py
.eval("[[[1,2],[3,4]],[[5,6],[7,8]]]", None, None)
.unwrap();
let extracted_array = py_list
.extract::<PyArrayLikeDyn<'_, i64, AllowTypeChange>>()
.unwrap();

assert_eq!(
array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
extracted_array.as_array()
);
});
}

#[test]
fn convert_1d_list_on_extract() {
Python::with_gil(|py| {
let py_list = py.eval("[1,2,3,4]", None, None).unwrap();
let extracted_array_1d = py_list.extract::<PyArrayLike1<'_, u32>>().unwrap();
let extracted_array_dyn = py_list.extract::<PyArrayLikeDyn<'_, f64>>().unwrap();

assert_eq!(array![1, 2, 3, 4], extracted_array_1d.as_array());
assert_eq!(
array![1_f64, 2_f64, 3_f64, 4_f64].into_dyn(),
extracted_array_dyn.as_array()
);
});
}

#[test]
fn unsafe_cast_shall_fail() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_list = py
.eval(
"np.array([1.1,2.2,3.3,4.4], dtype='float64')",
Some(locals),
None,
)
.unwrap();
let extracted_array = py_list.extract::<PyArrayLike1<'_, i32>>();

assert!(extracted_array.is_err());
});
}

#[test]
fn unsafe_cast_with_coerce_works() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_list = py
.eval(
"np.array([1.1,2.2,3.3,4.4], dtype='float64')",
Some(locals),
None,
)
.unwrap();
let extracted_array = py_list
.extract::<PyArrayLike1<'_, i32, AllowTypeChange>>()
.unwrap();

assert_eq!(array![1, 2, 3, 4], extracted_array.as_array());
});
}

0 comments on commit 1671b00

Please sign in to comment.