Skip to content

Commit

Permalink
Add simple example for PyArray<PyObject>. (#339)
Browse files Browse the repository at this point in the history
Add simple example for PyArray<PyObject>.
  • Loading branch information
Chuxiaof authored Aug 22, 2022
1 parent a445411 commit 251bfde
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
17 changes: 16 additions & 1 deletion examples/simple/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,21 @@ use numpy::{
PyReadwriteArray1, PyReadwriteArrayDyn,
};
use pyo3::{
exceptions::PyIndexError,
pymodule,
types::{PyDict, PyModule},
FromPyObject, PyAny, PyResult, Python,
FromPyObject, PyAny, PyObject, PyResult, Python,
};

#[pymodule]
fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// example using generic PyObject
fn head(x: ArrayViewD<'_, PyObject>) -> PyResult<PyObject> {
x.get(0)
.cloned()
.ok_or_else(|| PyIndexError::new_err("array index out of range"))
}

// example using immutable borrows producing a new array
fn axpy(a: f64, x: ArrayViewD<'_, f64>, y: ArrayViewD<'_, f64>) -> ArrayD<f64> {
a * &x + &y
Expand All @@ -34,6 +42,13 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
&x + &y
}

// wrapper of `head`
#[pyfn(m)]
#[pyo3(name = "head")]
fn head_py(_py: Python<'_>, x: PyReadonlyArrayDyn<'_, PyObject>) -> PyResult<PyObject> {
head(x.as_array())
}

// wrapper of `axpy`
#[pyfn(m)]
#[pyo3(name = "axpy")]
Expand Down
8 changes: 7 additions & 1 deletion examples/simple/tests/test_ext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import numpy as np
from rust_ext import axpy, conj, mult, extract, add_minutes_to_seconds, polymorphic_add
from rust_ext import head, axpy, conj, mult, extract, add_minutes_to_seconds, polymorphic_add


def test_head():
x = np.array(['first', None, 42])
first = head(x)
assert first == 'first'


def test_axpy():
Expand Down

0 comments on commit 251bfde

Please sign in to comment.