Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for safely creating universal functions in Rust #400

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

adamreichold
Copy link
Member

@adamreichold adamreichold commented Oct 30, 2023

This does not support polymorphic universal functions, but the numbers of inputs and outputs are arbitrary.

Closes #399

@adamreichold
Copy link
Member Author

adamreichold commented Oct 30, 2023

@mhostetter Would you be interested and able to test this branch? Would the functionality available here suffice to handle your use case?

EDIT: The tests examples should give a rough idea how this works.

@adamreichold adamreichold force-pushed the ufunc branch 2 times, most recently from 5658032 to 2276b54 Compare October 30, 2023 13:29
Copy link
Member

@kngwyu kngwyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks quite solid as an initial implementation! I'll review again after you complete the PR.

src/ufunc.rs Outdated Show resolved Hide resolved
src/ufunc.rs Show resolved Hide resolved
@mhostetter
Copy link

@mhostetter Would you be interested and able to test this branch? Would the functionality available here suffice to handle your use case?

I can try, but I fear my knowledge is too limited to be effective. I've only done an example or two with Py03 and rust-numpy. Do I need to pip install this branch? I'm also probably unsure of how to write the Rust code.

To summarize my goals in words... Write a NumPy ufunc in Rust. This function takes x and y which are single elements of two arrays. The function may also take other configuration parameters, e.g. modulo. An example Rust ufunc of addition in a prime finite field would return (x + y) % modulo. I then expose this Rust-written ufunc to Python via PyO3. Then I can invoke this ufunc on arbitrarily-sized NumPy arrays in Python (using normal NumPy broadcasting). I can also change the modulo in Python at runtime, e.g. rust_ufunc(x, y, 7) or rust_ufunc(x, y, 13).

@adamreichold
Copy link
Member Author

I can try, but I fear my knowledge is too limited to be effective. I've only done an example or two with Py03 and rust-numpy. Do I need to pip install this branch? I'm also probably unsure of how to write the Rust code.

I think to test this branch, you would only need to change your Cargo.toml to replace the version dependency on numpy by a Git one, e.g. replace

numpy = "0.20"

by

numpy = { git = "https://github.com/PyO3/rust-numpy.git", branch = "ufunc" }

To summarize my goals in words... Write a NumPy ufunc in Rust. This function takes x and y which are single elements of two arrays. The function may also take other configuration parameters, e.g. modulo. An example Rust ufunc of addition in a prime finite field would return (x + y) % modulo. I then expose this Rust-written ufunc to Python via PyO3. Then I can invoke this ufunc on arbitrarily-sized NumPy arrays in Python (using normal NumPy broadcasting). I can also change the modulo in Python at runtime, e.g. rust_ufunc(x, y, 7) or rust_ufunc(x, y, 13).

To my understanding, universal functions always take 1-dimensional vectors as inputs (so they have a chance of vectorizing the inner-most loop) and take their outputs explicitly.

So if your modulus is basically fixed, you could inject by capturing it in the closure that defines your ufunc, e.g.

let m = ...;

let add_mod_m = move |[x, y]: [ArrayView1<'_, u64>; 2], [z]: [ArrayViewMut1<'_, u64>; 1]| {
  azip!((x in x, y in y, z in z) *z = (*x + *y) % m);
}); 

let add_mod_m = numpy::ufunc::from_func(py, CString::new("add_mod_m").unwrap(), numpy::ufunc::Identity::Zero, add_mod_m);

module.add("add_mod_m", add_mod_m).unwrap();

If you want to vary m, then you need to add it as a third input parameter that NumPy will then broadcast to a 1-dimensional array (which is trivial using a zero stride).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support writing custom NumPy ufuncs in Rust
3 participants