Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix upcasting with python builtin numbers and numpy 2 #8946

Merged
merged 52 commits into from
Jun 11, 2024

Conversation

djhoese
Copy link
Contributor

@djhoese djhoese commented Apr 15, 2024

See #8402 for more discussion. Bottom line is that numpy 2 changes the rules for casting between two inputs. Due to this and xarray's preference for promoting python scalars to 0d arrays (scalar arrays), xarray objects are being upcast to higher data types when they previously didn't.

I'm mainly opening this PR for further and more detailed discussion.

CC @dcherian

@dcherian dcherian added the run-upstream Run upstream CI label Apr 15, 2024
@djhoese
Copy link
Contributor Author

djhoese commented Apr 15, 2024

Ugh my local clone was so old it was pointing to master. One sec...

@djhoese djhoese force-pushed the bugfix-scalar-arr-casting branch from 88e778a to f3c2c93 Compare April 15, 2024 20:14
@djhoese
Copy link
Contributor Author

djhoese commented Apr 15, 2024

Ok so the failing test is the array-api version (https://github.com/data-apis/array-api-compat) where it expects both the x and y inputs of the where function to be .dtype. Since we're skipping scalar->array conversion in this PR those objects won't have a .dtype. I'm not sure what the rules are for the strict array API having scalar inputs.

@dcherian
Copy link
Contributor

Looks like the array api strictly wants arrays: https://data-apis.org/array-api/latest/API_specification/generated/array_api.where.html

@djhoese
Copy link
Contributor Author

djhoese commented Apr 15, 2024

Related but I don't fully understand it: data-apis/array-api-compat#85

@djhoese
Copy link
Contributor Author

djhoese commented Apr 16, 2024

I guess it depends how you interpret the array API standard then. I can file an issue if needed. To me, depending on how you read the standard, it means either:

  1. This test is flawed as it tests scalar inputs when the array API specifically defines Array inputs.
  2. The Array API package is flawed because it assumes and requires Array inputs when the standard allows for scalar inputs (I don't think this is true if I'm understanding the description).

The other point is that maybe numpy compatibility is more important until numpy more formally conforms to the array API standard (see the first note on https://data-apis.org/array-api/latest/API_specification/array_object.html#api-specification-array-object--page-root). But also type promotion seems wishy-washy and not super strict: https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars

I propose, because it works best for me and matches numpy compatibility, that I update the test to have a numpy case only but add a new test function with numpy and array api cases with array inputs to .where instead of scalars.

* main:
  (feat): Support for `pandas` `ExtensionArray` (pydata#8723)
  Migrate datatree mapping.py (pydata#8948)
  Add mypy to dev dependencies (pydata#8947)
  Convert 360_day calendars by choosing random dates to drop or add (pydata#8603)
@dcherian
Copy link
Contributor

I lean towards (1).

I looked at this for a while, and we'll need major changes around handling array API dtype objects to do this properly.

cc @keewis

@dcherian dcherian requested a review from keewis April 18, 2024 14:25
@keewis
Copy link
Collaborator

keewis commented Apr 22, 2024

we'll need major changes around handling array API dtype objects to do this properly.

I think the change could be limited to xarray.core.duck_array_ops.as_shared_dtype. According to the Array API section on mixing scalars and arrays, we should to use the dtype of the array (though it only looks at scalar + 1 array, so we'd need to extend that).

However, what we currently do is cast all scalars to arrays using asarray, which means python scalars use the OS default dtype (e.g. float64 on most 64-bit systems).

As a algorithm, maybe this could work:

  • separate the input into python scalars and arrays / scalars with dtype
  • determine result_type using just the arrays / scalars with dtype
  • check that all python scalars are compatible with the result (otherwise might have to return object?)
  • cast all input to arrays with the dtype

@djhoese
Copy link
Contributor Author

djhoese commented Apr 22, 2024

According to the Array API section on mixing scalars and arrays, we should to use the dtype of the array (though it only looks at scalar + 1 array, so we'd need to extend that).

Do you know if this is inline with numpy 2 dtype casting behavior?

@keewis
Copy link
Collaborator

keewis commented Apr 22, 2024

The main numpy namespace is supposed to be Array API compatible, so it should? I don't know for certain, though.

@djhoese
Copy link
Contributor Author

djhoese commented Apr 22, 2024

check that all python scalars are compatible with the result (otherwise might have to return object?)

How do we check this?

@djhoese
Copy link
Contributor Author

djhoese commented Apr 22, 2024

Here's what I have locally which seems to pass:

Subject: [PATCH] Cast scalars as arrays with result type of only arrays
---
Index: xarray/core/duck_array_ops.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py
--- a/xarray/core/duck_array_ops.py	(revision e27f572585a6386729a5523c1f9082c72fa8d178)
+++ b/xarray/core/duck_array_ops.py	(date 1713816523554)
@@ -239,20 +239,30 @@
         import cupy as cp
 
         arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
+        # Pass arrays directly instead of dtypes to result_type so scalars
+        # get handled properly.
+        # Note that result_type() safely gets the dtype from dask arrays without
+        # evaluating them.
+        out_type = dtypes.result_type(*arrays)
     else:
-        arrays = [
-            # https://github.com/pydata/xarray/issues/8402
-            # https://github.com/pydata/xarray/issues/7721
-            x if isinstance(x, (int, float, complex)) else asarray(x, xp=xp)
-            for x in scalars_or_arrays
-        ]
-    # Pass arrays directly instead of dtypes to result_type so scalars
-    # get handled properly.
-    # Note that result_type() safely gets the dtype from dask arrays without
-    # evaluating them.
-    out_type = dtypes.result_type(*arrays)
+        # arrays = [
+        #     # https://github.com/pydata/xarray/issues/8402
+        #     # https://github.com/pydata/xarray/issues/7721
+        #     x if isinstance(x, (int, float, complex)) else asarray(x, xp=xp)
+        #     for x in scalars_or_arrays
+        # ]
+        objs_with_dtype = [obj for obj in scalars_or_arrays if hasattr(obj, "dtype")]
+        if objs_with_dtype:
+            # Pass arrays directly instead of dtypes to result_type so scalars
+            # get handled properly.
+            # Note that result_type() safely gets the dtype from dask arrays without
+            # evaluating them.
+            out_type = dtypes.result_type(*objs_with_dtype)
+        else:
+            out_type = dtypes.result_type(*scalars_or_arrays)
+        arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
     return [
-        astype(x, out_type, copy=False) if hasattr(x, "dtype") else x for x in arrays
+        astype(x, out_type, copy=False) for x in arrays
     ]
 
 

I just through it together to see if it would work. I'm not sure it is accurate, but the fact that it is almost exactly like the existing solution with the only difference being the out_type = changes makes me feel this is going in a good direction.

Note I had to do if objs_with_dtype: because the test passes two python scalars so there are no arrays to determine the result type.

@keewis
Copy link
Collaborator

keewis commented Apr 22, 2024

How do we check this?

Not sure... but there are only so many builtin types that can be involved without requiring object dtype, so we could just enumerate all of them? As far as I can tell, that would be: bool, int, float, str, datetime/date, and timedelta

@djhoese
Copy link
Contributor Author

djhoese commented Apr 26, 2024

check that all python scalars are compatible with the result (otherwise might have to return object?)

How do we check this?

@keewis Do you have a test that I can add to verify any fix I attempt for this? What do you mean by python scalar being compatible with the result?

@keewis
Copy link
Collaborator

keewis commented Apr 26, 2024

well, for example, what should happen for this:

a = xr.DataArray(np.array([1, 2, 3], dtype="int8"), dim="x")
xr.where(a % 2 == 1, a, 1.2)

according to the algorithm above, we have one array of dtype int8, so that means we'd have to check if 1.2 (a float) is compatible with int8. It is not, so we should promote everything to float (the default would be to use float64, which might be a bit weird).

Something similar:

a = xr.DataArray(np.array(["2019-01-01", "2020-01-01"], dtype="datetime64[ns]"), dim="x")
xr.where(a.x % 2 == 1, a, datetime.datetime(2019, 6, 30))

in that case, the check should succeed, because we can convert a builtin datetime object to datetime64[ns].

@djhoese
Copy link
Contributor Author

djhoese commented Apr 28, 2024

I committed my (what I consider ugly) implementation of your original approach @keewis. I'm still not sure I understand how to approach the scalar compatibility so if someone has some ideas then please make some suggestion comments or commits directly if you have the permissions.

@keewis
Copy link
Collaborator

keewis commented Apr 28, 2024

this might be cleaner:

def asarray(data, xp=np, dtype=None):
    return data if is_duck_array(data) else xp.asarray(data, dtype=dtype)


def as_shared_dtype(scalars_or_arrays, xp=np):
    """Cast a arrays to a shared dtype using xarray's type promotion rules."""
    if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
        # as soon as extension arrays are involved we only use this:
        extension_array_types = [
            x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
        ]
        if len(extension_array_types) == len(scalars_or_arrays) and all(
            isinstance(x, type(extension_array_types[0])) for x in extension_array_types
        ):
            return scalars_or_arrays
        raise ValueError(
            f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
        )

    if array_type_cupy := array_type("cupy") and any(  # noqa: F841
        isinstance(x, array_type_cupy) for x in scalars_or_arrays  # noqa: F821
    ):
        import cupy as cp

        xp_ = cp
    else:
        xp_ = xp

    # split into python scalars and arrays / numpy scalars (i.e. into weakly and strongly dtyped)
    with_dtype = {}
    python_scalars = {}
    for index, elem in enumerate(scalars_or_arrays):
        append_to = with_dtype if hasattr(elem, "dtype") else python_scalars
        append_to[index] = elem

    if with_dtype:
        to_convert = with_dtype
    else:
        # can't avoid using the default dtypes if we only get weak dtypes
        to_convert = python_scalars
        python_scalars = {}

    arrays = {index: asarray(x, xp=xp_) for index, x in to_convert.items()}

    common_dtype = dtypes.result_type(*arrays.values())
    # TODO(keewis): check that all python scalars are compatible. If not, change the dtype or raise.

    # cast arrays
    cast = {index: astype(x, dtype=common_dtype, copy=False) for index, x in arrays.items()}
    # convert python scalars to arrays with a specific dtype
    converted = {index: asarray(x, xp=xp_, dtype=common_dtype) for index, x in python_scalars.items()}

    # merge both
    combined = cast | converted
    return [x for _, x in sorted(combined.items(), key=lambda x: x[0])]

This is still missing the dtype fallbacks, though.

@keewis
Copy link
Collaborator

keewis commented Apr 28, 2024

I see now why the dtype fallbacks for scalars is tricky... we basically need to enumerate the casting rules, and decide when to return a different dtype (like object). numpy has can_cast with the option to choose the strictness (so we could use "same_kind") and it accepts python scalar types, while the Array API does not allow that choice, and we also can't pass in python scalar types.

To start, here's the rules from the Array API:

  • complex dtypes are compatible with int, float, or complex
  • float dtypes are compatible with any int or float
  • int dtypes are compatible with int (but beware: python uses BigInt, so the value might exceed the maximum of the dtype)
  • the bool dtype is only compatible with bool

From numpy, we also have these (numpy casting is even more relaxed than this, but that behavior may also cause some confusing issues):

  • bool can be cast to int, so it is compatible with anything int is compatible with
  • str dtypes are only compatible with str. Anything else, like formatting and casting to other types, has to be done explicitly before calling as_shared_dtype.
  • datetime dtypes (precisions) are compatible with datetime.datetime, datetime.date, and pd.Timestamp
  • timedelta dtypes (precisions) are compatible with datetime.timedelta and pd.Timedelta. Casting to int is possible, but has to be done explicitly (i.e. we can ignore it here)
  • anything else results in a object dtype

Edit: it appears NEP 50 describes the changes in detail. I didn't see that before writing both the list above and implementing the changes, so I might have to change both.

@keewis
Copy link
Collaborator

keewis commented Apr 28, 2024

here's my shot at the scalar dtype verification (the final implementation we settled on in the end is much better). I'm pretty sure it can be cleaned up further (and we need more tests), but it does fix all the casting issues. Edit: note that this depends on the Array API fixes for numpy>=2.

What I don't like is that we're essentially hard-coding the dtype casting hierarchy, but I couldn't figure out a way to make it work without that.

@djhoese
Copy link
Contributor Author

djhoese commented May 12, 2024

FYI to everyone watching this, I'm going to be switching to a heavier paternity leave than I was already starting this week. I think someone else should take this PR over as I don't think I'll have time to finish it in time for the numpy 2 final release.

Copy link
Member

@shoyer shoyer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. Thanks @keewis for your patience here and getting this to the finish line!

xarray/core/dtypes.py Outdated Show resolved Hide resolved
xarray/core/duck_array_ops.py Show resolved Hide resolved
xarray/tests/test_array_api.py Outdated Show resolved Hide resolved
@keewis keewis added the plan to merge Final call for comments label Jun 10, 2024
@keewis
Copy link
Collaborator

keewis commented Jun 10, 2024

if my most recent changes are fine, this should be ready for merging (the remaining upstream-dev test failures will be fixed by #9081).

Once that is done, I will cut a release to have at least one release that is compatible with numpy>=2 before that is released.

@dcherian dcherian mentioned this pull request Jun 10, 2024
xarray/core/npcompat.py Outdated Show resolved Hide resolved
xarray/core/array_api_compat.py Show resolved Hide resolved
@dcherian
Copy link
Contributor

image

Wow. Thanks @keewis 👏 👏

@flamingbear flamingbear merged commit 2013e7f into pydata:main Jun 11, 2024
27 of 28 checks passed
@keewis
Copy link
Collaborator

keewis commented Jun 11, 2024

thanks for all the help and advice, @shoyer! And thanks for kicking this off, @djhoese.

andersy005 pushed a commit that referenced this pull request Jun 14, 2024
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deepak Cherian <[email protected]>
Co-authored-by: Justus Magin <[email protected]>
Co-authored-by: Justus Magin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
plan to merge Final call for comments run-upstream Run upstream CI
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

⚠️ Nightly upstream-dev CI failed ⚠️ where dtype upcast with numpy 2
5 participants