Unclear precedence between NumPy and JAX arrays in arithmetic #9952
Labels
array API standard
Support for the Python array API standard
bug
topic-arrays
related to flexible array support
What happened?
Consider the following case of an xarray.DataArray wrapping a single element JAX array:
This object is wrapping a
jax.Array
, with operations implemented via the Array API (yay!), as one can check by inspectda.data
.da * 1
and1 * da
are both JAX arrays. So isda * np.array(1.0)
.Unfortunately,
np.array(1.0) * da
is not -- it's a base NumPy array.This feels quite inconsistent. Ideally JAX would take precedence in all these cases, even though the Python Array API rules technically do not prescribe an order of precendece between different array types.
What did you expect to happen?
No response
Minimal Complete Verifiable Example
No response
MVCE confirmation
Relevant log output
No response
Anything else we need to know?
No response
Environment
The text was updated successfully, but these errors were encountered: