Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add permute and transpose methods for changing the order of axes of a PyArray #428

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changelog

- Unreleased
- Add `permute` and `transpose` methods for changing the order of axes of a `PyArray`. ([#428](https://github.com/PyO3/rust-numpy/pull/428))

- v0.21.0
- Migrate to the new `Bound` API introduced by PyO3 0.21. ([#410](https://github.com/PyO3/rust-numpy/pull/410)) ([#411](https://github.com/PyO3/rust-numpy/pull/411)) ([#412](https://github.com/PyO3/rust-numpy/pull/412)) ([#415](https://github.com/PyO3/rust-numpy/pull/415)) ([#416](https://github.com/PyO3/rust-numpy/pull/416)) ([#418](https://github.com/PyO3/rust-numpy/pull/418)) ([#419](https://github.com/PyO3/rust-numpy/pull/419)) ([#420](https://github.com/PyO3/rust-numpy/pull/420)) ([#421](https://github.com/PyO3/rust-numpy/pull/421)) ([#422](https://github.com/PyO3/rust-numpy/pull/422))
Expand Down
151 changes: 118 additions & 33 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1336,8 +1336,45 @@ impl<T: Element, D> PyArray<T, D> {
self.as_borrowed().cast(is_fortran).map(Bound::into_gil_ref)
}

/// A view of `self` with a different order of axes determined by `axes`.
///
/// If `axes` is `None`, the order of axes is reversed which corresponds to the standard matrix transpose.
///
/// See also [`numpy.transpose`][numpy-transpose] and [`PyArray_Transpose`][PyArray_Transpose].
///
/// # Example
///
/// ```
/// use numpy::prelude::*;
/// use numpy::PyArray;
/// use pyo3::Python;
/// use ndarray::array;
///
/// Python::with_gil(|py| {
/// let array = array![[0, 1, 2], [3, 4, 5]].into_pyarray(py);
///
/// let array = array.permute(Some([1, 0])).unwrap();
///
/// assert_eq!(array.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]);
/// });
/// ```
///
/// [numpy-transpose]: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html
/// [PyArray_Transpose]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Transpose
pub fn permute<'py, ID: IntoDimension>(
&'py self,
axes: Option<ID>,
) -> PyResult<&'py PyArray<T, D>> {
self.as_borrowed().permute(axes).map(Bound::into_gil_ref)
}

/// Special case of [`permute`][Self::permute] which reverses the order the axes.
pub fn transpose<'py>(&'py self) -> PyResult<&'py PyArray<T, D>> {
self.as_borrowed().transpose().map(Bound::into_gil_ref)
}

/// Construct a new array which has same values as self,
/// but has different dimensions specified by `dims`
/// but has different dimensions specified by `shape`
/// and a possibly different memory order specified by `order`.
///
/// See also [`numpy.reshape`][numpy-reshape] and [`PyArray_Newshape`][PyArray_Newshape].
Expand Down Expand Up @@ -1365,21 +1402,21 @@ impl<T: Element, D> PyArray<T, D> {
/// [PyArray_Newshape]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Newshape
pub fn reshape_with_order<'py, ID: IntoDimension>(
&'py self,
dims: ID,
shape: ID,
order: NPY_ORDER,
) -> PyResult<&'py PyArray<T, ID::Dim>> {
self.as_borrowed()
.reshape_with_order(dims, order)
.reshape_with_order(shape, order)
.map(Bound::into_gil_ref)
}

/// Special case of [`reshape_with_order`][Self::reshape_with_order] which keeps the memory order the same.
#[inline(always)]
pub fn reshape<'py, ID: IntoDimension>(
&'py self,
dims: ID,
shape: ID,
) -> PyResult<&'py PyArray<T, ID::Dim>> {
self.as_borrowed().reshape(dims).map(Bound::into_gil_ref)
self.as_borrowed().reshape(shape).map(Bound::into_gil_ref)
}

