Skip to content

Commit

Permalink
Add permute and transpose methods for changing the order of axes …
Browse files Browse the repository at this point in the history
…of a `PyArray`
  • Loading branch information
adamreichold committed Apr 14, 2024
1 parent 0832b28 commit 053407d
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 31 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changelog

- Unreleased
- Add `permute` and `transpose` methods for changing the order of axes of a `PyArray`. ([#428](https://github.com/PyO3/rust-numpy/pull/428))

- v0.21.0
- Migrate to the new `Bound` API introduced by PyO3 0.21. ([#410](https://github.com/PyO3/rust-numpy/pull/410)) ([#411](https://github.com/PyO3/rust-numpy/pull/411)) ([#412](https://github.com/PyO3/rust-numpy/pull/412)) ([#415](https://github.com/PyO3/rust-numpy/pull/415)) ([#416](https://github.com/PyO3/rust-numpy/pull/416)) ([#418](https://github.com/PyO3/rust-numpy/pull/418)) ([#419](https://github.com/PyO3/rust-numpy/pull/419)) ([#420](https://github.com/PyO3/rust-numpy/pull/420)) ([#421](https://github.com/PyO3/rust-numpy/pull/421)) ([#422](https://github.com/PyO3/rust-numpy/pull/422))
Expand Down
148 changes: 122 additions & 26 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1336,8 +1336,45 @@ impl<T: Element, D> PyArray<T, D> {
self.as_borrowed().cast(is_fortran).map(Bound::into_gil_ref)
}

/// A view of `self` with a different order of axes determined by `axes`.
///
/// If `axes` is `None`, the order of axes is reversed which corresponds to the standard matix transpose.
///
/// See also [`numpy.transpose`][numpy-transpose] and [`PyArray_Transpose`][PyArray_Transpose].
///
/// # Example
///
/// ```
/// use numpy::prelude::*;
/// use numpy::PyArray;
/// use pyo3::Python;
/// use ndarray::array;
///
/// Python::with_gil(|py| {
/// let array = array![[0, 1, 2], [3, 4, 5]].into_pyarray(py);
///
/// let array = array.permute(Some([1, 0])).unwrap();
///
/// assert_eq!(array.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]);
/// });
/// ```
///
/// [numpy-transpose]: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html
/// [PyArray_Transpose]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Transpose
pub fn permute<'py, ID: IntoDimension>(
&'py self,
axes: Option<ID>,
) -> PyResult<&'py PyArray<T, D>> {
self.as_borrowed().permute(axes).map(Bound::into_gil_ref)
}

/// Special case of [`permute`][Self::permute] which reverses the order the axes.
pub fn transpose<'py>(&'py self) -> PyResult<&'py PyArray<T, D>> {
self.as_borrowed().transpose().map(Bound::into_gil_ref)
}

/// Construct a new array which has same values as self,
/// but has different dimensions specified by `dims`
/// but has different dimensions specified by `shape`
/// and a possibly different memory order specified by `order`.
///
/// See also [`numpy.reshape`][numpy-reshape] and [`PyArray_Newshape`][PyArray_Newshape].
Expand Down Expand Up @@ -1365,21 +1402,21 @@ impl<T: Element, D> PyArray<T, D> {
/// [PyArray_Newshape]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Newshape
pub fn reshape_with_order<'py, ID: IntoDimension>(
&'py self,
dims: ID,
shape: ID,
order: NPY_ORDER,
) -> PyResult<&'py PyArray<T, ID::Dim>> {
self.as_borrowed()
.reshape_with_order(dims, order)
.reshape_with_order(shape, order)
.map(Bound::into_gil_ref)
}

/// Special case of [`reshape_with_order`][Self::reshape_with_order] which keeps the memory order the same.
#[inline(always)]
pub fn reshape<'py, ID: IntoDimension>(
&'py self,
dims: ID,
shape: ID,
) -> PyResult<&'py PyArray<T, ID::Dim>> {
self.as_borrowed().reshape(dims).map(Bound::into_gil_ref)
self.as_borrowed().reshape(shape).map(Bound::into_gil_ref)
}

