Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Implement the Python array API standard #48

Open
lucascolley opened this issue Dec 7, 2023 · 11 comments
Open

[RFC] Implement the Python array API standard #48

lucascolley opened this issue Dec 7, 2023 · 11 comments

Comments

@lucascolley
Copy link

The Python array API standard standardises common functionality across Python array/tensor libraries. NumPy, PyTorch and CuPy are planning to have full implementations, and Dask and JAX also have implementations in progress. You could implement this in your main namespace or a separate namespace.

Why should you do this? As well as making it easier for users to convert existing NumPy/PyTorch/CuPy code to MLX, there is potential for interoperability with other libraries. For example, from the NumPy ecosystem, SciPy and scikit-learn have partial experimental support for arrays which comply with the standard.

If you are interested in this, the consortium would love to hear feedback over at https://github.com/data-apis/consortium-feedback/. Some potential pain points, such as missing float64 support, have already been discussed very briefly in data-apis/array-api#719.

@lucascolley
Copy link
Author

quoting @awni in gh-12:

We had the luxury of picking the best from all the frameworks we've used and worked on in the past and combining them into something new.

In terms of API design (implementation aside), this is the same process as what the consortium did in creation of the standard - I imagine it would be fruitful to discuss where the differences are and why one might be preferable.

@asmeurer
Copy link

asmeurer commented Dec 7, 2023

My talk at SciPy 2023 is useful if you want to know more about the array API.

@arpan-dhatt
Copy link

arpan-dhatt commented Dec 10, 2023

+1
Implementing this standard would help libraries like einops "just work" when dealing with MLX arrays. That particular library has been cropping up a lot for array reshape/transpose/stack, etc in the HyenaDNA model I'm trying to port.

It can also help clear up ambiguities in API design such as #113 and create a "checklist" of basic ops that should be implemented but are not yet, like linspace.

Edit: and just a thought, if there's any time to break backwards compatibility to make the Array API "first class" for MLX, it's now, at version 0.0.4 🙃

@awni
Copy link
Member

awni commented Dec 10, 2023

Are there any differences between the Python array API standard and NumPy or is the standard basically a subset of NumPy? If it's the latter, then I would say we are already on track to implement the standard.

Either way though, we will definitely take it into consideration as we continue to update the API!

@lucascolley
Copy link
Author

lucascolley commented Dec 10, 2023

Are there any differences between the Python array API standard and NumPy or is the standard basically a subset of NumPy?

Here is the tracking issue for support in the main NumPy namespace (making the NumPy API a superset of the standard). Some decisions were made to differ from NumPy where other libraries seemed to have improved upon NumPy, but it is quite similar. @mtsokol and @rgommers are working on a proposal to continue the work into making NumPy a superset.

Either way though, we will definitely take it into consideration as we continue to update the API!

Great, I'd definitely be keen to try to get MLX arrays working in SciPy! As @arpan-dhatt mentioned, if you think that compliance seems plausible and a good idea, now is probably the time to check for BC breaking changes.

@rgommers
Copy link

For context: in numpy 1.2x.y, there are differences. In numpy 2.0 (branching within a month, tentative release date end of Feb'24) there are a lot of API changes, additions and a number of backwards compatibility breaks to ensure that NumPy's main namespace and the fft and linalg modules will be compliant with the array API standard. The most important bc-breaking change (type promotion rules, NEP 55) was planned anyway independent of array API support - and makes NumPy more consistent and align better with JAX/PyTorch behavior.

Right now, for numpy 1.2x.y, the differences are bridged by the array-api-compat package. Longer-term, the need for that should go away. Having array API standard support in MLX would be quite nice - in particular because it'd then be possible to add support for MLX in SciPy, scikit-learn & co.

@EwoutH
Copy link

EwoutH commented Feb 28, 2024

@awni thanks for looking into this! Do you have anything to share as of now?

This is currently the second highest upvoted open issue by the way!

@awni
Copy link
Member

awni commented Feb 29, 2024

thanks for looking into this! Do you have anything to share as of now?

I'm sorry we haven't spent much time on this. I would say the following though, we intend to follow NumPy. So by transitivity if NumPy follows the Python API standard so will MLX. So we'd happily take PRs that move MLX to be more inline with NumPy, and will continue to work towards that ourselves.

@lucascolley
Copy link
Author

if NumPy follows the Python API standard so will MLX

Watch this space: https://numpy.org/neps/nep-0056-array-api-main-namespace.html

@ogrisel
Copy link

ogrisel commented Mar 8, 2024

we intend to follow NumPy. So by transitivity if NumPy follows the Python API standard so will MLX

That's good to know but there are things in MLX that do not exist in NumPy such as the stream parameter:

This is related to the device concept in Array API:

Note however that the concept of stream/queue control was deemed out-of-scope for the Array API. Instead, array constructor typically accept a device= kwarg and arrays expose a .device attribute.

Also note that NumPy will remain a CPU-only library for the foreseeable future but it might expose device keywords/attributes for the sake of compatibility:

@ogrisel
Copy link

ogrisel commented Mar 8, 2024

Similarly, NumPy is fundamentally an eager-evaluation library while the MLX library is lazy by default and exposes an explicit mx.eval function:

that can be used as a synchronization primitive.

On the other hand, Array API supports lazyness but does not specify a library-agnostic evaluation function as part of standard (yet). The only standard way to explicitly trigger evaluation are via dunder methods such as __float__ , __bool__ and the likes:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants