From 34b63b2b3560b6d90c35d1500e0ac60e3cd0fedf Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Sun, 14 Apr 2024 12:45:31 +0200 Subject: [PATCH] Add `permute` and `transpose` methods for changing the order of axes of a `PyArray` --- CHANGELOG.md | 1 + src/array.rs | 151 ++++++++++++++++++++++++++++++++++++++----------- tests/array.rs | 41 ++++++++++++-- 3 files changed, 155 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c05da0498..94ecd9080 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Changelog - Unreleased + - Add `permute` and `transpose` methods for changing the order of axes of a `PyArray`. ([#428](https://github.com/PyO3/rust-numpy/pull/428)) - v0.21.0 - Migrate to the new `Bound` API introduced by PyO3 0.21. ([#410](https://github.com/PyO3/rust-numpy/pull/410)) ([#411](https://github.com/PyO3/rust-numpy/pull/411)) ([#412](https://github.com/PyO3/rust-numpy/pull/412)) ([#415](https://github.com/PyO3/rust-numpy/pull/415)) ([#416](https://github.com/PyO3/rust-numpy/pull/416)) ([#418](https://github.com/PyO3/rust-numpy/pull/418)) ([#419](https://github.com/PyO3/rust-numpy/pull/419)) ([#420](https://github.com/PyO3/rust-numpy/pull/420)) ([#421](https://github.com/PyO3/rust-numpy/pull/421)) ([#422](https://github.com/PyO3/rust-numpy/pull/422)) diff --git a/src/array.rs b/src/array.rs index 9392c1554..dfc630727 100644 --- a/src/array.rs +++ b/src/array.rs @@ -1336,8 +1336,45 @@ impl PyArray { self.as_borrowed().cast(is_fortran).map(Bound::into_gil_ref) } + /// A view of `self` with a different order of axes determined by `axes`. + /// + /// If `axes` is `None`, the order of axes is reversed which corresponds to the standard matrix transpose. + /// + /// See also [`numpy.transpose`][numpy-transpose] and [`PyArray_Transpose`][PyArray_Transpose]. + /// + /// # Example + /// + /// ``` + /// use numpy::prelude::*; + /// use numpy::PyArray; + /// use pyo3::Python; + /// use ndarray::array; + /// + /// Python::with_gil(|py| { + /// let array = array![[0, 1, 2], [3, 4, 5]].into_pyarray(py); + /// + /// let array = array.permute(Some([1, 0])).unwrap(); + /// + /// assert_eq!(array.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]); + /// }); + /// ``` + /// + /// [numpy-transpose]: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html + /// [PyArray_Transpose]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Transpose + pub fn permute<'py, ID: IntoDimension>( + &'py self, + axes: Option, + ) -> PyResult<&'py PyArray> { + self.as_borrowed().permute(axes).map(Bound::into_gil_ref) + } + + /// Special case of [`permute`][Self::permute] which reverses the order the axes. + pub fn transpose<'py>(&'py self) -> PyResult<&'py PyArray> { + self.as_borrowed().transpose().map(Bound::into_gil_ref) + } + /// Construct a new array which has same values as self, - /// but has different dimensions specified by `dims` + /// but has different dimensions specified by `shape` /// and a possibly different memory order specified by `order`. /// /// See also [`numpy.reshape`][numpy-reshape] and [`PyArray_Newshape`][PyArray_Newshape]. @@ -1365,11 +1402,11 @@ impl PyArray { /// [PyArray_Newshape]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Newshape pub fn reshape_with_order<'py, ID: IntoDimension>( &'py self, - dims: ID, + shape: ID, order: NPY_ORDER, ) -> PyResult<&'py PyArray> { self.as_borrowed() - .reshape_with_order(dims, order) + .reshape_with_order(shape, order) .map(Bound::into_gil_ref) } @@ -1377,9 +1414,9 @@ impl PyArray { #[inline(always)] pub fn reshape<'py, ID: IntoDimension>( &'py self, - dims: ID, + shape: ID, ) -> PyResult<&'py PyArray> { - self.as_borrowed().reshape(dims).map(Bound::into_gil_ref) + self.as_borrowed().reshape(shape).map(Bound::into_gil_ref) } /// Extends or truncates the dimensions of an array. @@ -1414,8 +1451,8 @@ impl PyArray { /// /// [ndarray-resize]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.resize.html /// [PyArray_Resize]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Resize - pub unsafe fn resize(&self, dims: ID) -> PyResult<()> { - self.as_borrowed().resize(dims) + pub unsafe fn resize(&self, newshape: ID) -> PyResult<()> { + self.as_borrowed().resize(newshape) } } @@ -1879,8 +1916,45 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> { where T: Element; - /// Construct a new array which has same values as self, - /// but has different dimensions specified by `dims` + /// A view of `self` with a different order of axes determined by `axes`. + /// + /// If `axes` is `None`, the order of axes is reversed which corresponds to the standard matrix transpose. + /// + /// See also [`numpy.transpose`][numpy-transpose] and [`PyArray_Transpose`][PyArray_Transpose]. + /// + /// # Example + /// + /// ``` + /// use numpy::prelude::*; + /// use numpy::PyArray; + /// use pyo3::Python; + /// use ndarray::array; + /// + /// Python::with_gil(|py| { + /// let array = array![[0, 1, 2], [3, 4, 5]].into_pyarray_bound(py); + /// + /// let array = array.permute(Some([1, 0])).unwrap(); + /// + /// assert_eq!(array.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]); + /// }); + /// ``` + /// + /// [numpy-transpose]: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html + /// [PyArray_Transpose]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Transpose + fn permute(&self, axes: Option) -> PyResult>> + where + T: Element; + + /// Special case of [`permute`][Self::permute] which reverses the order the axes. + fn transpose(&self) -> PyResult>> + where + T: Element, + { + self.permute::<()>(None) + } + + /// Construct a new array which has same values as `self`, + /// but has different dimensions specified by `shape` /// and a possibly different memory order specified by `order`. /// /// See also [`numpy.reshape`][numpy-reshape] and [`PyArray_Newshape`][PyArray_Newshape]. @@ -1908,7 +1982,7 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> { /// [PyArray_Newshape]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Newshape fn reshape_with_order( &self, - dims: ID, + shape: ID, order: NPY_ORDER, ) -> PyResult>> where @@ -1916,11 +1990,11 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> { /// Special case of [`reshape_with_order`][Self::reshape_with_order] which keeps the memory order the same. #[inline(always)] - fn reshape(&self, dims: ID) -> PyResult>> + fn reshape(&self, shape: ID) -> PyResult>> where T: Element, { - self.reshape_with_order(dims, NPY_ORDER::NPY_ANYORDER) + self.reshape_with_order(shape, NPY_ORDER::NPY_ANYORDER) } /// Extends or truncates the dimensions of an array. @@ -1955,7 +2029,7 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> { /// /// [ndarray-resize]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.resize.html /// [PyArray_Resize]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Resize - unsafe fn resize(&self, dims: ID) -> PyResult<()> + unsafe fn resize(&self, newshape: ID) -> PyResult<()> where T: Element; @@ -2249,55 +2323,66 @@ impl<'py, T, D> PyArrayMethods<'py, T, D> for Bound<'py, PyArray> { if is_fortran { -1 } else { 0 }, ) }; - if !ptr.is_null() { - Ok(unsafe { Bound::from_owned_ptr(self.py(), ptr).downcast_into_unchecked() }) - } else { - Err(PyErr::fetch(self.py())) + unsafe { + Bound::from_owned_ptr_or_err(self.py(), ptr).map(|ob| ob.downcast_into_unchecked()) } } + fn permute(&self, axes: Option) -> PyResult>> { + let mut axes = axes.map(|axes| axes.into_dimension()); + let mut axes = axes.as_mut().map(|axes| axes.to_npy_dims()); + let axes = axes + .as_mut() + .map_or_else(ptr::null_mut, |axes| axes as *mut npyffi::PyArray_Dims); + + let py = self.py(); + let ptr = unsafe { PY_ARRAY_API.PyArray_Transpose(py, self.as_array_ptr(), axes) }; + unsafe { Bound::from_owned_ptr_or_err(py, ptr).map(|ob| ob.downcast_into_unchecked()) } + } + fn reshape_with_order( &self, - dims: ID, + shape: ID, order: NPY_ORDER, ) -> PyResult>> where T: Element, { - let mut dims = dims.into_dimension(); - let mut dims = dims.to_npy_dims(); + let mut shape = shape.into_dimension(); + let mut shape = shape.to_npy_dims(); + + let py = self.py(); let ptr = unsafe { PY_ARRAY_API.PyArray_Newshape( - self.py(), + py, self.as_array_ptr(), - &mut dims as *mut npyffi::PyArray_Dims, + &mut shape as *mut npyffi::PyArray_Dims, order, ) }; - if !ptr.is_null() { - Ok(unsafe { Bound::from_owned_ptr(self.py(), ptr).downcast_into_unchecked() }) - } else { - Err(PyErr::fetch(self.py())) - } + unsafe { Bound::from_owned_ptr_or_err(py, ptr).map(|ob| ob.downcast_into_unchecked()) } } - unsafe fn resize(&self, dims: ID) -> PyResult<()> + unsafe fn resize(&self, newshape: ID) -> PyResult<()> where T: Element, { - let mut dims = dims.into_dimension(); - let mut dims = dims.to_npy_dims(); + let mut newshape = newshape.into_dimension(); + let mut newshape = newshape.to_npy_dims(); + + let py = self.py(); let res = PY_ARRAY_API.PyArray_Resize( - self.py(), + py, self.as_array_ptr(), - &mut dims as *mut npyffi::PyArray_Dims, + &mut newshape as *mut npyffi::PyArray_Dims, 1, NPY_ORDER::NPY_ANYORDER, ); + if !res.is_null() { Ok(()) } else { - Err(PyErr::fetch(self.py())) + Err(PyErr::fetch(py)) } } diff --git a/tests/array.rs b/tests/array.rs index 5a0c4a872..273173c79 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -3,12 +3,10 @@ use std::mem::size_of; #[cfg(feature = "half")] use half::{bf16, f16}; use ndarray::{array, s, Array1, Dim}; +use numpy::prelude::*; use numpy::{ - array::{PyArray0Methods, PyArrayMethods}, - dtype_bound, get_array_module, - npyffi::NPY_ORDER, - pyarray_bound, PyArray, PyArray1, PyArray2, PyArrayDescr, PyArrayDescrMethods, PyArrayDyn, - PyFixedString, PyFixedUnicode, PyUntypedArrayMethods, ToPyArray, + dtype_bound, get_array_module, npyffi::NPY_ORDER, pyarray_bound, PyArray, PyArray1, PyArray2, + PyArrayDescr, PyArrayDyn, PyFixedString, PyFixedUnicode, }; use pyo3::{ py_run, pyclass, pymethods, @@ -522,6 +520,39 @@ fn get_works() { }); } +#[test] +fn permute_and_transpose() { + Python::with_gil(|py| { + let array = array![[0, 1, 2], [3, 4, 5]].into_pyarray_bound(py); + + let permuted = array.permute(Some([1, 0])).unwrap(); + assert_eq!( + permuted.readonly().as_array(), + array![[0, 3], [1, 4], [2, 5]] + ); + + let permuted = array.permute::<()>(None).unwrap(); + assert_eq!( + permuted.readonly().as_array(), + array![[0, 3], [1, 4], [2, 5]] + ); + + let transposed = array.transpose().unwrap(); + assert_eq!( + transposed.readonly().as_array(), + array![[0, 3], [1, 4], [2, 5]] + ); + + let array = pyarray_bound![py, [[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]; + + let permuted = array.permute(Some([0, 2, 1])).unwrap(); + assert_eq!( + permuted.readonly().as_array(), + array![[[1, 3], [2, 4]], [[5, 7], [6, 8]], [[9, 11], [10, 12]]] + ); + }); +} + #[test] fn reshape() { Python::with_gil(|py| {