/// Extends or truncates the dimensions of an array.
Expand Down Expand Up @@ -1414,8 +1451,8 @@ impl<T: Element, D> PyArray<T, D> {
///
/// [ndarray-resize]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.resize.html
/// [PyArray_Resize]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Resize
pub unsafe fn resize<ID: IntoDimension>(&self, dims: ID) -> PyResult<()> {
self.as_borrowed().resize(dims)
pub unsafe fn resize<ID: IntoDimension>(&self, newshape: ID) -> PyResult<()> {
self.as_borrowed().resize(newshape)
}
}

Expand Down Expand Up @@ -1879,8 +1916,45 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
where
T: Element;

/// Construct a new array which has same values as self,
/// but has different dimensions specified by `dims`
/// A view of `self` with a different order of axes determined by `axes`.
///
/// If `axes` is `None`, the order of axes is reversed which corresponds to the standard matrix transpose.
///
/// See also [`numpy.transpose`][numpy-transpose] and [`PyArray_Transpose`][PyArray_Transpose].
///
/// # Example
///
/// ```
/// use numpy::prelude::*;
/// use numpy::PyArray;
/// use pyo3::Python;
/// use ndarray::array;
///
/// Python::with_gil(|py| {
/// let array = array![[0, 1, 2], [3, 4, 5]].into_pyarray_bound(py);
///
/// let array = array.permute(Some([1, 0])).unwrap();
///
/// assert_eq!(array.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]);
/// });
/// ```
///
/// [numpy-transpose]: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html
/// [PyArray_Transpose]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Transpose
fn permute<ID: IntoDimension>(&self, axes: Option<ID>) -> PyResult<Bound<'py, PyArray<T, D>>>
where
T: Element;

/// Special case of [`permute`][Self::permute] which reverses the order the axes.
fn transpose(&self) -> PyResult<Bound<'py, PyArray<T, D>>>
where
T: Element,
{
self.permute::<()>(None)
}

/// Construct a new array which has same values as `self`,
/// but has different dimensions specified by `shape`
/// and a possibly different memory order specified by `order`.
///
/// See also [`numpy.reshape`][numpy-reshape] and [`PyArray_Newshape`][PyArray_Newshape].
Expand Down Expand Up @@ -1908,19 +1982,19 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
/// [PyArray_Newshape]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Newshape
fn reshape_with_order<ID: IntoDimension>(
&self,
dims: ID,
shape: ID,
order: NPY_ORDER,
) -> PyResult<Bound<'py, PyArray<T, ID::Dim>>>
where
T: Element;

/// Special case of [`reshape_with_order`][Self::reshape_with_order] which keeps the memory order the same.
#[inline(always)]
fn reshape<ID: IntoDimension>(&self, dims: ID) -> PyResult<Bound<'py, PyArray<T, ID::Dim>>>
fn reshape<ID: IntoDimension>(&self, shape: ID) -> PyResult<Bound<'py, PyArray<T, ID::Dim>>>
where
T: Element,
{
self.reshape_with_order(dims, NPY_ORDER::NPY_ANYORDER)
self.reshape_with_order(shape, NPY_ORDER::NPY_ANYORDER)
}

/// Extends or truncates the dimensions of an array.
Expand Down Expand Up @@ -1955,7 +2029,7 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
///
/// [ndarray-resize]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.resize.html
/// [PyArray_Resize]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Resize
unsafe fn resize<ID: IntoDimension>(&self, dims: ID) -> PyResult<()>
unsafe fn resize<ID: IntoDimension>(&self, newshape: ID) -> PyResult<()>
where
T: Element;

Expand Down Expand Up @@ -2249,55 +2323,66 @@ impl<'py, T, D> PyArrayMethods<'py, T, D> for Bound<'py, PyArray<T, D>> {
if is_fortran { -1 } else { 0 },
)
};
if !ptr.is_null() {
Ok(unsafe { Bound::from_owned_ptr(self.py(), ptr).downcast_into_unchecked() })
} else {
Err(PyErr::fetch(self.py()))
unsafe {
Bound::from_owned_ptr_or_err(self.py(), ptr).map(|ob| ob.downcast_into_unchecked())
}
}

fn permute<ID: IntoDimension>(&self, axes: Option<ID>) -> PyResult<Bound<'py, PyArray<T, D>>> {
let mut axes = axes.map(|axes| axes.into_dimension());
let mut axes = axes.as_mut().map(|axes| axes.to_npy_dims());
let axes = axes
.as_mut()
.map_or_else(ptr::null_mut, |axes| axes as *mut npyffi::PyArray_Dims);

let py = self.py();
let ptr = unsafe { PY_ARRAY_API.PyArray_Transpose(py, self.as_array_ptr(), axes) };
unsafe { Bound::from_owned_ptr_or_err(py, ptr).map(|ob| ob.downcast_into_unchecked()) }
}

fn reshape_with_order<ID: IntoDimension>(
&self,
dims: ID,
shape: ID,
order: NPY_ORDER,
) -> PyResult<Bound<'py, PyArray<T, ID::Dim>>>
where
T: Element,
{
let mut dims = dims.into_dimension();
let mut dims = dims.to_npy_dims();
let mut shape = shape.into_dimension();
let mut shape = shape.to_npy_dims();

let py = self.py();
let ptr = unsafe {
PY_ARRAY_API.PyArray_Newshape(
self.py(),
py,
self.as_array_ptr(),
&mut dims as *mut npyffi::PyArray_Dims,
&mut shape as *mut npyffi::PyArray_Dims,
order,
)
};
if !ptr.is_null() {
Ok(unsafe { Bound::from_owned_ptr(self.py(), ptr).downcast_into_unchecked() })
} else {
Err(PyErr::fetch(self.py()))
}
unsafe { Bound::from_owned_ptr_or_err(py, ptr).map(|ob| ob.downcast_into_unchecked()) }
}

unsafe fn resize<ID: IntoDimension>(&self, dims: ID) -> PyResult<()>
unsafe fn resize<ID: IntoDimension>(&self, newshape: ID) -> PyResult<()>
where
T: Element,
{
let mut dims = dims.into_dimension();
let mut dims = dims.to_npy_dims();
let mut newshape = newshape.into_dimension();
let mut newshape = newshape.to_npy_dims();

let py = self.py();
let res = PY_ARRAY_API.PyArray_Resize(
self.py(),
py,
self.as_array_ptr(),
&mut dims as *mut npyffi::PyArray_Dims,
&mut newshape as *mut npyffi::PyArray_Dims,
1,
NPY_ORDER::NPY_ANYORDER,
);

if !res.is_null() {
Ok(())
} else {
Err(PyErr::fetch(self.py()))
Err(PyErr::fetch(py))
}
}

Expand Down
41 changes: 36 additions & 5 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ use std::mem::size_of;
#[cfg(feature = "half")]
use half::{bf16, f16};
use ndarray::{array, s, Array1, Dim};
use numpy::prelude::*;
use numpy::{
array::{PyArray0Methods, PyArrayMethods},
dtype_bound, get_array_module,
npyffi::NPY_ORDER,
pyarray_bound, PyArray, PyArray1, PyArray2, PyArrayDescr, PyArrayDescrMethods, PyArrayDyn,
PyFixedString, PyFixedUnicode, PyUntypedArrayMethods, ToPyArray,
dtype_bound, get_array_module, npyffi::NPY_ORDER, pyarray_bound, PyArray, PyArray1, PyArray2,
PyArrayDescr, PyArrayDyn, PyFixedString, PyFixedUnicode,
};
use pyo3::{
py_run, pyclass, pymethods,
Expand Down Expand Up @@ -522,6 +520,39 @@ fn get_works() {
});
}

#[test]
fn permute_and_transpose() {
Python::with_gil(|py| {
let array = array![[0, 1, 2], [3, 4, 5]].into_pyarray_bound(py);

let permuted = array.permute(Some([1, 0])).unwrap();
assert_eq!(
permuted.readonly().as_array(),
array![[0, 3], [1, 4], [2, 5]]
);

let permuted = array.permute::<()>(None).unwrap();
assert_eq!(
permuted.readonly().as_array(),
array![[0, 3], [1, 4], [2, 5]]
);

let transposed = array.transpose().unwrap();
assert_eq!(
transposed.readonly().as_array(),
array![[0, 3], [1, 4], [2, 5]]
);

let array = pyarray_bound![py, [[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]];

let permuted = array.permute(Some([0, 2, 1])).unwrap();
assert_eq!(
permuted.readonly().as_array(),
array![[[1, 3], [2, 4]], [[5, 7], [6, 8]], [[9, 11], [10, 12]]]
);
});
}

#[test]
fn reshape() {
Python::with_gil(|py| {
Expand Down