/// Extends or truncates the dimensions of an array.
Expand Down Expand Up @@ -1414,8 +1451,8 @@ impl<T: Element, D> PyArray<T, D> {
///
/// [ndarray-resize]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.resize.html
/// [PyArray_Resize]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Resize
pub unsafe fn resize<ID: IntoDimension>(&self, dims: ID) -> PyResult<()> {
self.as_borrowed().resize(dims)
pub unsafe fn resize<ID: IntoDimension>(&self, newshape: ID) -> PyResult<()> {
self.as_borrowed().resize(newshape)
}
}

Expand Down Expand Up @@ -1879,8 +1916,45 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
where
T: Element;

/// Construct a new array which has same values as self,
/// but has different dimensions specified by `dims`
/// A view of `self` with a different order of axes determined by `axes`.
///
/// If `axes` is `None`, the order of axes is reversed which corresponds to the standard matix transpose.
///
/// See also [`numpy.transpose`][numpy-transpose] and [`PyArray_Transpose`][PyArray_Transpose].
///
/// # Example
///
/// ```
/// use numpy::prelude::*;
/// use numpy::PyArray;
/// use pyo3::Python;
/// use ndarray::array;
///
/// Python::with_gil(|py| {
/// let array = array![[0, 1, 2], [3, 4, 5]].into_pyarray_bound(py);
///
/// let array = array.permute(Some([1, 0])).unwrap();
///
/// assert_eq!(array.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]);
/// });
/// ```
///
/// [numpy-transpose]: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html
/// [PyArray_Transpose]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Transpose
fn permute<ID: IntoDimension>(&self, axes: Option<ID>) -> PyResult<Bound<'py, PyArray<T, D>>>
where
T: Element;

/// Special case of [`permute`][Self::permute] which reverses the order the axes.
fn transpose(&self) -> PyResult<Bound<'py, PyArray<T, D>>>
where
T: Element,
{
self.permute::<()>(None)
}

/// Construct a new array which has same values as `self`,
/// but has different dimensions specified by `shape`
/// and a possibly different memory order specified by `order`.
///
/// See also [`numpy.reshape`][numpy-reshape] and [`PyArray_Newshape`][PyArray_Newshape].
Expand Down Expand Up @@ -1908,19 +1982,19 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
/// [PyArray_Newshape]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Newshape
fn reshape_with_order<ID: IntoDimension>(
&self,
dims: ID,
shape: ID,
order: NPY_ORDER,
) -> PyResult<Bound<'py, PyArray<T, ID::Dim>>>
where
T: Element;

/// Special case of [`reshape_with_order`][Self::reshape_with_order] which keeps the memory order the same.
#[inline(always)]
fn reshape<ID: IntoDimension>(&self, dims: ID) -> PyResult<Bound<'py, PyArray<T, ID::Dim>>>
fn reshape<ID: IntoDimension>(&self, shape: ID) -> PyResult<Bound<'py, PyArray<T, ID::Dim>>>
where
T: Element,
{
self.reshape_with_order(dims, NPY_ORDER::NPY_ANYORDER)
self.reshape_with_order(shape, NPY_ORDER::NPY_ANYORDER)
}

/// Extends or truncates the dimensions of an array.
Expand Down Expand Up @@ -1955,7 +2029,7 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
///
/// [ndarray-resize]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.resize.html
/// [PyArray_Resize]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Resize
unsafe fn resize<ID: IntoDimension>(&self, dims: ID) -> PyResult<()>
unsafe fn resize<ID: IntoDimension>(&self, newshape: ID) -> PyResult<()>
where
T: Element;

Expand Down Expand Up @@ -2256,48 +2330,70 @@ impl<'py, T, D> PyArrayMethods<'py, T, D> for Bound<'py, PyArray<T, D>> {
}
}

fn permute<ID: IntoDimension>(&self, axes: Option<ID>) -> PyResult<Bound<'py, PyArray<T, D>>> {
let mut axes = axes.map(|axes| axes.into_dimension());
let mut axes = axes.as_mut().map(|axes| axes.to_npy_dims());
let axes = axes
.as_mut()
.map_or_else(ptr::null_mut, |axes| axes as *mut npyffi::PyArray_Dims);

let py = self.py();
let ptr = unsafe { PY_ARRAY_API.PyArray_Transpose(py, self.as_array_ptr(), axes) };
if !ptr.is_null() {
Ok(unsafe { Bound::from_owned_ptr(py, ptr).downcast_into_unchecked() })
} else {
Err(PyErr::fetch(py))
}
}

fn reshape_with_order<ID: IntoDimension>(
&self,
dims: ID,
shape: ID,
order: NPY_ORDER,
) -> PyResult<Bound<'py, PyArray<T, ID::Dim>>>
where
T: Element,
{
let mut dims = dims.into_dimension();
let mut dims = dims.to_npy_dims();
let mut shape = shape.into_dimension();
let mut shape = shape.to_npy_dims();

let py = self.py();
let ptr = unsafe {
PY_ARRAY_API.PyArray_Newshape(
self.py(),
py,
self.as_array_ptr(),
&mut dims as *mut npyffi::PyArray_Dims,
&mut shape as *mut npyffi::PyArray_Dims,
order,
)
};

if !ptr.is_null() {
Ok(unsafe { Bound::from_owned_ptr(self.py(), ptr).downcast_into_unchecked() })
Ok(unsafe { Bound::from_owned_ptr(py, ptr).downcast_into_unchecked() })
} else {
Err(PyErr::fetch(self.py()))
Err(PyErr::fetch(py))
}
}

unsafe fn resize<ID: IntoDimension>(&self, dims: ID) -> PyResult<()>
unsafe fn resize<ID: IntoDimension>(&self, newshape: ID) -> PyResult<()>
where
T: Element,
{
let mut dims = dims.into_dimension();
let mut dims = dims.to_npy_dims();
let mut newshape = newshape.into_dimension();
let mut newshape = newshape.to_npy_dims();

let py = self.py();
let res = PY_ARRAY_API.PyArray_Resize(
self.py(),
py,
self.as_array_ptr(),
&mut dims as *mut npyffi::PyArray_Dims,
&mut newshape as *mut npyffi::PyArray_Dims,
1,
NPY_ORDER::NPY_ANYORDER,
);

if !res.is_null() {
Ok(())
} else {
Err(PyErr::fetch(self.py()))
Err(PyErr::fetch(py))
}
}

Expand Down
41 changes: 36 additions & 5 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ use std::mem::size_of;
#[cfg(feature = "half")]
use half::{bf16, f16};
use ndarray::{array, s, Array1, Dim};
use numpy::prelude::*;
use numpy::{
array::{PyArray0Methods, PyArrayMethods},
dtype_bound, get_array_module,
npyffi::NPY_ORDER,
pyarray_bound, PyArray, PyArray1, PyArray2, PyArrayDescr, PyArrayDescrMethods, PyArrayDyn,
PyFixedString, PyFixedUnicode, PyUntypedArrayMethods, ToPyArray,
dtype_bound, get_array_module, npyffi::NPY_ORDER, pyarray_bound, PyArray, PyArray1, PyArray2,
PyArrayDescr, PyArrayDyn, PyFixedString, PyFixedUnicode,
};
use pyo3::{
py_run, pyclass, pymethods,
Expand Down Expand Up @@ -522,6 +520,39 @@ fn get_works() {
});
}

#[test]
fn permute_and_transpose() {
Python::with_gil(|py| {
let array = array![[0, 1, 2], [3, 4, 5]].into_pyarray_bound(py);

let permuted = array.permute(Some([1, 0])).unwrap();
assert_eq!(
permuted.readonly().as_array(),
array![[0, 3], [1, 4], [2, 5]]
);

let permuted = array.permute::<()>(None).unwrap();
assert_eq!(
permuted.readonly().as_array(),
array![[0, 3], [1, 4], [2, 5]]
);

let transposed = array.transpose().unwrap();
assert_eq!(
transposed.readonly().as_array(),
array![[0, 3], [1, 4], [2, 5]]
);

let array = pyarray_bound![py, [[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]];

let permuted = array.permute(Some([0, 2, 1])).unwrap();
assert_eq!(
permuted.readonly().as_array(),
array![[[1, 3], [2, 4]], [[5, 7], [6, 8]], [[9, 11], [10, 12]]]
);
});
}

#[test]
fn reshape() {
Python::with_gil(|py| {
Expand Down

0 comments on commit 053407d

Please sign in to comment.