Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unclear precedence between NumPy and JAX arrays in arithmetic #9952

Open
5 tasks done
shoyer opened this issue Jan 16, 2025 · 2 comments
Open
5 tasks done

Unclear precedence between NumPy and JAX arrays in arithmetic #9952

shoyer opened this issue Jan 16, 2025 · 2 comments
Labels
array API standard Support for the Python array API standard bug topic-arrays related to flexible array support

Comments

@shoyer
Copy link
Member

shoyer commented Jan 16, 2025

What happened?

Consider the following case of an xarray.DataArray wrapping a single element JAX array:

import jax
import xarray
import numpy as np

da = xarray.DataArray(jax.numpy.ones(1))

This object is wrapping a jax.Array, with operations implemented via the Array API (yay!), as one can check by inspect da.data.

da * 1 and 1 * da are both JAX arrays. So is da * np.array(1.0).

Unfortunately, np.array(1.0) * da is not -- it's a base NumPy array.

This feels quite inconsistent. Ideally JAX would take precedence in all these cases, even though the Python Array API rules technically do not prescribe an order of precendece between different array types.

What did you expect to happen?

No response

Minimal Complete Verifiable Example

No response

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

No response

Anything else we need to know?

No response

Environment

xarray = 2025.1.1 jax = 0.4.38
@shoyer shoyer added bug needs triage Issue that has not been reviewed by xarray team member labels Jan 16, 2025
@keewis
Copy link
Collaborator

keewis commented Jan 16, 2025

there's a recent issue that's quite similar: #9934 (for xr.dot instead of __mul__). As far as I can tell, we'd need numpy to return NotImplemented in __mul__ for DataArray for this to work (it would then call DataArray.__rmul__).

If I read the Array API spec correctly, it explicitly states the interaction of different array types as a non-goal and recommends the calling code to cast the input arrays first.

@TomNicholas TomNicholas added topic-arrays related to flexible array support array API standard Support for the Python array API standard and removed needs triage Issue that has not been reviewed by xarray team member labels Jan 16, 2025
@shoyer
Copy link
Member Author

shoyer commented Jan 16, 2025

OK, I think I've figured out why this is happening.

Xarray implements __array_ufunc__, which means that NumPy never returns NotImplemented from arithmetic special methods (e.g., ndarray.__mul__ in this case). Instead, it calls the NumPy ufunc: numpy_array * xarray_wrapping_jax
-> numpy_array.__mul__(xarray_wrapping_jax)
-> np.mul(numpy_array, xarray_wrapping_jax)
-> xarray_obj.__array_ufunc__(np.mul, numpy_array, xarray_wrapping_jax)
-> xarray.apply_ufunc(np.mul, numpy_array, xarray_wrapping_jax)
-> np.mul(numpy_array, jax_array)
-> np.mul(numpy_array, np.asarray(jax_array)) (because JAX does not define __array_ufunc__)
-> numpy array

To fix this in Xarray, we will need to update Xarray's __array_ufunc__ to use built-in Python arithmetic (i.e.,operator.mul) rather than the NumPy ufunc (np.mul).

The cleanest way to do this is to replacing the NumPy ufunc with the corresponding special operator (if possible) before calling xarray.apply_ufunc, somewhere around this line. The list of the necessary NumPy ufuncs to reverse can be found here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array API standard Support for the Python array API standard bug topic-arrays related to flexible array support
Projects
None yet
Development

No branches or pull requests

3 participants