Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix upcasting with python builtin numbers and numpy 2 #8946

Merged
merged 52 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
f3c2c93
Fix upcasting with python builtin numbers and numpy 2
djhoese Apr 15, 2024
2c8a607
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2024
fcbd821
Remove old commented code
djhoese Apr 15, 2024
3f7670b
Try updating result_type for array API
dcherian Apr 18, 2024
6115dd8
xfail
dcherian Apr 18, 2024
e3493b0
Revert "Try updating result_type for array API"
dcherian Apr 18, 2024
e27f572
Merge branch 'main' into bugfix-scalar-arr-casting
dcherian Apr 18, 2024
5b4384f
Determine python scalar dtype from array arguments
djhoese Apr 28, 2024
33d48e4
Merge branch 'main' into bugfix-scalar-arr-casting
keewis May 22, 2024
712c244
extend `get_array_namespace` to get the same namespace for multiple v…
keewis May 23, 2024
c486f8e
reflow error message
keewis May 23, 2024
6a55378
allow weakly dtyped data as a separate argument to `result_type`
keewis May 23, 2024
c38c07f
allow passing a `dtype` argument to `duck_array_ops.asarray`
keewis May 23, 2024
7b39cbe
rewrite `as_shared_dtype`
keewis May 23, 2024
85d986f
fall back to the `astype` method if `xp.astype` doesn't exist
keewis May 25, 2024
a89cba8
determine the smallest possible dtype and call `xp.result_type` again
keewis May 25, 2024
62900d6
map python types to dtype names
keewis May 25, 2024
3e2855d
apply the custom rules after both dtype combinations
keewis May 25, 2024
66fe0c5
add tests for passing scalars to `result_type`
keewis May 29, 2024
476cb11
apply the custom casting rules to scalar types + explicit dtypes
keewis May 29, 2024
4da4771
go back to the simple dtypes
keewis May 29, 2024
e193e90
comment on the fallback: if we don't have explicit dtypes, use the de…
keewis May 29, 2024
cda1ad0
Merge branch 'main' into bugfix-scalar-arr-casting
keewis May 29, 2024
47becb8
ignore the typing
keewis May 29, 2024
03974b1
don't support `datetime` scalars since that would be a new feature
keewis Jun 5, 2024
92b35b3
make sure passing only scalars also works
keewis Jun 5, 2024
109468a
move the fallback version of `result_type` into a separate function
keewis Jun 5, 2024
efb0f84
ignore the custom `inf` / `ninf` / `na` objects
keewis Jun 5, 2024
7b0d478
more temporary fixes
keewis Jun 5, 2024
aa2d372
Merge branch 'main' into bugfix-scalar-arr-casting
keewis Jun 5, 2024
463d0d6
remove debug print statements
keewis Jun 5, 2024
d33c742
change the representation for inf / ninf for boolean and nonnumeric a…
keewis Jun 7, 2024
51f15fa
pass arrays instead of relying on implicit casting using `asarray`
keewis Jun 7, 2024
2391cfb
refactor the implementation of `result_type`
keewis Jun 7, 2024
aaced1d
Merge branch 'main' into bugfix-scalar-arr-casting
keewis Jun 7, 2024
4e86aa2
remove obsolete comment
keewis Jun 7, 2024
c28233c
proper placement of the `result_type` comment
keewis Jun 7, 2024
0977707
back to using a list instead of a tuple
keewis Jun 7, 2024
57b4eec
expect the platform-specific default dtype for integers
keewis Jun 8, 2024
96e75c6
move `_future_array_api_result_type` to `npcompat`
keewis Jun 10, 2024
afa495e
rename `is_scalar_type` to `is_weak_scalar_type`
keewis Jun 10, 2024
1dc9d46
move the dispatch to `_future_array_api_result_type` to `npcompat`
keewis Jun 10, 2024
b405916
support passing *only* weak dtypes to `_future_array_api_result_type`
keewis Jun 10, 2024
66233a1
don't `xfail` the Array API `where` test
keewis Jun 10, 2024
445195b
determine the array namespace if not passed and no `cupy` arrays present
keewis Jun 10, 2024
b2a6379
comment on what to do with `npcompat.result_type`
keewis Jun 10, 2024
79d4a26
Merge branch 'main' into bugfix-scalar-arr-casting
keewis Jun 10, 2024
ab284f7
move the `result_type` compat code to `array_api_compat`
keewis Jun 10, 2024
cf11228
also update the comment
keewis Jun 10, 2024
8dbacbc
mention the Array API issue on python scalars in a comment
keewis Jun 10, 2024
d0e08d8
Merge branch 'main' into bugfix-scalar-arr-casting
keewis Jun 10, 2024
4c094c1
Merge branch 'main' into bugfix-scalar-arr-casting
keewis Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 7 additions & 32 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,31 +215,6 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
return xp.isdtype(dtype, kind)


def is_scalar_type(t):
return isinstance(t, (bool, int, float, complex, str, bytes))


def _future_array_api_result_type(*arrays_and_dtypes, xp):
strongly_dtyped = [t for t in arrays_and_dtypes if not is_scalar_type(t)]
weakly_dtyped = [t for t in arrays_and_dtypes if is_scalar_type(t)]

dtype = xp.result_type(*strongly_dtyped)
if not weakly_dtyped:
return dtype

possible_dtypes = {
complex: "complex64",
float: "float32",
int: "int8",
bool: "bool",
str: "str",
bytes: "bytes",
}
dtypes = [possible_dtypes.get(type(x), "object") for x in weakly_dtyped]

return xp.result_type(dtype, *dtypes)


def preprocess_scalar_types(t):
if isinstance(t, (str, bytes)):
return type(t)
Expand All @@ -266,12 +241,17 @@ def result_type(
-------
numpy.dtype for the result.
"""
# TODO (keewis): replace `npcompat.result_type` with `xp.result_type` once we can
# require a version of the Array API that supports passing scalars to it.
from xarray.core.duck_array_ops import get_array_namespace

if xp is None:
xp = get_array_namespace(arrays_and_dtypes)

types = {xp.result_type(preprocess_scalar_types(t)) for t in arrays_and_dtypes}
types = {
npcompat.result_type(preprocess_scalar_types(t), xp=xp)
for t in arrays_and_dtypes
}
if any(isinstance(t, np.dtype) for t in types):
# only check if there's numpy dtypes – the array API does not
# define the types we're checking for
Expand All @@ -281,9 +261,4 @@ def result_type(
):
return np.dtype(object)

if xp is np or any(
isinstance(getattr(t, "dtype", t), np.dtype) for t in arrays_and_dtypes
):
return xp.result_type(*map(preprocess_scalar_types, arrays_and_dtypes))
else:
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)
return npcompat.result_type(*map(preprocess_scalar_types, arrays_and_dtypes), xp=xp)
4 changes: 3 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def asarray(data, xp=np, dtype=None):
return xp.astype(converted, dtype)


def as_shared_dtype(scalars_or_arrays, xp=np):
def as_shared_dtype(scalars_or_arrays, xp=None):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
extension_array_types = [
Expand All @@ -272,6 +272,8 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
import cupy as cp

xp = cp
elif xp is None:
xp = get_array_namespace(scalars_or_arrays)

# Pass arrays directly instead of dtypes to result_type so scalars
# get handled properly.
Expand Down
43 changes: 42 additions & 1 deletion xarray/core/npcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np

try:
# requires numpy>=2.0
from numpy import isdtype # type: ignore[attr-defined,unused-ignore]
except ImportError:
import numpy as np

dtype_kinds = {
"bool": np.bool_,
Expand All @@ -58,3 +59,43 @@ def isdtype(dtype, kind):
return any(isinstance(dtype, kind) for kind in translated_kinds)
else:
return any(np.issubdtype(dtype, kind) for kind in translated_kinds)


keewis marked this conversation as resolved.
Show resolved Hide resolved
def is_weak_scalar_type(t):
return isinstance(t, (bool, int, float, complex, str, bytes))


def _future_array_api_result_type(*arrays_and_dtypes, xp):
strongly_dtyped = [t for t in arrays_and_dtypes if not is_weak_scalar_type(t)]
weakly_dtyped = [t for t in arrays_and_dtypes if is_weak_scalar_type(t)]

if not strongly_dtyped:
strongly_dtyped = [
xp.asarray(x) if not isinstance(x, type) else x for x in weakly_dtyped
]
weakly_dtyped = []

dtype = xp.result_type(*strongly_dtyped)
if not weakly_dtyped:
return dtype

possible_dtypes = {
complex: "complex64",
float: "float32",
int: "int8",
bool: "bool",
str: "str",
bytes: "bytes",
}
dtypes = [possible_dtypes.get(type(x), "object") for x in weakly_dtyped]

return xp.result_type(dtype, *dtypes)


def result_type(*arrays_and_dtypes, xp) -> np.dtype:
if xp is np or any(
isinstance(getattr(t, "dtype", t), np.dtype) for t in arrays_and_dtypes
):
return xp.result_type(*arrays_and_dtypes)
else:
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)
3 changes: 0 additions & 3 deletions xarray/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ def test_unstack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
assert_equal(actual, expected)


@pytest.mark.xfail(
reason="We don't know how to handle python scalars with array API arrays. This is strictly prohibited by the array API"
)
def test_where() -> None:
np_arr = xr.DataArray(np.array([1, 0]), dims="x")
xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x")
Expand Down
Loading