diff --git a/.codecov.yml b/.codecov.yml index f3a055c09d4..d0bec9539f8 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,16 +1,38 @@ codecov: - ci: - # by default, codecov doesn't recognize azure as a CI provider - - dev.azure.com - require_ci_to_pass: yes + require_ci_to_pass: true coverage: status: project: default: # Require 1% coverage, i.e., always succeed - target: 1 + target: 1% + flags: + - unittests + paths: + - "!xarray/tests/" + unittests: + target: 90% + flags: + - unittests + paths: + - "!xarray/tests/" + mypy: + target: 20% + flags: + - mypy patch: false changes: false -comment: off +comment: false + +flags: + unittests: + paths: + - "xarray" + - "!xarray/tests" + carryforward: false + mypy: + paths: + - "xarray" + carryforward: false diff --git a/.github/ISSUE_TEMPLATE/bugreport.yml b/.github/ISSUE_TEMPLATE/bugreport.yml index 59e5889f5ec..cc1a2e12be3 100644 --- a/.github/ISSUE_TEMPLATE/bugreport.yml +++ b/.github/ISSUE_TEMPLATE/bugreport.yml @@ -44,6 +44,7 @@ body: - label: Complete example — the example is self-contained, including all data and the text of any traceback. - label: Verifiable example — the example copy & pastes into an IPython prompt or [Binder notebook](https://mybinder.org/v2/gh/pydata/xarray/main?urlpath=lab/tree/doc/examples/blank_template.ipynb), returning the result. - label: New issue — a search of GitHub Issues suggests this is not a duplicate. + - label: Recent environment — the issue occurs with the latest version of xarray and its dependencies. - type: textarea id: log-output diff --git a/.github/labeler.yml b/.github/labeler.yml index d866a7342fe..b5d55100d9b 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -54,6 +54,9 @@ topic-indexing: - xarray/core/indexes.py - xarray/core/indexing.py +topic-NamedArray: + - xarray/namedarray/* + topic-performance: - asv_bench/benchmarks/* - asv_bench/benchmarks/**/* diff --git a/.github/workflows/benchmarks-last-release.yml b/.github/workflows/benchmarks-last-release.yml index e1ae9b1b62e..40f06c82107 100644 --- a/.github/workflows/benchmarks-last-release.yml +++ b/.github/workflows/benchmarks-last-release.yml @@ -17,7 +17,7 @@ jobs: steps: # We need the full repo to avoid this issue # https://github.com/actions/checkout/issues/23 - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index ade00b942e7..08f39e36762 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -17,7 +17,7 @@ jobs: steps: # We need the full repo to avoid this issue # https://github.com/actions/checkout/issues/23 - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index 99ebefd9338..dc9cc2cd2fe 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -22,7 +22,7 @@ jobs: outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 2 - uses: xarray-contrib/ci-trigger@v1 @@ -44,7 +44,7 @@ jobs: PYTHON_VERSION: "3.10" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. @@ -82,8 +82,6 @@ jobs: name: Mypy runs-on: "ubuntu-latest" needs: detect-ci-trigger - # temporarily skipping due to https://github.com/pydata/xarray/issues/6551 - if: needs.detect-ci-trigger.outputs.triggered == 'false' defaults: run: shell: bash -l {0} @@ -92,7 +90,7 @@ jobs: PYTHON_VERSION: "3.10" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. @@ -123,7 +121,7 @@ jobs: - name: Run mypy run: | - python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report + python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report xarray/ - name: Upload mypy coverage to Codecov uses: codecov/codecov-action@v3.1.4 @@ -146,7 +144,7 @@ jobs: PYTHON_VERSION: "3.9" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. @@ -177,7 +175,7 @@ jobs: - name: Run mypy run: | - python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report + python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report xarray/ - name: Upload mypy coverage to Codecov uses: codecov/codecov-action@v3.1.4 @@ -190,6 +188,126 @@ jobs: + pyright: + name: Pyright + runs-on: "ubuntu-latest" + needs: detect-ci-trigger + if: | + always() + && ( + contains( github.event.pull_request.labels.*.name, 'run-pyright') + ) + defaults: + run: + shell: bash -l {0} + env: + CONDA_ENV_FILE: ci/requirements/environment.yml + PYTHON_VERSION: "3.10" + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + + - name: set environment variables + run: | + echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + - name: Setup micromamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: ${{env.CONDA_ENV_FILE}} + environment-name: xarray-tests + create-args: >- + python=${{env.PYTHON_VERSION}} + conda + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" + - name: Install xarray + run: | + python -m pip install --no-deps -e . + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: Install pyright + run: | + python -m pip install pyright --force-reinstall + + - name: Run pyright + run: | + python -m pyright xarray/ + + - name: Upload pyright coverage to Codecov + uses: codecov/codecov-action@v3.1.4 + with: + file: pyright_report/cobertura.xml + flags: pyright + env_vars: PYTHON_VERSION + name: codecov-umbrella + fail_ci_if_error: false + + pyright39: + name: Pyright 3.9 + runs-on: "ubuntu-latest" + needs: detect-ci-trigger + if: | + always() + && ( + contains( github.event.pull_request.labels.*.name, 'run-pyright') + ) + defaults: + run: + shell: bash -l {0} + env: + CONDA_ENV_FILE: ci/requirements/environment.yml + PYTHON_VERSION: "3.9" + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + + - name: set environment variables + run: | + echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + - name: Setup micromamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: ${{env.CONDA_ENV_FILE}} + environment-name: xarray-tests + create-args: >- + python=${{env.PYTHON_VERSION}} + conda + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" + - name: Install xarray + run: | + python -m pip install --no-deps -e . + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: Install pyright + run: | + python -m pip install pyright --force-reinstall + + - name: Run pyright + run: | + python -m pyright xarray/ + + - name: Upload pyright coverage to Codecov + uses: codecov/codecov-action@v3.1.4 + with: + file: pyright_report/cobertura.xml + flags: pyright39 + env_vars: PYTHON_VERSION + name: codecov-umbrella + fail_ci_if_error: false + + + min-version-policy: name: Minimum Version Policy runs-on: "ubuntu-latest" @@ -205,7 +323,7 @@ jobs: fail-fast: false steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index da4ad32b1f5..7ee197aeda3 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -22,7 +22,7 @@ jobs: outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 2 - uses: xarray-contrib/ci-trigger@v1 @@ -42,7 +42,7 @@ jobs: matrix: os: ["ubuntu-latest", "macos-latest", "windows-latest"] # Bookend python versions - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.9", "3.11"] env: [""] include: # Minimum python version: @@ -60,7 +60,7 @@ jobs: python-version: "3.10" os: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. - name: Set environment variables diff --git a/.github/workflows/nightly-wheels.yml b/.github/workflows/nightly-wheels.yml index 562e442683e..ca3499386c9 100644 --- a/.github/workflows/nightly-wheels.yml +++ b/.github/workflows/nightly-wheels.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest if: github.repository == 'pydata/xarray' steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 - uses: actions/setup-python@v4 diff --git a/.github/workflows/pypi-release.yaml b/.github/workflows/pypi-release.yaml index 1afafd4550d..916bd33528a 100644 --- a/.github/workflows/pypi-release.yaml +++ b/.github/workflows/pypi-release.yaml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest if: github.repository == 'pydata/xarray' steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 - uses: actions/setup-python@v4 @@ -88,7 +88,7 @@ jobs: path: dist - name: Publish package to TestPyPI if: github.event_name == 'push' - uses: pypa/gh-action-pypi-publish@v1.8.8 + uses: pypa/gh-action-pypi-publish@v1.8.10 with: repository_url: https://test.pypi.org/legacy/ verbose: true @@ -111,6 +111,6 @@ jobs: name: releases path: dist - name: Publish package to PyPI - uses: pypa/gh-action-pypi-publish@v1.8.8 + uses: pypa/gh-action-pypi-publish@v1.8.10 with: verbose: true diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml index 7c60f20125e..d01fc5cdffc 100644 --- a/.github/workflows/upstream-dev-ci.yaml +++ b/.github/workflows/upstream-dev-ci.yaml @@ -25,7 +25,7 @@ jobs: outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 2 - uses: xarray-contrib/ci-trigger@v1 @@ -52,7 +52,7 @@ jobs: matrix: python-version: ["3.10"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. - name: Set up conda environment @@ -112,7 +112,7 @@ jobs: matrix: python-version: ["3.10"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. - name: Set up conda environment diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eca14fe0631..5626f450ec0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,6 @@ # https://pre-commit.com/ +ci: + autoupdate_schedule: monthly repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 @@ -16,13 +18,13 @@ repos: files: ^xarray/ - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.0.278' + rev: 'v0.0.292' hooks: - id: ruff args: ["--fix"] # https://github.com/python/black#version-control-integration - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.9.1 hooks: - id: black-jupyter - repo: https://github.com/keewis/blackdoc @@ -30,10 +32,10 @@ repos: hooks: - id: blackdoc exclude: "generate_aggregations.py" - additional_dependencies: ["black==23.7.0"] + additional_dependencies: ["black==23.9.1"] - id: blackdoc-autoupdate-black - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.4.1 + rev: v1.5.1 hooks: - id: mypy # Copied from setup.cfg diff --git a/CORE_TEAM_GUIDE.md b/CORE_TEAM_GUIDE.md new file mode 100644 index 00000000000..9eb91f4e586 --- /dev/null +++ b/CORE_TEAM_GUIDE.md @@ -0,0 +1,322 @@ +> **_Note:_** This Core Team Member Guide was adapted from the [napari project's Core Developer Guide](https://napari.org/stable/developers/core_dev_guide.html) and the [Pandas maintainers guide](https://pandas.pydata.org/docs/development/maintaining.html). + +# Core Team Member Guide + +Welcome, new core team member! We appreciate the quality of your work, and enjoy working with you! +Thank you for your numerous contributions to the project so far. + +By accepting the invitation to become a core team member you are **not required to commit to doing any more work** - +xarray is a volunteer project, and we value the contributions you have made already. + +You can see a list of all the current core team members on our +[@pydata/xarray](https://github.com/orgs/pydata/teams/xarray) +GitHub team. Once accepted, you should now be on that list too. +This document offers guidelines for your new role. + +## Tasks + +Xarray values a wide range of contributions, only some of which involve writing code. +As such, we do not currently make a distinction between a "core team member", "core developer", "maintainer", +or "triage team member" as some projects do (e.g. [pandas](https://pandas.pydata.org/docs/development/maintaining.html)). +That said, if you prefer to refer to your role as one of the other titles above then that is fine by us! + +Xarray is mostly a volunteer project, so these tasks shouldn’t be read as “expectations”. +**There are no strict expectations**, other than to adhere to our [Code of Conduct](https://github.com/pydata/xarray/tree/main/CODE_OF_CONDUCT.md). +Rather, the tasks that follow are general descriptions of what it might mean to be a core team member: + +- Facilitate a welcoming environment for those who file issues, make pull requests, and open discussion topics, +- Triage newly filed issues, +- Review newly opened pull requests, +- Respond to updates on existing issues and pull requests, +- Drive discussion and decisions on stalled issues and pull requests, +- Provide experience / wisdom on API design questions to ensure consistency and maintainability, +- Project organization (run developer meetings, coordinate with sponsors), +- Project evangelism (advertise xarray to new users), +- Community contact (represent xarray in user communities such as [Pangeo](https://pangeo.io/)), +- Key project contact (represent xarray's perspective within key related projects like NumPy, Zarr or Dask), +- Project fundraising (help write and administrate grants that will support xarray), +- Improve documentation or tutorials (especially on [`tutorial.xarray.dev`](https://tutorial.xarray.dev/)), +- Presenting or running tutorials (such as those we have given at the SciPy conference), +- Help maintain the [`xarray.dev`](https://xarray.dev/) landing page and website, the [code for which is here](https://github.com/xarray-contrib/xarray.dev), +- Write blog posts on the [xarray blog](https://xarray.dev/blog), +- Help maintain xarray's various Continuous Integration Workflows, +- Help maintain a regular release schedule (we aim for one or more releases per month), +- Attend the bi-weekly community meeting ([issue](https://github.com/pydata/xarray/issues/4001)), +- Contribute to the xarray codebase. + +(Matt Rocklin's post on [the role of a maintainer](https://matthewrocklin.com/blog/2019/05/18/maintainer) may be +interesting background reading, but should not be taken to strictly apply to the Xarray project.) + +Obviously you are not expected to contribute in all (or even more than one) of these ways! +They are listed so as to indicate the many types of work that go into maintaining xarray. + +It is natural that your available time and enthusiasm for the project will wax and wane - this is fine and expected! +It is also common for core team members to have a "niche" - a particular part of the codebase they have specific expertise +with, or certain types of task above which they primarily perform. + +If however you feel that is unlikely you will be able to be actively contribute in the foreseeable future +(or especially if you won't be available to answer questions about pieces of code that you wrote previously) +then you may want to consider letting us know you would rather be listed as an "Emeritus Core Team Member", +as this would help us in evaluating the overall health of the project. + +## Issue triage + +One of the main ways you might spend your contribution time is by responding to or triaging new issues. +Here’s a typical workflow for triaging a newly opened issue or discussion: + +1. **Thank the reporter for opening an issue.** + + The issue tracker is many people’s first interaction with the xarray project itself, beyond just using the library. + It may also be their first open-source contribution of any kind. As such, we want it to be a welcoming, pleasant experience. + +2. **Is the necessary information provided?** + + Ideally reporters would fill out the issue template, but many don’t. If crucial information (like the version of xarray they used), + is missing feel free to ask for that and label the issue with “needs info”. + The report should follow the [guidelines for xarray discussions](https://github.com/pydata/xarray/discussions/5404). + You may want to link to that if they didn’t follow the template. + + Make sure that the title accurately reflects the issue. Edit it yourself if it’s not clear. + Remember also that issues can be converted to discussions and vice versa if appropriate. + +3. **Is this a duplicate issue?** + + We have many open issues. If a new issue is clearly a duplicate, label the new issue as “duplicate”, and close the issue with a link to the original issue. + Make sure to still thank the reporter, and encourage them to chime in on the original issue, and perhaps try to fix it. + + If the new issue provides relevant information, such as a better or slightly different example, add it to the original issue as a comment or an edit to the original post. + +4. **Is the issue minimal and reproducible?** + + For bug reports, we ask that the reporter provide a minimal reproducible example. + See [minimal-bug-reports](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) for a good explanation. + If the example is not reproducible, or if it’s clearly not minimal, feel free to ask the reporter if they can provide and example or simplify the provided one. + Do acknowledge that writing minimal reproducible examples is hard work. If the reporter is struggling, you can try to write one yourself and we’ll edit the original post to include it. + + If a nice reproducible example has been provided, thank the reporter for that. + If a reproducible example can’t be provided, add the “needs mcve” label. + + If a reproducible example is provided, but you see a simplification, edit the original post with your simpler reproducible example. + +5. **Is this a clearly defined feature request?** + + Generally, xarray prefers to discuss and design new features in issues, before a pull request is made. + Encourage the submitter to include a proposed API for the new feature. Having them write a full docstring is a good way to pin down specifics. + + We may need a discussion from several xarray maintainers before deciding whether the proposal is in scope for xarray. + +6. **Is this a usage question?** + + We prefer that usage questions are asked on StackOverflow with the [`python-xarray` tag](https://stackoverflow.com/questions/tagged/python-xarray +) or as a [GitHub discussion topic](https://github.com/pydata/xarray/discussions). + + If it’s easy to answer, feel free to link to the relevant documentation section, let them know that in the future this kind of question should be on StackOverflow, and close the issue. + +7. **What labels and milestones should I add?** + + Apply the relevant labels. This is a bit of an art, and comes with experience. Look at similar issues to get a feel for how things are labeled. + Labels used for labelling issues that relate to particular features or parts of the codebase normally have the form `topic-`. + + If the issue is clearly defined and the fix seems relatively straightforward, label the issue as `contrib-good-first-issue`. + You can also remove the `needs triage` label that is automatically applied to all newly-opened issues. + +8. **Where should the poster look to fix the issue?** + + If you can, it is very helpful to point to the approximate location in the codebase where a contributor might begin to fix the issue. + This helps ease the way in for new contributors to the repository. + +## Code review and contributions + +As a core team member, you are a representative of the project, +and trusted to make decisions that will serve the long term interests +of all users. You also gain the responsibility of shepherding +other contributors through the review process; here are some +guidelines for how to do that. + +### All contributors are treated the same + +You should now have gained the ability to merge or approve +other contributors' pull requests. Merging contributions is a shared power: +only merge contributions you yourself have carefully reviewed, and that are +clear improvements for the project. When in doubt, and especially for more +complex changes, wait until at least one other core team member has approved. +(See [Reviewing](#reviewing) and especially +[Merge Only Changes You Understand](#merge-only-changes-you-understand) below.) + +It should also be considered best practice to leave a reasonable (24hr) time window +after approval before merge to ensure that other core team members have a reasonable +chance to weigh in. +Adding the `plan-to-merge` label notifies developers of the imminent merge. + +We are also an international community, with contributors from many different time zones, +some of whom will only contribute during their working hours, others who might only be able +to contribute during nights and weekends. It is important to be respectful of other peoples +schedules and working habits, even if it slows the project down slightly - we are in this +for the long run. In the same vein you also shouldn't feel pressured to be constantly +available or online, and users or contributors who are overly demanding and unreasonable +to the point of harassment will be directed to our [Code of Conduct](https://github.com/pydata/xarray/tree/main/CODE_OF_CONDUCT.md). +We value sustainable development practices over mad rushes. + +When merging, we automatically use GitHub's +[Squash and Merge](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/merging-a-pull-request#merging-a-pull-request) +to ensure a clean git history. + +You should also continue to make your own pull requests as before and in accordance +with the [general contributing guide](https://docs.xarray.dev/en/stable/contributing.html). These pull requests still +require the approval of another core team member before they can be merged. + +### How to conduct a good review + +*Always* be kind to contributors. Contributors are often doing +volunteer work, for which we are tremendously grateful. Provide +constructive criticism on ideas and implementations, and remind +yourself of how it felt when your own work was being evaluated as a +novice. + +``xarray`` strongly values mentorship in code review. New users +often need more handholding, having little to no git +experience. Repeat yourself liberally, and, if you don’t recognize a +contributor, point them to our development guide, or other GitHub +workflow tutorials around the web. Do not assume that they know how +GitHub works (many don't realize that adding a commit +automatically updates a pull request, for example). Gentle, polite, kind +encouragement can make the difference between a new core team member and +an abandoned pull request. + +When reviewing, focus on the following: + +1. **Usability and generality:** `xarray` is a user-facing package that strives to be accessible +to both novice and advanced users, and new features should ultimately be +accessible to everyone using the package. `xarray` targets the scientific user +community broadly, and core features should be domain-agnostic and general purpose. +Custom functionality is meant to be provided through our various types of interoperability. + +2. **Performance and benchmarks:** As `xarray` targets scientific applications that often involve +large multidimensional datasets, high performance is a key value of `xarray`. While +every new feature won't scale equally to all sizes of data, keeping in mind performance +and our [benchmarks](https://github.com/pydata/xarray/tree/main/asv_bench) during a review may be important, and you may +need to ask for benchmarks to be run and reported or new benchmarks to be added. +You can run the CI benchmarking suite on any PR by tagging it with the ``run-benchmark`` label. + +3. **APIs and stability:** Coding users and developers will make +extensive use of our APIs. The foundation of a healthy ecosystem will be +a fully capable and stable set of APIs, so as `xarray` matures it will +very important to ensure our APIs are stable. Spending the extra time to consider names of public facing +variables and methods, alongside function signatures, could save us considerable +trouble in the future. We do our best to provide [deprecation cycles](https://docs.xarray.dev/en/stable/contributing.html#backwards-compatibility) +when making backwards-incompatible changes. + +4. **Documentation and tutorials:** All new methods should have appropriate doc +strings following [PEP257](https://peps.python.org/pep-0257/) and the +[NumPy documentation guide](https://numpy.org/devdocs/dev/howto-docs.html#documentation-style). +For any major new features, accompanying changes should be made to our +[tutorials](https://tutorial.xarray.dev). These should not only +illustrates the new feature, but explains it. + +5. **Implementations and algorithms:** You should understand the code being modified +or added before approving it. (See [Merge Only Changes You Understand](#merge-only-changes-you-understand) +below.) Implementations should do what they claim and be simple, readable, and efficient +in that order. + +6. **Tests:** All contributions *must* be tested, and each added line of code +should be covered by at least one test. Good tests not only execute the code, +but explore corner cases. It can be tempting not to review tests, but please +do so. + +Other changes may be *nitpicky*: spelling mistakes, formatting, +etc. Do not insist contributors make these changes, but instead you should offer +to make these changes by [pushing to their branch](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/committing-changes-to-a-pull-request-branch-created-from-a-fork), +or using GitHub’s [suggestion](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/reviewing-changes-in-pull-requests/commenting-on-a-pull-request) +[feature](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/reviewing-changes-in-pull-requests/incorporating-feedback-in-your-pull-request), and +be prepared to make them yourself if needed. Using the suggestion feature is preferred because +it gives the contributor a choice in whether to accept the changes. + +Unless you know that a contributor is experienced with git, don’t +ask for a rebase when merge conflicts arise. Instead, rebase the +branch yourself, force-push to their branch, and advise the contributor to force-pull. If the contributor is +no longer active, you may take over their branch by submitting a new pull +request and closing the original, including a reference to the original pull +request. In doing so, ensure you communicate that you are not throwing the +contributor's work away! If appropriate it is a good idea to acknowledge other contributions +to the pull request using the `Co-authored-by` +[syntax](https://docs.github.com/en/pull-requests/committing-changes-to-your-project/creating-and-editing-commits/creating-a-commit-with-multiple-authors) in the commit message. + +### Merge only changes you understand + +*Long-term maintainability* is an important concern. Code doesn't +merely have to *work*, but should be *understood* by multiple core +developers. Changes will have to be made in the future, and the +original contributor may have moved on. + +Therefore, *do not merge a code change unless you understand it*. Ask +for help freely: we can consult community members, or even external developers, +for added insight where needed, and see this as a great learning opportunity. + +While we collectively "own" any patches (and bugs!) that become part +of the code base, you are vouching for changes you merge. Please take +that responsibility seriously. + +Feel free to ping other active maintainers with any questions you may have. + +## Further resources + +As a core member, you should be familiar with community and developer +resources such as: + +- Our [contributor guide](https://docs.xarray.dev/en/stable/contributing.html). +- Our [code of conduct](https://github.com/pydata/xarray/tree/main/CODE_OF_CONDUCT.md). +- Our [philosophy and development roadmap](https://docs.xarray.dev/en/stable/roadmap.html). +- [PEP8](https://peps.python.org/pep-0008/) for Python style. +- [PEP257](https://peps.python.org/pep-0257/) and the + [NumPy documentation guide](https://numpy.org/devdocs/dev/howto-docs.html#documentation-style) + for docstring conventions. +- [`pre-commit`](https://pre-commit.com) hooks for autoformatting. +- [`black`](https://github.com/psf/black) autoformatting. +- [`flake8`](https://github.com/PyCQA/flake8) linting. +- [python-xarray](https://stackoverflow.com/questions/tagged/python-xarray) on Stack Overflow. +- [@xarray_dev](https://twitter.com/xarray_dev) on Twitter. +- [xarray-dev](https://discord.gg/bsSGdwBn) discord community (normally only used for remote synchronous chat during sprints). + +You are not required to monitor any of the social resources. + +Where possible we prefer to point people towards asynchronous forms of communication +like github issues instead of realtime chat options as they are far easier +for a global community to consume and refer back to. + +We hold a [bi-weekly developers meeting](https://docs.xarray.dev/en/stable/developers-meeting.html) via video call. +This is a great place to bring up any questions you have, raise visibility of an issue and/or gather more perspectives. +Attendance is absolutely optional, and we keep the meeting to 30 minutes in respect of your valuable time. +This meeting is public, so we occasionally have non-core team members join us. + +We also have a private mailing list for core team members +`xarray-core-team@googlegroups.com` which is sparingly used for discussions +that are required to be private, such as nominating new core members and discussing financial issues. + +## Inviting new core members + +Any core member may nominate other contributors to join the core team. +While there is no hard-and-fast rule about who can be nominated, ideally, +they should have: been part of the project for at least two months, contributed +significant changes of their own, contributed to the discussion and +review of others' work, and collaborated in a way befitting our +community values. **We strongly encourage nominating anyone who has made significant non-code contributions +to the Xarray community in any way**. After nomination voting will happen on a private mailing list. +While it is expected that most votes will be unanimous, a two-thirds majority of +the cast votes is enough. + +Core team members can choose to become emeritus core team members and suspend +their approval and voting rights until they become active again. + +## Contribute to this guide (!) + +This guide reflects the experience of the current core team members. We +may well have missed things that, by now, have become second +nature—things that you, as a new team member, will spot more easily. +Please ask the other core team members if you have any questions, and +submit a pull request with insights gained. + +## Conclusion + +We are excited to have you on board! We look forward to your +contributions to the code base and the community. Thank you in +advance! diff --git a/README.md b/README.md index 41db66fd395..432d535d1b1 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # xarray: N-D labeled arrays and datasets [![CI](https://github.com/pydata/xarray/workflows/CI/badge.svg?branch=main)](https://github.com/pydata/xarray/actions?query=workflow%3ACI) -[![Code coverage](https://codecov.io/gh/pydata/xarray/branch/main/graph/badge.svg)](https://codecov.io/gh/pydata/xarray) +[![Code coverage](https://codecov.io/gh/pydata/xarray/branch/main/graph/badge.svg?flag=unittests)](https://codecov.io/gh/pydata/xarray) [![Docs](https://readthedocs.org/projects/xray/badge/?version=latest)](https://docs.xarray.dev/) [![Benchmarked with asv](https://img.shields.io/badge/benchmarked%20by-asv-green.svg?style=flat)](https://pandas.pydata.org/speed/xarray/) [![Available on pypi](https://img.shields.io/pypi/v/xarray.svg)](https://pypi.python.org/pypi/xarray/) @@ -108,7 +108,7 @@ Thanks to our many contributors! ## License -Copyright 2014-2019, xarray Developers +Copyright 2014-2023, xarray Developers Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may @@ -125,12 +125,12 @@ limitations under the License. Xarray bundles portions of pandas, NumPy and Seaborn, all of which are available under a "3-clause BSD" license: -- pandas: setup.py, xarray/util/print_versions.py -- NumPy: xarray/core/npcompat.py -- Seaborn: _determine_cmap_params in xarray/core/plot/utils.py +- pandas: `setup.py`, `xarray/util/print_versions.py` +- NumPy: `xarray/core/npcompat.py` +- Seaborn: `_determine_cmap_params` in `xarray/core/plot/utils.py` Xarray also bundles portions of CPython, which is available under the -"Python Software Foundation License" in xarray/core/pycompat.py. +"Python Software Foundation License" in `xarray/core/pycompat.py`. Xarray uses icons from the icomoon package (free version), which is available under the "CC BY 4.0" license. diff --git a/asv_bench/benchmarks/dataset.py b/asv_bench/benchmarks/dataset.py new file mode 100644 index 00000000000..d8a6d6df9d8 --- /dev/null +++ b/asv_bench/benchmarks/dataset.py @@ -0,0 +1,32 @@ +import numpy as np + +from xarray import Dataset + +from . import requires_dask + + +class DatasetBinaryOp: + def setup(self): + self.ds = Dataset( + { + "a": (("x", "y"), np.ones((300, 400))), + "b": (("x", "y"), np.ones((300, 400))), + } + ) + self.mean = self.ds.mean() + self.std = self.ds.std() + + def time_normalize(self): + (self.ds - self.mean) / self.std + + +class DatasetChunk: + def setup(self): + requires_dask() + self.ds = Dataset() + array = np.ones(1000) + for i in range(250): + self.ds[f"var{i}"] = ("x", array) + + def time_chunk(self): + self.ds.chunk(x=(1,) * 1000) diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index 0af8084dd21..dcc2de0473b 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -527,7 +527,7 @@ def time_read_dataset(self, engine, chunks): class IOReadCustomEngine: def setup(self, *args, **kwargs): """ - The custom backend does the bare mininum to be considered a lazy backend. But + The custom backend does the bare minimum to be considered a lazy backend. But the data in it is still in memory so slow file reading shouldn't affect the results. """ @@ -593,7 +593,7 @@ def load(self) -> tuple: n_variables = 2000 # Important to have a shape and dtype for lazy loading. - shape = (1,) + shape = (1000,) dtype = np.dtype(int) variables = { f"long_variable_name_{v}": xr.Variable( @@ -643,7 +643,7 @@ def open_dataset( self.engine = PerformanceBackend - @parameterized(["chunks"], ([None, {}])) + @parameterized(["chunks"], ([None, {}, {"time": 10}])) def time_open_dataset(self, chunks): """ Time how fast xr.open_dataset is without the slow data reading part. diff --git a/asv_bench/benchmarks/merge.py b/asv_bench/benchmarks/merge.py index 043de35bdf7..6c8c1e9da90 100644 --- a/asv_bench/benchmarks/merge.py +++ b/asv_bench/benchmarks/merge.py @@ -41,7 +41,7 @@ def setup(self, strategy, count): data = np.array(["0", "b"], dtype=str) self.dataset_coords = dict(time=np.array([0, 1])) self.dataset_attrs = dict(description="Test data") - attrs = dict(units="Celcius") + attrs = dict(units="Celsius") if strategy == "dict_of_DataArrays": def create_data_vars(): diff --git a/asv_bench/benchmarks/pandas.py b/asv_bench/benchmarks/pandas.py index 9bda5970a4c..ebe61081916 100644 --- a/asv_bench/benchmarks/pandas.py +++ b/asv_bench/benchmarks/pandas.py @@ -13,7 +13,7 @@ def setup(self, dtype, subset): [ list("abcdefhijk"), list("abcdefhijk"), - pd.date_range(start="2000-01-01", periods=1000, freq="B"), + pd.date_range(start="2000-01-01", periods=1000, freq="D"), ] ) series = pd.Series(data, index) diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py index 1d3713f19bf..579f4f00fbc 100644 --- a/asv_bench/benchmarks/rolling.py +++ b/asv_bench/benchmarks/rolling.py @@ -5,10 +5,10 @@ from . import parameterized, randn, requires_dask -nx = 300 +nx = 3000 long_nx = 30000 ny = 200 -nt = 100 +nt = 1000 window = 20 randn_xy = randn((nx, ny), frac_nan=0.1) @@ -115,6 +115,11 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck): roll = self.ds.var3.rolling(t=100) getattr(roll, func)() + @parameterized(["stride"], ([None, 5, 50])) + def peakmem_1drolling_construct(self, stride): + self.ds.var2.rolling(t=100).construct("w", stride=stride) + self.ds.var3.rolling(t=100).construct("w", stride=stride) + class DatasetRollingMemory(RollingMemory): @parameterized(["func", "use_bottleneck"], (["sum", "max", "mean"], [True, False])) @@ -128,3 +133,7 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck): with xr.set_options(use_bottleneck=use_bottleneck): roll = self.ds.rolling(t=100) getattr(roll, func)() + + @parameterized(["stride"], ([None, 5, 50])) + def peakmem_1drolling_construct(self, stride): + self.ds.rolling(t=100).construct("w", stride=stride) diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh index 41507fce13e..97ae4c2bbca 100755 --- a/ci/install-upstream-wheels.sh +++ b/ci/install-upstream-wheels.sh @@ -45,4 +45,5 @@ python -m pip install \ git+https://github.com/intake/filesystem_spec \ git+https://github.com/SciTools/nc-time-axis \ git+https://github.com/xarray-contrib/flox \ - git+https://github.com/h5netcdf/h5netcdf + git+https://github.com/h5netcdf/h5netcdf \ + git+https://github.com/dgasmith/opt_einsum diff --git a/ci/min_deps_check.py b/ci/min_deps_check.py index 0f7002ea513..bbaf440a9a0 100755 --- a/ci/min_deps_check.py +++ b/ci/min_deps_check.py @@ -1,6 +1,6 @@ #!/usr/bin/env python """Fetch from conda database all available versions of the xarray dependencies and their -publication date. Compare it against requirements/py37-min-all-deps.yml to verify the +publication date. Compare it against requirements/min-all-deps.yml to verify the policy on obsolete dependencies is being followed. Print a pretty report :) """ import itertools @@ -46,7 +46,7 @@ def warning(msg: str) -> None: def parse_requirements(fname) -> Iterator[tuple[str, int, int, int | None]]: - """Load requirements/py37-min-all-deps.yml + """Load requirements/min-all-deps.yml Yield (package name, major version, minor version, [patch version]) """ diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index 74c0b72bd0d..4645be08b83 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -23,7 +23,7 @@ dependencies: - netcdf4 - numba - numbagg - - numpy<1.24 + - numpy - packaging - pandas - pint<0.21 diff --git a/ci/requirements/bare-minimum.yml b/ci/requirements/bare-minimum.yml index 0a36493fa07..e8a80fdba99 100644 --- a/ci/requirements/bare-minimum.yml +++ b/ci/requirements/bare-minimum.yml @@ -11,6 +11,6 @@ dependencies: - pytest-env - pytest-xdist - pytest-timeout - - numpy=1.21 + - numpy=1.22 - packaging=21.3 - pandas=1.4 diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index 8850ef4f6a9..c4f590fdb18 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -20,9 +20,9 @@ dependencies: - nbsphinx - netcdf4>=1.5 - numba - - numpy>=1.21,<1.24 + - numpy>=1.21 - packaging>=21.3 - - pandas>=1.4 + - pandas>=1.4,!=2.1.0 - pooch - pip - pre-commit @@ -35,6 +35,7 @@ dependencies: - sphinx-book-theme >= 0.3.0 - sphinx-copybutton - sphinx-design + - sphinx-inline-tabs - sphinx>=5.0 - zarr>=2.10 - pip: diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 6abee0b18c3..efa9ccb5a9a 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -22,7 +22,7 @@ dependencies: - netcdf4 - numba - numbagg - - numpy<1.24 + - numpy - packaging - pandas - pint<0.21 diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 7a478c35a58..6e93ab7a946 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -25,7 +25,8 @@ dependencies: - numba - numbagg - numexpr - - numpy<1.24 + - numpy + - opt_einsum - packaging - pandas - pint<0.21 diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index e50d08264b8..8400270ce1b 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -8,48 +8,48 @@ dependencies: # When upgrading python, numpy, or pandas, must also change # doc/user-guide/installing.rst, doc/user-guide/plotting.rst and setup.py. - python=3.9 - - boto3=1.20 + - boto3=1.24 - bottleneck=1.3 - cartopy=0.20 - cdms2=3.1 - - cftime=1.5 + - cftime=1.6 - coveralls - - dask-core=2022.1 - - distributed=2022.1 + - dask-core=2022.7 + - distributed=2022.7 - flox=0.5 - - h5netcdf=0.13 + - h5netcdf=1.0 # h5py and hdf5 tend to cause conflicts # for e.g. hdf5 1.12 conflicts with h5py=3.1 # prioritize bumping other packages instead - h5py=3.6 - hdf5=1.12 - hypothesis - - iris=3.1 - - lxml=4.7 # Optional dep of pydap + - iris=3.2 + - lxml=4.9 # Optional dep of pydap - matplotlib-base=3.5 - nc-time-axis=1.4 # netcdf follows a 1.major.minor[.patch] convention # (see https://github.com/Unidata/netcdf4-python/issues/1090) - - netcdf4=1.5.7 + - netcdf4=1.6.0 - numba=0.55 - - numpy=1.21 + - numpy=1.22 - packaging=21.3 - pandas=1.4 - - pint=0.18 + - pint=0.19 - pip - pseudonetcdf=3.2 - - pydap=3.2 + - pydap=3.3 - pytest - pytest-cov - pytest-env - pytest-xdist - pytest-timeout - - rasterio=1.2 - - scipy=1.7 + - rasterio=1.3 + - scipy=1.8 - seaborn=0.11 - sparse=0.13 - - toolz=0.11 - - typing_extensions=4.0 - - zarr=2.10 + - toolz=0.12 + - typing_extensions=4.3 + - zarr=2.12 - pip: - - numbagg==0.1 + - numbagg==0.2.1 diff --git a/design_notes/named_array_design_doc.md b/design_notes/named_array_design_doc.md new file mode 100644 index 00000000000..fc8129af6b0 --- /dev/null +++ b/design_notes/named_array_design_doc.md @@ -0,0 +1,338 @@ +# named-array Design Document + +## Abstract + +Despite the wealth of scientific libraries in the Python ecosystem, there is a gap for a lightweight, efficient array structure with named dimensions that can provide convenient broadcasting and indexing. + +Existing solutions like Xarray's Variable, [Pytorch Named Tensor](https://github.com/pytorch/pytorch/issues/60832), [Levanter](https://crfm.stanford.edu/2023/06/16/levanter-1_0-release.html), and [Larray](https://larray.readthedocs.io/en/stable/tutorial/getting_started.html) have their own strengths and weaknesses. Xarray's Variable is an efficient data structure, but it depends on the relatively heavy-weight library Pandas, which limits its use in other projects. Pytorch Named Tensor offers named dimensions, but it lacks support for many operations, making it less user-friendly. Levanter is a powerful tool with a named tensor module (Haliax) that makes deep learning code easier to read, understand, and write, but it is not as lightweight or generic as desired. Larry offers labeled N-dimensional arrays, but it may not provide the level of seamless interoperability with other scientific Python libraries that some users need. + +named-array aims to solve these issues by exposing the core functionality of Xarray's Variable class as a standalone package. + +## Motivation and Scope + +The Python ecosystem boasts a wealth of scientific libraries that enable efficient computations on large, multi-dimensional arrays. Libraries like PyTorch, Xarray, and NumPy have revolutionized scientific computing by offering robust data structures for array manipulations. Despite this wealth of tools, a gap exists in the Python landscape for a lightweight, efficient array structure with named dimensions that can provide convenient broadcasting and indexing. + +Xarray internally maintains a data structure that meets this need, referred to as [`xarray.Variable`](https://docs.xarray.dev/en/latest/generated/xarray.Variable.html) . However, Xarray's dependency on Pandas, a relatively heavy-weight library, restricts other projects from leveraging this efficient data structure (, , ). + +We propose the creation of a standalone Python package, "named-array". This package is envisioned to be a version of the `xarray.Variable` data structure, cleanly separated from the heavier dependencies of Xarray. named-array will provide a lightweight, user-friendly array-like data structure with named dimensions, facilitating convenient indexing and broadcasting. The package will use existing scientific Python community standards such as established array protocols and the new [Python array API standard](https://data-apis.org/array-api/latest), allowing users to wrap multiple duck-array objects, including, but not limited to, NumPy, Dask, Sparse, Pint, CuPy, and Pytorch. + +The development of named-array is projected to meet a key community need and expected to broaden Xarray's user base. By making the core `xarray.Variable` more accessible, we anticipate an increase in contributors and a reduction in the developer burden on current Xarray maintainers. + +### Goals + +1. **Simple and minimal**: named-array will expose Xarray's [Variable class](https://docs.xarray.dev/en/stable/internals/variable-objects.html) as a standalone object (`NamedArray`) with named axes (dimensions) and arbitrary metadata (attributes) but without coordinate labels. This will make it a lightweight, efficient array data structure that allows convenient broadcasting and indexing. + +2. **Interoperability**: named-array will follow established scientific Python community standards and in doing so, will allow it to wrap multiple duck-array objects, including but not limited to, NumPy, Dask, Sparse, Pint, CuPy, and Pytorch. + +3. **Community Engagement**: By making the core `xarray.Variable` more accessible, we open the door to increased adoption of this fundamental data structure. As such, we hope to see an increase in contributors and reduction in the developer burden on current Xarray maintainers. + +### Non-Goals + +1. **Extensive Data Analysis**: named-array will not provide extensive data analysis features like statistical functions, data cleaning, or visualization. Its primary focus is on providing a data structure that allows users to use dimension names for descriptive array manipulations. + +2. **Support for I/O**: named-array will not bundle file reading functions. Instead users will be expected to handle I/O and then wrap those arrays with the new named-array data structure. + +## Backward Compatibility + +The creation of named-array is intended to separate the `xarray.Variable` from Xarray into a standalone package. This allows it to be used independently, without the need for Xarray's dependencies, like Pandas. This separation has implications for backward compatibility. + +Since the new named-array is envisioned to contain the core features of Xarray's variable, existing code using Variable from Xarray should be able to switch to named-array with minimal changes. However, there are several potential issues related to backward compatibility: + +* **API Changes**: as the Variable is decoupled from Xarray and moved into named-array, some changes to the API may be necessary. These changes might include differences in function signature, etc. These changes could break existing code that relies on the current API and associated utility functions (e.g. `as_variable()`). The `xarray.Variable` object will subclass `NamedArray`, and provide the existing interface for compatibility. + +## Detailed Description + +named-array aims to provide a lightweight, efficient array structure with named dimensions, or axes, that enables convenient broadcasting and indexing. The primary component of named-array is a standalone version of the xarray.Variable data structure, which was previously a part of the Xarray library. +The xarray.Variable data structure in named-array will maintain the core features of its counterpart in Xarray, including: + +* **Named Axes (Dimensions)**: Each axis of the array can be given a name, providing a descriptive and intuitive way to reference the dimensions of the array. + +* **Arbitrary Metadata (Attributes)**: named-array will support the attachment of arbitrary metadata to arrays as a dict, providing a mechanism to store additional information about the data that the array represents. + +* **Convenient Broadcasting and Indexing**: With named dimensions, broadcasting and indexing operations become more intuitive and less error-prone. + +The named-array package is designed to be interoperable with other scientific Python libraries. It will follow established scientific Python community standards and use standard array protocols, as well as the new data-apis standard. This allows named-array to wrap multiple duck-array objects, including, but not limited to, NumPy, Dask, Sparse, Pint, CuPy, and Pytorch. + +## Implementation + +* **Decoupling**: making `variable.py` agnostic to Xarray internals by decoupling it from the rest of the library. This will make the code more modular and easier to maintain. However, this will also make the code more complex, as we will need to define a clear interface for how the functionality in `variable.py` interacts with the rest of the library, particularly the ExplicitlyIndexed subclasses used to enable lazy indexing of data on disk. +* **Move Xarray's internal lazy indexing classes to follow standard Array Protocols**: moving the lazy indexing classes like `ExplicitlyIndexed` to use standard array protocols will be a key step in decoupling. It will also potentially improve interoperability with other libraries that use these protocols, and prepare these classes [for eventual movement out](https://github.com/pydata/xarray/issues/5081) of the Xarray code base. However, this will also require significant changes to the code, and we will need to ensure that all existing functionality is preserved. + * Use [https://data-apis.org/array-api-compat/](https://data-apis.org/array-api-compat/) to handle compatibility issues? +* **Leave lazy indexing classes in Xarray for now** +* **Preserve support for Dask collection protocols**: named-array will preserve existing support for the dask collections protocol namely the __dask_***__ methods +* **Preserve support for ChunkManagerEntrypoint?** Opening variables backed by dask vs cubed arrays currently is [handled within Variable.chunk](https://github.com/pydata/xarray/blob/92c8b33eb464b09d6f8277265b16cae039ab57ee/xarray/core/variable.py#L1272C15-L1272C15). If we are preserving dask support it would be nice to preserve general chunked array type support, but this currently requires an entrypoint. + +### Plan + +1. Create a new baseclass for `xarray.Variable` to its own module e.g. `xarray.core.base_variable` +2. Remove all imports of internal Xarray classes and utils from `base_variable.py`. `base_variable.Variable` should not depend on anything in xarray.core + * Will require moving the lazy indexing classes (subclasses of ExplicitlyIndexed) to be standards compliant containers.` + * an array-api compliant container that provides **array_namespace**` + * Support `.oindex` and `.vindex` for explicit indexing + * Potentially implement this by introducing a new compliant wrapper object? + * Delete the `NON_NUMPY_SUPPORTED_ARRAY_TYPES` variable which special-cases ExplicitlyIndexed and `pd.Index.` + * `ExplicitlyIndexed` class and subclasses should provide `.oindex` and `.vindex` for indexing by `Variable.__getitem__.`: `oindex` and `vindex` were proposed in [NEP21](https://numpy.org/neps/nep-0021-advanced-indexing.html), but have not been implemented yet + * Delete the ExplicitIndexer objects (`BasicIndexer`, `VectorizedIndexer`, `OuterIndexer`) + * Remove explicit support for `pd.Index`. When provided with a `pd.Index` object, Variable will coerce to an array using `np.array(pd.Index)`. For Xarray's purposes, Xarray can use `as_variable` to explicitly wrap these in PandasIndexingAdapter and pass them to `Variable.__init__`. +3. Define a minimal variable interface that the rest of Xarray can use: + 1. `dims`: tuple of dimension names + 2. `data`: numpy/dask/duck arrays` + 3. `attrs``: dictionary of attributes + +4. Implement basic functions & methods for manipulating these objects. These methods will be a cleaned-up subset (for now) of functionality on xarray.Variable, with adaptations inspired by the [Python array API](https://data-apis.org/array-api/2022.12/API_specification/index.html). +5. Existing Variable structures + 1. Keep Variable object which subclasses the new structure that adds the `.encoding` attribute and potentially other methods needed for easy refactoring. + 2. IndexVariable will remain in xarray.core.variable and subclass the new named-array data structure pending future deletion. +6. Docstrings and user-facing APIs will need to be updated to reflect the changed methods on Variable objects. + +Further implementation details are in Appendix: [Implementation Details](#appendix-implementation-details). + +## Project Timeline and Milestones + +We have identified the following milestones for the completion of this project: + +1. **Write and publish a design document**: this document will explain the purpose of named-array, the intended audience, and the features it will provide. It will also describe the architecture of named-array and how it will be implemented. This will ensure early community awareness and engagement in the project to promote subsequent uptake. +2. **Refactor `variable.py` to `base_variable.py`** and remove internal Xarray imports. +3. **Break out the package and create continuous integration infrastructure**: this will entail breaking out the named-array project into a Python package and creating a continuous integration (CI) system. This will help to modularize the code and make it easier to manage. Building a CI system will help ensure that codebase changes do not break existing functionality. +4. Incrementally add new functions & methods to the new package, ported from xarray. This will start to make named-array useful on its own. +5. Refactor the existing Xarray codebase to rely on the newly created package (named-array): This will help to demonstrate the usefulness of the new package, and also provide an example for others who may want to use it. +6. Expand tests, add documentation, and write a blog post: expanding the test suite will help to ensure that the code is reliable and that changes do not introduce bugs. Adding documentation will make it easier for others to understand and use the project. +7. Finally, we will write a series of blog posts on [xarray.dev](https://xarray.dev/) to promote the project and attract more contributors. + * Toward the end of the process, write a few blog posts that demonstrate the use of the newly available data structure + * pick the same example applications used by other implementations/applications (e.g. Pytorch, sklearn, and Levanter) to show how it can work. + +## Related Work + +1. [GitHub - deepmind/graphcast](https://github.com/deepmind/graphcast) +2. [Getting Started — LArray 0.34 documentation](https://larray.readthedocs.io/en/stable/tutorial/getting_started.html) +3. [Levanter — Legible, Scalable, Reproducible Foundation Models with JAX](https://crfm.stanford.edu/2023/06/16/levanter-1_0-release.html) +4. [google/xarray-tensorstore](https://github.com/google/xarray-tensorstore) +5. [State of Torch Named Tensors · Issue #60832 · pytorch/pytorch · GitHub](https://github.com/pytorch/pytorch/issues/60832) + * Incomplete support: Many primitive operations result in errors, making it difficult to use NamedTensors in Practice. Users often have to resort to removing the names from tensors to avoid these errors. + * Lack of active development: the development of the NamedTensor feature in PyTorch is not currently active due a lack of bandwidth for resolving ambiguities in the design. + * Usability issues: the current form of NamedTensor is not user-friendly and sometimes raises errors, making it difficult for users to incorporate NamedTensors into their workflows. +6. [Scikit-learn Enhancement Proposals (SLEPs) 8, 12, 14](https://github.com/scikit-learn/enhancement_proposals/pull/18) + * Some of the key points and limitations discussed in these proposals are: + * Inconsistency in feature name handling: Scikit-learn currently lacks a consistent and comprehensive way to handle and propagate feature names through its pipelines and estimators ([SLEP 8](https://github.com/scikit-learn/enhancement_proposals/pull/18),[SLEP 12](https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep012/proposal.html)). + * Memory intensive for large feature sets: storing and propagating feature names can be memory intensive, particularly in cases where the entire "dictionary" becomes the features, such as in NLP use cases ([SLEP 8](https://github.com/scikit-learn/enhancement_proposals/pull/18),[GitHub issue #35](https://github.com/scikit-learn/enhancement_proposals/issues/35)) + * Sparse matrices: sparse data structures present a challenge for feature name propagation. For instance, the sparse data structure functionality in Pandas 1.0 only supports converting directly to the coordinate format (COO), which can be an issue with transformers such as the OneHotEncoder.transform that has been optimized to construct a CSR matrix ([SLEP 14](https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep014/proposal.html)) + * New Data structures: the introduction of new data structures, such as "InputArray" or "DataArray" could lead to more burden for third-party estimator maintainers and increase the learning curve for users. Xarray's "DataArray" is mentioned as a potential alternative, but the proposal mentions that the conversion from a Pandas dataframe to a Dataset is not lossless ([SLEP 12](https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep012/proposal.html),[SLEP 14](https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep014/proposal.html),[GitHub issue #35](https://github.com/scikit-learn/enhancement_proposals/issues/35)). + * Dependency on other libraries: solutions that involve using Xarray and/or Pandas to handle feature names come with the challenge of managing dependencies. While a soft dependency approach is suggested, this means users would be able to have/enable the feature only if they have the dependency installed. Xarra-lite's integration with other scientific Python libraries could potentially help with this issue ([GitHub issue #35](https://github.com/scikit-learn/enhancement_proposals/issues/35)). + +## References and Previous Discussion + +* [[Proposal] Expose Variable without Pandas dependency · Issue #3981 · pydata/xarray · GitHub](https://github.com/pydata/xarray/issues/3981) +* [https://github.com/pydata/xarray/issues/3981#issuecomment-985051449](https://github.com/pydata/xarray/issues/3981#issuecomment-985051449) +* [Lazy indexing arrays as a stand-alone package · Issue #5081 · pydata/xarray · GitHub](https://github.com/pydata/xarray/issues/5081) + +### Appendix: Engagement with the Community + +We plan to publicize this document on : + +* [x] `Xarray dev call` +* [ ] `Scientific Python discourse` +* [ ] `Xarray Github` +* [ ] `Twitter` +* [ ] `Respond to NamedTensor and Scikit-Learn issues?` +* [ ] `Pangeo Discourse` +* [ ] `Numpy, SciPy email lists?` +* [ ] `Xarray blog` + +Additionally, We plan on writing a series of blog posts to effectively showcase the implementation and potential of the newly available functionality. To illustrate this, we will use the same example applications as other established libraries (such as Pytorch, sklearn), providing practical demonstrations of how these new data structures can be leveraged. + +### Appendix: API Surface + +Questions: + +1. Document Xarray indexing rules +2. Document use of .oindex and .vindex protocols +3. Do we use `.mean` and `.nanmean` or `.mean(skipna=...)`? + * Default behavior in named-array should mirror NumPy / the array API standard, not pandas. + * nanmean is not (yet) in the [array API](https://github.com/pydata/xarray/pull/7424#issuecomment-1373979208). There are a handful of other key functions (e.g., median) that are are also missing. I think that should be OK, as long as what we support is a strict superset of the array API. +4. What methods need to be exposed on Variable? + * `Variable.concat` classmethod: create two functions, one as the equivalent of `np.stack` and other for `np.concat` + * `.rolling_window` and `.coarsen_reshape` ? + * `named-array.apply_ufunc`: used in astype, clip, quantile, isnull, notnull` + +#### methods to be preserved from xarray.Variable + +```python +# Sorting + Variable.argsort + Variable.searchsorted + +# NaN handling + Variable.fillna + Variable.isnull + Variable.notnull + +# Lazy data handling + Variable.chunk # Could instead have accessor interface and recommend users use `Variable.dask.chunk` and `Variable.cubed.chunk`? + Variable.to_numpy() + Variable.as_numpy() + +# Xarray-specific + Variable.get_axis_num + Variable.isel + Variable.to_dict + +# Reductions + Variable.reduce + Variable.all + Variable.any + Variable.argmax + Variable.argmin + Variable.count + Variable.max + Variable.mean + Variable.median + Variable.min + Variable.prod + Variable.quantile + Variable.std + Variable.sum + Variable.var + +# Accumulate + Variable.cumprod + Variable.cumsum + +# numpy-like Methods + Variable.astype + Variable.copy + Variable.clip + Variable.round + Variable.item + Variable.where + +# Reordering/Reshaping + Variable.squeeze + Variable.pad + Variable.roll + Variable.shift + +``` + +#### methods to be renamed from xarray.Variable + +```python +# Xarray-specific + Variable.concat # create two functions, one as the equivalent of `np.stack` and other for `np.concat` + + # Given how niche these are, these would be better as functions than methods. + # We could also keep these in Xarray, at least for now. If we don't think people will use functionality outside of Xarray it probably is not worth the trouble of porting it (including documentation, etc). + Variable.coarsen # This should probably be called something like coarsen_reduce. + Variable.coarsen_reshape + Variable.rolling_window + + Variable.set_dims # split this into broadcas_to and expand_dims + + +# Reordering/Reshaping + Variable.stack # To avoid confusion with np.stack, let's call this stack_dims. + Variable.transpose # Could consider calling this permute_dims, like the [array API standard](https://data-apis.org/array-api/2022.12/API_specification/manipulation_functions.html#objects-in-api) + Variable.unstack # Likewise, maybe call this unstack_dims? +``` + +#### methods to be removed from xarray.Variable + +```python +# Testing + Variable.broadcast_equals + Variable.equals + Variable.identical + Variable.no_conflicts + +# Lazy data handling + Variable.compute # We can probably omit this method for now, too, given that dask.compute() uses a protocol. The other concern is that different array libraries have different notions of "compute" and this one is rather Dask specific, including conversion from Dask to NumPy arrays. For example, in JAX every operation executes eagerly, but in a non-blocking fashion, and you need to call jax.block_until_ready() to ensure computation is finished. + Variable.load # Could remove? compute vs load is a common source of confusion. + +# Xarray-specific + Variable.to_index + Variable.to_index_variable + Variable.to_variable + Variable.to_base_variable + Variable.to_coord + + Variable.rank # Uses bottleneck. Delete? Could use https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.rankdata.html instead + + +# numpy-like Methods + Variable.conjugate # .conj is enough + Variable.__array_wrap__ # This is a very old NumPy protocol for duck arrays. We don't need it now that we have `__array_ufunc__` and `__array_function__` + +# Encoding + Variable.reset_encoding + +``` + +#### Attributes to be preserved from xarray.Variable + +```python +# Properties + Variable.attrs + Variable.chunks + Variable.data + Variable.dims + Variable.dtype + + Variable.nbytes + Variable.ndim + Variable.shape + Variable.size + Variable.sizes + + Variable.T + Variable.real + Variable.imag + Variable.conj +``` + +#### Attributes to be renamed from xarray.Variable + +```python +``` + +#### Attributes to be removed from xarray.Variable + +```python + + Variable.values # Probably also remove -- this is a legacy from before Xarray supported dask arrays. ".data" is enough. + +# Encoding + Variable.encoding + +``` + +### Appendix: Implementation Details + +* Merge in VariableArithmetic's parent classes: AbstractArray, NdimSizeLenMixin with the new data structure.. + +```python +class VariableArithmetic( + ImplementsArrayReduce, + IncludeReduceMethods, + IncludeCumMethods, + IncludeNumpySameMethods, + SupportsArithmetic, + VariableOpsMixin, +): + __slots__ = () + # prioritize our operations over those of numpy.ndarray (priority=0) + __array_priority__ = 50 + +``` + +* Move over `_typed_ops.VariableOpsMixin` +* Build a list of utility functions used elsewhere : Which of these should become public API? + * `broadcast_variables`: `dataset.py`, `dataarray.py`,`missing.py` + * This could be just called "broadcast" in named-array. + * `Variable._getitem_with_mask` : `alignment.py` + * keep this method/function as private and inside Xarray. +* The Variable constructor will need to be rewritten to no longer accept tuples, encodings, etc. These details should be handled at the Xarray data structure level. +* What happens to `duck_array_ops?` +* What about Variable.chunk and "chunk managers"? + * Could this functionality be left in Xarray proper for now? Alternative array types like JAX also have some notion of "chunks" for parallel arrays, but the details differ in a number of ways from the Dask/Cubed. + * Perhaps variable.chunk/load methods should become functions defined in xarray that convert Variable objects. This is easy so long as xarray can reach in and replace .data + +* Utility functions like `as_variable` should be moved out of `base_variable.py` so they can convert BaseVariable objects to/from DataArray or Dataset containing explicitly indexed arrays. diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 1a2b1d11747..c96b0aa5c3b 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -9,17 +9,42 @@ .. autosummary:: :toctree: generated/ + Coordinates.from_pandas_multiindex + Coordinates.get + Coordinates.items + Coordinates.keys + Coordinates.values + Coordinates.dims + Coordinates.dtypes + Coordinates.variables + Coordinates.xindexes + Coordinates.indexes + Coordinates.to_dataset + Coordinates.to_index + Coordinates.update + Coordinates.assign + Coordinates.merge + Coordinates.copy + Coordinates.equals + Coordinates.identical + core.coordinates.DatasetCoordinates.get core.coordinates.DatasetCoordinates.items core.coordinates.DatasetCoordinates.keys - core.coordinates.DatasetCoordinates.merge - core.coordinates.DatasetCoordinates.to_dataset - core.coordinates.DatasetCoordinates.to_index - core.coordinates.DatasetCoordinates.update core.coordinates.DatasetCoordinates.values core.coordinates.DatasetCoordinates.dims - core.coordinates.DatasetCoordinates.indexes + core.coordinates.DatasetCoordinates.dtypes core.coordinates.DatasetCoordinates.variables + core.coordinates.DatasetCoordinates.xindexes + core.coordinates.DatasetCoordinates.indexes + core.coordinates.DatasetCoordinates.to_dataset + core.coordinates.DatasetCoordinates.to_index + core.coordinates.DatasetCoordinates.update + core.coordinates.DatasetCoordinates.assign + core.coordinates.DatasetCoordinates.merge + core.coordinates.DatasetCoordinates.copy + core.coordinates.DatasetCoordinates.equals + core.coordinates.DatasetCoordinates.identical core.rolling.DatasetCoarsen.boundary core.rolling.DatasetCoarsen.coord_func @@ -47,14 +72,20 @@ core.coordinates.DataArrayCoordinates.get core.coordinates.DataArrayCoordinates.items core.coordinates.DataArrayCoordinates.keys - core.coordinates.DataArrayCoordinates.merge - core.coordinates.DataArrayCoordinates.to_dataset - core.coordinates.DataArrayCoordinates.to_index - core.coordinates.DataArrayCoordinates.update core.coordinates.DataArrayCoordinates.values core.coordinates.DataArrayCoordinates.dims - core.coordinates.DataArrayCoordinates.indexes + core.coordinates.DataArrayCoordinates.dtypes core.coordinates.DataArrayCoordinates.variables + core.coordinates.DataArrayCoordinates.xindexes + core.coordinates.DataArrayCoordinates.indexes + core.coordinates.DataArrayCoordinates.to_dataset + core.coordinates.DataArrayCoordinates.to_index + core.coordinates.DataArrayCoordinates.update + core.coordinates.DataArrayCoordinates.assign + core.coordinates.DataArrayCoordinates.merge + core.coordinates.DataArrayCoordinates.copy + core.coordinates.DataArrayCoordinates.equals + core.coordinates.DataArrayCoordinates.identical core.rolling.DataArrayCoarsen.boundary core.rolling.DataArrayCoarsen.coord_func @@ -234,7 +265,7 @@ Variable.dims Variable.dtype Variable.encoding - Variable.reset_encoding + Variable.drop_encoding Variable.imag Variable.nbytes Variable.ndim @@ -320,6 +351,36 @@ IndexVariable.sizes IndexVariable.values + + namedarray.core.NamedArray.all + namedarray.core.NamedArray.any + namedarray.core.NamedArray.attrs + namedarray.core.NamedArray.chunks + namedarray.core.NamedArray.chunksizes + namedarray.core.NamedArray.copy + namedarray.core.NamedArray.count + namedarray.core.NamedArray.cumprod + namedarray.core.NamedArray.cumsum + namedarray.core.NamedArray.data + namedarray.core.NamedArray.dims + namedarray.core.NamedArray.dtype + namedarray.core.NamedArray.get_axis_num + namedarray.core.NamedArray.max + namedarray.core.NamedArray.mean + namedarray.core.NamedArray.median + namedarray.core.NamedArray.min + namedarray.core.NamedArray.nbytes + namedarray.core.NamedArray.ndim + namedarray.core.NamedArray.prod + namedarray.core.NamedArray.reduce + namedarray.core.NamedArray.shape + namedarray.core.NamedArray.size + namedarray.core.NamedArray.sizes + namedarray.core.NamedArray.std + namedarray.core.NamedArray.sum + namedarray.core.NamedArray.var + + plot.plot plot.line plot.step diff --git a/doc/api.rst b/doc/api.rst index 1cfd44dcff7..a930c9e5466 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -110,9 +110,9 @@ Dataset contents Dataset.drop_indexes Dataset.drop_duplicates Dataset.drop_dims + Dataset.drop_encoding Dataset.set_coords Dataset.reset_coords - Dataset.reset_encoding Dataset.convert_calendar Dataset.interp_calendar Dataset.get_index @@ -303,8 +303,8 @@ DataArray contents DataArray.drop_vars DataArray.drop_indexes DataArray.drop_duplicates + DataArray.drop_encoding DataArray.reset_coords - DataArray.reset_encoding DataArray.copy DataArray.convert_calendar DataArray.interp_calendar @@ -1107,6 +1107,7 @@ Advanced API .. autosummary:: :toctree: generated/ + Coordinates Dataset.variables DataArray.variable Variable diff --git a/doc/conf.py b/doc/conf.py index 8f8fdebd169..8b425e15bae 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -84,6 +84,7 @@ "sphinx_copybutton", "sphinxext.rediraffe", "sphinx_design", + "sphinx_inline_tabs", ] @@ -230,6 +231,7 @@ # canonical_url="", repository_url="https://github.com/pydata/xarray", repository_branch="main", + navigation_with_keys=False, # pydata/pydata-sphinx-theme#1492 path_to_docs="doc", use_edit_page_button=True, use_repository_button=True, @@ -325,6 +327,8 @@ "sparse": ("https://sparse.pydata.org/en/latest/", None), "hypothesis": ("https://hypothesis.readthedocs.io/en/latest/", None), "cubed": ("https://tom-e-white.com/cubed/", None), + "datatree": ("https://xarray-datatree.readthedocs.io/en/latest/", None), + # "opt_einsum": ("https://dgasmith.github.io/opt_einsum/", None), } diff --git a/doc/contributing.rst b/doc/contributing.rst index 3cc43314d9a..3cdd7dd9933 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -518,7 +518,7 @@ See the `Installation `_ Including figures and files ---------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~ Image files can be directly included in pages with the ``image::`` directive. diff --git a/doc/ecosystem.rst b/doc/ecosystem.rst index e6e970c6239..fc5ae963a1d 100644 --- a/doc/ecosystem.rst +++ b/doc/ecosystem.rst @@ -41,6 +41,7 @@ Geosciences harmonic wind analysis in Python. - `wradlib `_: An Open Source Library for Weather Radar Data Processing. - `wrf-python `_: A collection of diagnostic and interpolation routines for use with output of the Weather Research and Forecasting (WRF-ARW) Model. +- `xarray-regrid `_: xarray extension for regridding rectilinear data. - `xarray-simlab `_: xarray extension for computer model simulations. - `xarray-spatial `_: Numba-accelerated raster-based spatial processing tools (NDVI, curvature, zonal-statistics, proximity, hillshading, viewshed, etc.) - `xarray-topo `_: xarray extension for topographic analysis and modelling. diff --git a/doc/examples/apply_ufunc_vectorize_1d.ipynb b/doc/examples/apply_ufunc_vectorize_1d.ipynb index 68d011d0725..c2ab7271873 100644 --- a/doc/examples/apply_ufunc_vectorize_1d.ipynb +++ b/doc/examples/apply_ufunc_vectorize_1d.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This example will illustrate how to conveniently apply an unvectorized function `func` to xarray objects using `apply_ufunc`. `func` expects 1D numpy arrays and returns a 1D numpy array. Our goal is to coveniently apply this function along a dimension of xarray objects that may or may not wrap dask arrays with a signature.\n", + "This example will illustrate how to conveniently apply an unvectorized function `func` to xarray objects using `apply_ufunc`. `func` expects 1D numpy arrays and returns a 1D numpy array. Our goal is to conveniently apply this function along a dimension of xarray objects that may or may not wrap dask arrays with a signature.\n", "\n", "We will illustrate this using `np.interp`: \n", "\n", diff --git a/doc/getting-started-guide/installing.rst b/doc/getting-started-guide/installing.rst index 9fee849a341..e8c498b6664 100644 --- a/doc/getting-started-guide/installing.rst +++ b/doc/getting-started-guide/installing.rst @@ -7,7 +7,7 @@ Required dependencies --------------------- - Python (3.9 or later) -- `numpy `__ (1.21 or later) +- `numpy `__ (1.22 or later) - `packaging `__ (21.3 or later) - `pandas `__ (1.4 or later) @@ -135,13 +135,13 @@ We also maintain other dependency sets for different subsets of functionality:: The above commands should install most of the `optional dependencies`_. However, some packages which are either not listed on PyPI or require extra installation steps are excluded. To know which dependencies would be -installed, take a look at the ``[options.extras_require]`` section in -``setup.cfg``: +installed, take a look at the ``[project.optional-dependencies]`` section in +``pyproject.toml``: -.. literalinclude:: ../../setup.cfg - :language: ini - :start-at: [options.extras_require] - :end-before: [options.package_data] +.. literalinclude:: ../../pyproject.toml + :language: toml + :start-at: [project.optional-dependencies] + :end-before: [build-system] Development versions -------------------- diff --git a/doc/internals/duck-arrays-integration.rst b/doc/internals/duck-arrays-integration.rst index 1f1f57974df..a674acb04fe 100644 --- a/doc/internals/duck-arrays-integration.rst +++ b/doc/internals/duck-arrays-integration.rst @@ -35,7 +35,7 @@ Python Array API standard support ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ As an integration library xarray benefits greatly from the standardization of duck-array libraries' APIs, and so is a -big supporter of the `Python Array API Standard `_. . +big supporter of the `Python Array API Standard `_. We aim to support any array libraries that follow the Array API standard out-of-the-box. However, xarray does occasionally call some numpy functions which are not (yet) part of the standard (e.g. :py:meth:`xarray.DataArray.pad` calls :py:func:`numpy.pad`). diff --git a/doc/internals/extending-xarray.rst b/doc/internals/extending-xarray.rst index a180b85044f..cb1b23e78eb 100644 --- a/doc/internals/extending-xarray.rst +++ b/doc/internals/extending-xarray.rst @@ -14,6 +14,11 @@ Xarray is designed as a general purpose library and hence tries to avoid including overly domain specific functionality. But inevitably, the need for more domain specific logic arises. +.. _internals.accessors.composition: + +Composition over Inheritance +---------------------------- + One potential solution to this problem is to subclass Dataset and/or DataArray to add domain specific functionality. However, inheritance is not very robust. It's easy to inadvertently use internal APIs when subclassing, which means that your @@ -23,11 +28,17 @@ only return native xarray objects. The standard advice is to use :issue:`composition over inheritance <706>`, but reimplementing an API as large as xarray's on your own objects can be an onerous task, even if most methods are only forwarding to xarray implementations. +(For an example of a project which took this approach of subclassing see `UXarray `_). If you simply want the ability to call a function with the syntax of a method call, then the builtin :py:meth:`~xarray.DataArray.pipe` method (copied from pandas) may suffice. +.. _internals.accessors.writing accessors: + +Writing Custom Accessors +------------------------ + To resolve this issue for more complex cases, xarray has the :py:func:`~xarray.register_dataset_accessor` and :py:func:`~xarray.register_dataarray_accessor` decorators for adding custom diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst index a106232958e..ca42d60abaf 100644 --- a/doc/internals/how-to-add-new-backend.rst +++ b/doc/internals/how-to-add-new-backend.rst @@ -9,7 +9,8 @@ to integrate any code in Xarray; all you need to do is: - Create a class that inherits from Xarray :py:class:`~xarray.backends.BackendEntrypoint` and implements the method ``open_dataset`` see :ref:`RST backend_entrypoint` -- Declare this class as an external plugin in your ``setup.py``, see :ref:`RST backend_registration` +- Declare this class as an external plugin in your project configuration, see :ref:`RST + backend_registration` If you also want to support lazy loading and dask see :ref:`RST lazy_loading`. @@ -267,42 +268,57 @@ interface only the boolean keywords related to the supported decoders. How to register a backend +++++++++++++++++++++++++ -Define a new entrypoint in your ``setup.py`` (or ``setup.cfg``) with: +Define a new entrypoint in your ``pyproject.toml`` (or ``setup.cfg/setup.py`` for older +configurations), with: - group: ``xarray.backends`` - name: the name to be passed to :py:meth:`~xarray.open_dataset` as ``engine`` - object reference: the reference of the class that you have implemented. -You can declare the entrypoint in ``setup.py`` using the following syntax: +You can declare the entrypoint in your project configuration like so: -.. code-block:: +.. tab:: pyproject.toml - setuptools.setup( - entry_points={ - "xarray.backends": ["my_engine=my_package.my_module:MyBackendEntryClass"], - }, - ) + .. code:: toml + + [project.entry-points."xarray-backends"] + my_engine = "my_package.my_module:MyBackendEntrypoint" + +.. tab:: pyproject.toml [Poetry] + + .. code-block:: toml + + [tool.poetry.plugins."xarray.backends"] + my_engine = "my_package.my_module:MyBackendEntrypoint" -in ``setup.cfg``: +.. tab:: setup.cfg -.. code-block:: cfg + .. code-block:: cfg - [options.entry_points] - xarray.backends = - my_engine = my_package.my_module:MyBackendEntryClass + [options.entry_points] + xarray.backends = + my_engine = my_package.my_module:MyBackendEntrypoint +.. tab:: setup.py -See https://packaging.python.org/specifications/entry-points/#data-model -for more information + .. code-block:: -If you are using `Poetry `_ for your build system, you can accomplish the same thing using "plugins". In this case you would need to add the following to your ``pyproject.toml`` file: + setuptools.setup( + entry_points={ + "xarray.backends": [ + "my_engine=my_package.my_module:MyBackendEntrypoint" + ], + }, + ) -.. code-block:: toml - [tool.poetry.plugins."xarray.backends"] - "my_engine" = "my_package.my_module:MyBackendEntryClass" +See the `Python Packaging User Guide +`_ for more +information on entrypoints and details of the syntax. -See https://python-poetry.org/docs/pyproject/#plugins for more information on Poetry plugins. +If you're using Poetry, note that table name in ``pyproject.toml`` is slightly different. +See `the Poetry docs `_ for more +information on plugins. .. _RST lazy_loading: diff --git a/doc/internals/how-to-create-custom-index.rst b/doc/internals/how-to-create-custom-index.rst index 93805229db1..90b3412c2cb 100644 --- a/doc/internals/how-to-create-custom-index.rst +++ b/doc/internals/how-to-create-custom-index.rst @@ -1,5 +1,7 @@ .. currentmodule:: xarray +.. _internals.custom indexes: + How to create a custom index ============================ diff --git a/doc/internals/index.rst b/doc/internals/index.rst index 7e13f0cfe95..b2a37900338 100644 --- a/doc/internals/index.rst +++ b/doc/internals/index.rst @@ -1,6 +1,6 @@ .. _internals: -xarray Internals +Xarray Internals ================ Xarray builds upon two of the foundational libraries of the scientific Python @@ -11,18 +11,18 @@ compiled code to :ref:`optional dependencies`. The pages in this section are intended for: * Contributors to xarray who wish to better understand some of the internals, -* Developers who wish to extend xarray with domain-specific logic, perhaps to support a new scientific community of users, -* Developers who wish to interface xarray with their existing tooling, e.g. by creating a plugin for reading a new file format, or wrapping a custom array type. - +* Developers from other fields who wish to extend xarray with domain-specific logic, perhaps to support a new scientific community of users, +* Developers of other packages who wish to interface xarray with their existing tools, e.g. by creating a backend for reading a new file format, or wrapping a custom array type. .. toctree:: :maxdepth: 2 :hidden: - variable-objects + internal-design + interoperability duck-arrays-integration chunked-arrays extending-xarray - zarr-encoding-spec how-to-add-new-backend how-to-create-custom-index + zarr-encoding-spec diff --git a/doc/internals/internal-design.rst b/doc/internals/internal-design.rst new file mode 100644 index 00000000000..55ab2d79dbe --- /dev/null +++ b/doc/internals/internal-design.rst @@ -0,0 +1,224 @@ +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + np.set_printoptions(threshold=20) + +.. _internal design: + +Internal Design +=============== + +This page gives an overview of the internal design of xarray. + +In totality, the Xarray project defines 4 key data structures. +In order of increasing complexity, they are: + +- :py:class:`xarray.Variable`, +- :py:class:`xarray.DataArray`, +- :py:class:`xarray.Dataset`, +- :py:class:`datatree.DataTree`. + +The user guide lists only :py:class:`xarray.DataArray` and :py:class:`xarray.Dataset`, +but :py:class:`~xarray.Variable` is the fundamental object internally, +and :py:class:`~datatree.DataTree` is a natural generalisation of :py:class:`xarray.Dataset`. + +.. note:: + + Our :ref:`roadmap` includes plans both to document :py:class:`~xarray.Variable` as fully public API, + and to merge the `xarray-datatree `_ package into xarray's main repository. + +Internally private :ref:`lazy indexing classes ` are used to avoid loading more data than necessary, +and flexible indexes classes (derived from :py:class:`~xarray.indexes.Index`) provide performant label-based lookups. + + +.. _internal design.data structures: + +Data Structures +--------------- + +The :ref:`data structures` page in the user guide explains the basics and concentrates on user-facing behavior, +whereas this section explains how xarray's data structure classes actually work internally. + + +.. _internal design.data structures.variable: + +Variable Objects +~~~~~~~~~~~~~~~~ + +The core internal data structure in xarray is the :py:class:`~xarray.Variable`, +which is used as the basic building block behind xarray's +:py:class:`~xarray.Dataset`, :py:class:`~xarray.DataArray` types. A +:py:class:`~xarray.Variable` consists of: + +- ``dims``: A tuple of dimension names. +- ``data``: The N-dimensional array (typically a NumPy or Dask array) storing + the Variable's data. It must have the same number of dimensions as the length + of ``dims``. +- ``attrs``: A dictionary of metadata associated with this array. By + convention, xarray's built-in operations never use this metadata. +- ``encoding``: Another dictionary used to store information about how + these variable's data is represented on disk. See :ref:`io.encoding` for more + details. + +:py:class:`~xarray.Variable` has an interface similar to NumPy arrays, but extended to make use +of named dimensions. For example, it uses ``dim`` in preference to an ``axis`` +argument for methods like ``mean``, and supports :ref:`compute.broadcasting`. + +However, unlike ``Dataset`` and ``DataArray``, the basic ``Variable`` does not +include coordinate labels along each axis. + +:py:class:`~xarray.Variable` is public API, but because of its incomplete support for labeled +data, it is mostly intended for advanced uses, such as in xarray itself, for +writing new backends, or when creating custom indexes. +You can access the variable objects that correspond to xarray objects via the (readonly) +:py:attr:`Dataset.variables ` and +:py:attr:`DataArray.variable ` attributes. + + +.. _internal design.dataarray: + +DataArray Objects +~~~~~~~~~~~~~~~~~ + +The simplest data structure used by most users is :py:class:`~xarray.DataArray`. +A :py:class:`~xarray.DataArray` is a composite object consisting of multiple +:py:class:`~xarray.core.variable.Variable` objects which store related data. + +A single :py:class:`~xarray.core.Variable` is referred to as the "data variable", and stored under the :py:attr:`~xarray.DataArray.variable`` attribute. +A :py:class:`~xarray.DataArray` inherits all of the properties of this data variable, i.e. ``dims``, ``data``, ``attrs`` and ``encoding``, +all of which are implemented by forwarding on to the underlying ``Variable`` object. + +In addition, a :py:class:`~xarray.DataArray` stores additional ``Variable`` objects stored in a dict under the private ``_coords`` attribute, +each of which is referred to as a "Coordinate Variable". These coordinate variable objects are only allowed to have ``dims`` that are a subset of the data variable's ``dims``, +and each dim has a specific length. This means that the full :py:attr:`~xarray.DataArray.size` of the dataarray can be represented by a dictionary mapping dimension names to integer sizes. +The underlying data variable has this exact same size, and the attached coordinate variables have sizes which are some subset of the size of the data variable. +Another way of saying this is that all coordinate variables must be "alignable" with the data variable. + +When a coordinate is accessed by the user (e.g. via the dict-like :py:class:`~xarray.DataArray.__getitem__` syntax), +then a new ``DataArray`` is constructed by finding all coordinate variables that have compatible dimensions and re-attaching them before the result is returned. +This is why most users never see the ``Variable`` class underlying each coordinate variable - it is always promoted to a ``DataArray`` before returning. + +Lookups are performed by special :py:class:`~xarray.indexes.Index` objects, which are stored in a dict under the private ``_indexes`` attribute. +Indexes must be associated with one or more coordinates, and essentially act by translating a query given in physical coordinate space +(typically via the :py:meth:`~xarray.DataArray.sel` method) into a set of integer indices in array index space that can be used to index the underlying n-dimensional array-like ``data``. +Indexing in array index space (typically performed via the :py:meth:`~xarray.DataArray.isel` method) does not require consulting an ``Index`` object. + +Finally a :py:class:`~xarray.DataArray` defines a :py:attr:`~xarray.DataArray.name` attribute, which refers to its data +variable but is stored on the wrapping ``DataArray`` class. +The ``name`` attribute is primarily used when one or more :py:class:`~xarray.DataArray` objects are promoted into a :py:class:`~xarray.Dataset` +(e.g. via :py:meth:`~xarray.DataArray.to_dataset`). +Note that the underlying :py:class:`~xarray.core.Variable` objects are all unnamed, so they can always be referred to uniquely via a +dict-like mapping. + +.. _internal design.dataset: + +Dataset Objects +~~~~~~~~~~~~~~~ + +The :py:class:`~xarray.Dataset` class is a generalization of the :py:class:`~xarray.DataArray` class that can hold multiple data variables. +Internally all data variables and coordinate variables are stored under a single ``variables`` dict, and coordinates are +specified by storing their names in a private ``_coord_names`` dict. + +The dataset's ``dims`` are the set of all dims present across any variable, but (similar to in dataarrays) coordinate +variables cannot have a dimension that is not present on any data variable. + +When a data variable or coordinate variable is accessed, a new ``DataArray`` is again constructed from all compatible +coordinates before returning. + +.. _internal design.subclassing: + +.. note:: + + The way that selecting a variable from a ``DataArray`` or ``Dataset`` actually involves internally wrapping the + ``Variable`` object back up into a ``DataArray``/``Dataset`` is the primary reason :ref:`we recommend against subclassing ` + Xarray objects. The main problem it creates is that we currently cannot easily guarantee that for example selecting + a coordinate variable from your ``SubclassedDataArray`` would return an instance of ``SubclassedDataArray`` instead + of just an :py:class:`xarray.DataArray`. See `GH issue `_ for more details. + +.. _internal design.lazy indexing: + +Lazy Indexing Classes +--------------------- + +Lazy Loading +~~~~~~~~~~~~ + +If we open a ``Variable`` object from disk using :py:func:`~xarray.open_dataset` we can see that the actual values of +the array wrapped by the data variable are not displayed. + +.. ipython:: python + + da = xr.tutorial.open_dataset("air_temperature")["air"] + var = da.variable + var + +We can see the size, and the dtype of the underlying array, but not the actual values. +This is because the values have not yet been loaded. + +If we look at the private attribute :py:meth:`~xarray.Variable._data` containing the underlying array object, we see +something interesting: + +.. ipython:: python + + var._data + +You're looking at one of xarray's internal `Lazy Indexing Classes`. These powerful classes are hidden from the user, +but provide important functionality. + +Calling the public :py:attr:`~xarray.Variable.data` property loads the underlying array into memory. + +.. ipython:: python + + var.data + +This array is now cached, which we can see by accessing the private attribute again: + +.. ipython:: python + + var._data + +Lazy Indexing +~~~~~~~~~~~~~ + +The purpose of these lazy indexing classes is to prevent more data being loaded into memory than is necessary for the +subsequent analysis, by deferring loading data until after indexing is performed. + +Let's open the data from disk again. + +.. ipython:: python + + da = xr.tutorial.open_dataset("air_temperature")["air"] + var = da.variable + +Now, notice how even after subsetting the data has does not get loaded: + +.. ipython:: python + + var.isel(time=0) + +The shape has changed, but the values are still not shown. + +Looking at the private attribute again shows how this indexing information was propagated via the hidden lazy indexing classes: + +.. ipython:: python + + var.isel(time=0)._data + +.. note:: + + Currently only certain indexing operations are lazy, not all array operations. For discussion of making all array + operations lazy see `GH issue #5081 `_. + + +Lazy Dask Arrays +~~~~~~~~~~~~~~~~ + +Note that xarray's implementation of Lazy Indexing classes is completely separate from how :py:class:`dask.array.Array` +objects evaluate lazily. Dask-backed xarray objects delay almost all operations until :py:meth:`~xarray.DataArray.compute` +is called (either explicitly or implicitly via :py:meth:`~xarray.DataArray.plot` for example). The exceptions to this +laziness are operations whose output shape is data-dependent, such as when calling :py:meth:`~xarray.DataArray.where`. diff --git a/doc/internals/interoperability.rst b/doc/internals/interoperability.rst new file mode 100644 index 00000000000..cbd96362e35 --- /dev/null +++ b/doc/internals/interoperability.rst @@ -0,0 +1,45 @@ +.. _interoperability: + +Interoperability of Xarray +========================== + +Xarray is designed to be extremely interoperable, in many orthogonal ways. +Making xarray as flexible as possible is the common theme of most of the goals on our :ref:`roadmap`. + +This interoperability comes via a set of flexible abstractions into which the user can plug in. The current full list is: + +- :ref:`Custom file backends ` via the :py:class:`~xarray.backends.BackendEntrypoint` system, +- Numpy-like :ref:`"duck" array wrapping `, which supports the `Python Array API Standard `_, +- :ref:`Chunked distributed array computation ` via the :py:class:`~xarray.core.parallelcompat.ChunkManagerEntrypoint` system, +- Custom :py:class:`~xarray.Index` objects for :ref:`flexible label-based lookups `, +- Extending xarray objects with domain-specific methods via :ref:`custom accessors `. + +.. warning:: + + One obvious way in which xarray could be more flexible is that whilst subclassing xarray objects is possible, we + currently don't support it in most transformations, instead recommending composition over inheritance. See the + :ref:`internal design page ` for the rationale and look at the corresponding `GH issue `_ + if you're interested in improving support for subclassing! + +.. note:: + + If you think there is another way in which xarray could become more generically flexible then please + tell us your ideas by `raising an issue to request the feature `_! + + +Whilst xarray was originally designed specifically to open ``netCDF4`` files as :py:class:`numpy.ndarray` objects labelled by :py:class:`pandas.Index` objects, +it is entirely possible today to: + +- lazily open an xarray object directly from a custom binary file format (e.g. using ``xarray.open_dataset(path, engine='my_custom_format')``, +- handle the data as any API-compliant numpy-like array type (e.g. sparse or GPU-backed), +- distribute out-of-core computation across that array type in parallel (e.g. via :ref:`dask`), +- track the physical units of the data through computations (e.g via `pint-xarray `_), +- query the data via custom index logic optimized for specific applications (e.g. an :py:class:`~xarray.Index` object backed by a KDTree structure), +- attach domain-specific logic via accessor methods (e.g. to understand geographic Coordinate Reference System metadata), +- organize hierarchical groups of xarray data in a :py:class:`~datatree.DataTree` (e.g. to treat heterogenous simulation and observational data together during analysis). + +All of these features can be provided simultaneously, using libaries compatible with the rest of the scientific python ecosystem. +In this situation xarray would be essentially a thin wrapper acting as pure-python framework, providing a common interface and +separation of concerns via various domain-agnostic abstractions. + +Most of the remaining pages in the documentation of xarray's internals describe these various types of interoperability in more detail. diff --git a/doc/internals/variable-objects.rst b/doc/internals/variable-objects.rst deleted file mode 100644 index 6ae3c2f7e6d..00000000000 --- a/doc/internals/variable-objects.rst +++ /dev/null @@ -1,31 +0,0 @@ -Variable objects -================ - -The core internal data structure in xarray is the :py:class:`~xarray.Variable`, -which is used as the basic building block behind xarray's -:py:class:`~xarray.Dataset` and :py:class:`~xarray.DataArray` types. A -``Variable`` consists of: - -- ``dims``: A tuple of dimension names. -- ``data``: The N-dimensional array (typically, a NumPy or Dask array) storing - the Variable's data. It must have the same number of dimensions as the length - of ``dims``. -- ``attrs``: An ordered dictionary of metadata associated with this array. By - convention, xarray's built-in operations never use this metadata. -- ``encoding``: Another ordered dictionary used to store information about how - these variable's data is represented on disk. See :ref:`io.encoding` for more - details. - -``Variable`` has an interface similar to NumPy arrays, but extended to make use -of named dimensions. For example, it uses ``dim`` in preference to an ``axis`` -argument for methods like ``mean``, and supports :ref:`compute.broadcasting`. - -However, unlike ``Dataset`` and ``DataArray``, the basic ``Variable`` does not -include coordinate labels along each axis. - -``Variable`` is public API, but because of its incomplete support for labeled -data, it is mostly intended for advanced uses, such as in xarray itself or for -writing new backends. You can access the variable objects that correspond to -xarray objects via the (readonly) :py:attr:`Dataset.variables -` and -:py:attr:`DataArray.variable ` attributes. diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index dc495b9f285..9656a2ba973 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -115,10 +115,7 @@ you try to perform some sort of actual computation. For an example of how these lazy arrays work, see the OPeNDAP section below. There may be minor differences in the :py:class:`Dataset` object returned -when reading a NetCDF file with different engines. For example, -single-valued attributes are returned as scalars by the default -``engine=netcdf4``, but as arrays of size ``(1,)`` when reading with -``engine=h5netcdf``. +when reading a NetCDF file with different engines. It is important to note that when you modify values of a Dataset, even one linked to files on disk, only the in-memory copy you are manipulating in xarray @@ -263,12 +260,12 @@ Note that all operations that manipulate variables other than indexing will remove encoding information. In some cases it is useful to intentionally reset a dataset's original encoding values. -This can be done with either the :py:meth:`Dataset.reset_encoding` or -:py:meth:`DataArray.reset_encoding` methods. +This can be done with either the :py:meth:`Dataset.drop_encoding` or +:py:meth:`DataArray.drop_encoding` methods. .. ipython:: python - ds_no_encoding = ds_disk.reset_encoding() + ds_no_encoding = ds_disk.drop_encoding() ds_no_encoding.encoding .. _combining multiple files: @@ -559,6 +556,67 @@ and currently raises a warning unless ``invalid_netcdf=True`` is set: Note that this produces a file that is likely to be not readable by other netCDF libraries! +.. _io.hdf5: + +HDF5 +---- +`HDF5`_ is both a file format and a data model for storing information. HDF5 stores +data hierarchically, using groups to create a nested structure. HDF5 is a more +general version of the netCDF4 data model, so the nested structure is one of many +similarities between the two data formats. + +Reading HDF5 files in xarray requires the ``h5netcdf`` engine, which can be installed +with ``conda install h5netcdf``. Once installed we can use xarray to open HDF5 files: + +.. code:: python + + xr.open_dataset("/path/to/my/file.h5") + +The similarities between HDF5 and netCDF4 mean that HDF5 data can be written with the +same :py:meth:`Dataset.to_netcdf` method as used for netCDF4 data: + +.. ipython:: python + + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.rand(4, 5))}, + coords={ + "x": [10, 20, 30, 40], + "y": pd.date_range("2000-01-01", periods=5), + "z": ("x", list("abcd")), + }, + ) + + ds.to_netcdf("saved_on_disk.h5") + +Groups +~~~~~~ + +If you have multiple or highly nested groups, xarray by default may not read the group +that you want. A particular group of an HDF5 file can be specified using the ``group`` +argument: + +.. code:: python + + xr.open_dataset("/path/to/my/file.h5", group="/my/group") + +While xarray cannot interrogate an HDF5 file to determine which groups are available, +the HDF5 Python reader `h5py`_ can be used instead. + +Natively the xarray data structures can only handle one level of nesting, organized as +DataArrays inside of Datasets. If your HDF5 file has additional levels of hierarchy you +can only access one group and a time and will need to specify group names. + +.. note:: + + For native handling of multiple HDF5 groups with xarray, including I/O, you might be + interested in the experimental + `xarray-datatree `_ package. + + +.. _HDF5: https://hdfgroup.github.io/hdf5/index.html +.. _h5py: https://www.h5py.org/ + + .. _io.zarr: Zarr @@ -761,7 +819,7 @@ with ``mode='a'`` on a Dataset containing the new variables, passing in an existing Zarr store or path to a Zarr store. To resize and then append values along an existing dimension in a store, set -``append_dim``. This is a good option if data always arives in a particular +``append_dim``. This is a good option if data always arrives in a particular order, e.g., for time-stepping a simulation: .. ipython:: python diff --git a/doc/user-guide/reshaping.rst b/doc/user-guide/reshaping.rst index 2281106e7ec..d0b72322218 100644 --- a/doc/user-guide/reshaping.rst +++ b/doc/user-guide/reshaping.rst @@ -274,7 +274,7 @@ Sort ---- One may sort a DataArray/Dataset via :py:meth:`~xarray.DataArray.sortby` and -:py:meth:`~xarray.DataArray.sortby`. The input can be an individual or list of +:py:meth:`~xarray.Dataset.sortby`. The input can be an individual or list of 1D ``DataArray`` objects: .. ipython:: python diff --git a/doc/user-guide/terminology.rst b/doc/user-guide/terminology.rst index 24e6ab69927..ce7e55546a4 100644 --- a/doc/user-guide/terminology.rst +++ b/doc/user-guide/terminology.rst @@ -54,23 +54,22 @@ complete examples, please consult the relevant documentation.* Coordinate An array that labels a dimension or set of dimensions of another ``DataArray``. In the usual one-dimensional case, the coordinate array's - values can loosely be thought of as tick labels along a dimension. There - are two types of coordinate arrays: *dimension coordinates* and - *non-dimension coordinates* (see below). A coordinate named ``x`` can be - retrieved from ``arr.coords[x]``. A ``DataArray`` can have more - coordinates than dimensions because a single dimension can be labeled by - multiple coordinate arrays. However, only one coordinate array can be a - assigned as a particular dimension's dimension coordinate array. As a + values can loosely be thought of as tick labels along a dimension. We + distinguish :term:`Dimension coordinate` vs. :term:`Non-dimension + coordinate` and :term:`Indexed coordinate` vs. :term:`Non-indexed + coordinate`. A coordinate named ``x`` can be retrieved from + ``arr.coords[x]``. A ``DataArray`` can have more coordinates than + dimensions because a single dimension can be labeled by multiple + coordinate arrays. However, only one coordinate array can be a assigned + as a particular dimension's dimension coordinate array. As a consequence, ``len(arr.dims) <= len(arr.coords)`` in general. Dimension coordinate A one-dimensional coordinate array assigned to ``arr`` with both a name - and dimension name in ``arr.dims``. Dimension coordinates are used for - label-based indexing and alignment, like the index found on a - :py:class:`pandas.DataFrame` or :py:class:`pandas.Series`. In fact, - dimension coordinates use :py:class:`pandas.Index` objects under the - hood for efficient computation. Dimension coordinates are marked by - ``*`` when printing a ``DataArray`` or ``Dataset``. + and dimension name in ``arr.dims``. Usually (but not always), a + dimension coordinate is also an :term:`Indexed coordinate` so that it can + be used for label-based indexing and alignment, like the index found on + a :py:class:`pandas.DataFrame` or :py:class:`pandas.Series`. Non-dimension coordinate A coordinate array assigned to ``arr`` with a name in ``arr.coords`` but @@ -79,20 +78,40 @@ complete examples, please consult the relevant documentation.* example, multidimensional coordinates are often used in geoscience datasets when :doc:`the data's physical coordinates (such as latitude and longitude) differ from their logical coordinates - <../examples/multidimensional-coords>`. However, non-dimension coordinates - are not indexed, and any operation on non-dimension coordinates that - leverages indexing will fail. Printing ``arr.coords`` will print all of - ``arr``'s coordinate names, with the corresponding dimension(s) in - parentheses. For example, ``coord_name (dim_name) 1 2 3 ...``. + <../examples/multidimensional-coords>`. Printing ``arr.coords`` will + print all of ``arr``'s coordinate names, with the corresponding + dimension(s) in parentheses. For example, ``coord_name (dim_name) 1 2 3 + ...``. + + Indexed coordinate + A coordinate which has an associated :term:`Index`. Generally this means + that the coordinate labels can be used for indexing (selection) and/or + alignment. An indexed coordinate may have one or more arbitrary + dimensions although in most cases it is also a :term:`Dimension + coordinate`. It may or may not be grouped with other indexed coordinates + depending on whether they share the same index. Indexed coordinates are + marked by an asterisk ``*`` when printing a ``DataArray`` or ``Dataset``. + + Non-indexed coordinate + A coordinate which has no associated :term:`Index`. It may still + represent fixed labels along one or more dimensions but it cannot be + used for label-based indexing and alignment. Index - An *index* is a data structure optimized for efficient selecting and - slicing of an associated array. Xarray creates indexes for dimension - coordinates so that operations along dimensions are fast, while - non-dimension coordinates are not indexed. Under the hood, indexes are - implemented as :py:class:`pandas.Index` objects. The index associated - with dimension name ``x`` can be retrieved by ``arr.indexes[x]``. By - construction, ``len(arr.dims) == len(arr.indexes)`` + An *index* is a data structure optimized for efficient data selection + and alignment within a discrete or continuous space that is defined by + coordinate labels (unless it is a functional index). By default, Xarray + creates a :py:class:`~xarray.indexes.PandasIndex` object (i.e., a + :py:class:`pandas.Index` wrapper) for each :term:`Dimension coordinate`. + For more advanced use cases (e.g., staggered or irregular grids, + geospatial indexes), Xarray also accepts any instance of a specialized + :py:class:`~xarray.indexes.Index` subclass that is associated to one or + more arbitrary coordinates. The index associated with the coordinate + ``x`` can be retrieved by ``arr.xindexes[x]`` (or ``arr.indexes["x"]`` + if the index is convertible to a :py:class:`pandas.Index` object). If + two coordinates ``x`` and ``y`` share the same index, + ``arr.xindexes[x]`` and ``arr.xindexes[y]`` both return the same + :py:class:`~xarray.indexes.Index` object. name The names of dimensions, coordinates, DataArray objects and data @@ -112,3 +131,128 @@ complete examples, please consult the relevant documentation.* ``__array_ufunc__`` and ``__array_function__`` protocols are also required. __ https://numpy.org/neps/nep-0022-ndarray-duck-typing-overview.html + + .. ipython:: python + :suppress: + + import numpy as np + import xarray as xr + + Aligning + Aligning refers to the process of ensuring that two or more DataArrays or Datasets + have the same dimensions and coordinates, so that they can be combined or compared properly. + + .. ipython:: python + + x = xr.DataArray( + [[25, 35], [10, 24]], + dims=("lat", "lon"), + coords={"lat": [35.0, 40.0], "lon": [100.0, 120.0]}, + ) + y = xr.DataArray( + [[20, 5], [7, 13]], + dims=("lat", "lon"), + coords={"lat": [35.0, 42.0], "lon": [100.0, 120.0]}, + ) + x + y + + Broadcasting + A technique that allows operations to be performed on arrays with different shapes and dimensions. + When performing operations on arrays with different shapes and dimensions, xarray will automatically attempt to broadcast the + arrays to a common shape before the operation is applied. + + .. ipython:: python + + # 'a' has shape (3,) and 'b' has shape (4,) + a = xr.DataArray(np.array([1, 2, 3]), dims=["x"]) + b = xr.DataArray(np.array([4, 5, 6, 7]), dims=["y"]) + + # 2D array with shape (3, 4) + a + b + + Merging + Merging is used to combine two or more Datasets or DataArrays that have different variables or coordinates along + the same dimensions. When merging, xarray aligns the variables and coordinates of the different datasets along + the specified dimensions and creates a new ``Dataset`` containing all the variables and coordinates. + + .. ipython:: python + + # create two 1D arrays with names + arr1 = xr.DataArray( + [1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]}, name="arr1" + ) + arr2 = xr.DataArray( + [4, 5, 6], dims=["x"], coords={"x": [20, 30, 40]}, name="arr2" + ) + + # merge the two arrays into a new dataset + merged_ds = xr.Dataset({"arr1": arr1, "arr2": arr2}) + merged_ds + + Concatenating + Concatenating is used to combine two or more Datasets or DataArrays along a dimension. When concatenating, + xarray arranges the datasets or dataarrays along a new dimension, and the resulting ``Dataset`` or ``Dataarray`` + will have the same variables and coordinates along the other dimensions. + + .. ipython:: python + + a = xr.DataArray([[1, 2], [3, 4]], dims=("x", "y")) + b = xr.DataArray([[5, 6], [7, 8]], dims=("x", "y")) + c = xr.concat([a, b], dim="c") + c + + Combining + Combining is the process of arranging two or more DataArrays or Datasets into a single ``DataArray`` or + ``Dataset`` using some combination of merging and concatenation operations. + + .. ipython:: python + + ds1 = xr.Dataset( + {"data": xr.DataArray([[1, 2], [3, 4]], dims=("x", "y"))}, + coords={"x": [1, 2], "y": [3, 4]}, + ) + ds2 = xr.Dataset( + {"data": xr.DataArray([[5, 6], [7, 8]], dims=("x", "y"))}, + coords={"x": [2, 3], "y": [4, 5]}, + ) + + # combine the datasets + combined_ds = xr.combine_by_coords([ds1, ds2]) + combined_ds + + lazy + Lazily-evaluated operations do not load data into memory until necessary.Instead of doing calculations + right away, xarray lets you plan what calculations you want to do, like finding the + average temperature in a dataset.This planning is called "lazy evaluation." Later, when + you're ready to see the final result, you tell xarray, "Okay, go ahead and do those calculations now!" + That's when xarray starts working through the steps you planned and gives you the answer you wanted.This + lazy approach helps save time and memory because xarray only does the work when you actually need the + results. + + labeled + Labeled data has metadata describing the context of the data, not just the raw data values. + This contextual information can be labels for array axes (i.e. dimension names) tick labels along axes (stored as Coordinate variables) or unique names for each array. These labels + provide context and meaning to the data, making it easier to understand and work with. If you have + temperature data for different cities over time. Using xarray, you can label the dimensions: one for + cities and another for time. + + serialization + Serialization is the process of converting your data into a format that makes it easy to save and share. + When you serialize data in xarray, you're taking all those temperature measurements, along with their + labels and other information, and turning them into a format that can be stored in a file or sent over + the internet. xarray objects can be serialized into formats which store the labels alongside the data. + Some supported serialization formats are files that can then be stored or transferred (e.g. netCDF), + whilst others are protocols that allow for data access over a network (e.g. Zarr). + + indexing + :ref:`Indexing` is how you select subsets of your data which you are interested in. + + - Label-based Indexing: Selecting data by passing a specific label and comparing it to the labels + stored in the associated coordinates. You can use labels to specify what you want like "Give me the + temperature for New York on July 15th." + + - Positional Indexing: You can use numbers to refer to positions in the data like "Give me the third temperature value" This is useful when you know the order of your data but don't need to remember the exact labels. + + - Slicing: You can take a "slice" of your data, like you might want all temperatures from July 1st + to July 10th. xarray supports slicing for both positional and label-based indexing. diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1fbe1a75ee2..bcf05a2bb72 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,9 +14,9 @@ What's New np.random.seed(123456) -.. _whats-new.2023.07.1: +.. _whats-new.2023.10.2: -v2023.07.1 (unreleased) +v2023.10.2 (unreleased) ----------------------- New Features @@ -35,6 +35,8 @@ New Features (:issue:`2233`, :pull:`7989`) By `Deepak Cherian `_ and `Benoit Bovy `_. +- Use `opt_einsum `_ for :py:func:`xarray.dot` by default if installed. + By `Deepak Cherian `_. (:issue:`7764`, :pull:`8373`). Breaking changes ~~~~~~~~~~~~~~~~ @@ -43,23 +45,352 @@ Breaking changes Deprecations ~~~~~~~~~~~~ +- Supplying dimension-ordered sequences to :py:meth:`DataArray.chunk` & + :py:meth:`Dataset.chunk` is deprecated in favor of supplying a dictionary of + dimensions, or a single ``int`` or ``"auto"`` argument covering all + dimensions. Xarray favors using dimensions names rather than positions, and + this was one place in the API where dimension positions were used. + (:pull:`8341`) + By `Maximilian Roos `_. + +Bug fixes +~~~~~~~~~ + +- Port `bug fix from pandas `_ + to eliminate the adjustment of resample bin edges in the case that the + resampling frequency has units of days and is greater than one day + (e.g. ``"2D"``, ``"3D"`` etc.) and the ``closed`` argument is set to + ``"right"`` to xarray's implementation of resample for data indexed by a + :py:class:`CFTimeIndex` (:pull:`8393`). + By `Spencer Clark `_. + +Documentation +~~~~~~~~~~~~~ + + +Internal Changes +~~~~~~~~~~~~~~~~ + + +.. _whats-new.2023.10.1: + +v2023.10.1 (19 Oct, 2023) +------------------------- + +This release updates our minimum numpy version in ``pyproject.toml`` to 1.22, +consistent with our documentation below. + +.. _whats-new.2023.10.0: + +v2023.10.0 (19 Oct, 2023) +------------------------- + +This release brings performance enhancements to reading Zarr datasets, the ability to use `numbagg `_ for reductions, +an expansion in API for ``rolling_exp``, fixes two regressions with datetime decoding, +and many other bugfixes and improvements. Groupby reductions will also use ``numbagg`` if ``flox>=0.8.1`` and ``numbagg`` are both installed. + +Thanks to our 13 contributors: +Anderson Banihirwe, Bart Schilperoort, Deepak Cherian, Illviljan, Kai Mühlbauer, Mathias Hauser, Maximilian Roos, Michael Niklas, Pieter Eendebak, Simon Høxbro Hansen, Spencer Clark, Tom White, olimcc + +New Features +~~~~~~~~~~~~ +- Support high-performance reductions with `numbagg `_. + This is enabled by default if ``numbagg`` is installed. + By `Deepak Cherian `_. (:pull:`8316`) +- Add ``corr``, ``cov``, ``std`` & ``var`` to ``.rolling_exp``. + By `Maximilian Roos `_. (:pull:`8307`) +- :py:meth:`DataArray.where` & :py:meth:`Dataset.where` accept a callable for + the ``other`` parameter, passing the object as the only argument. Previously, + this was only valid for the ``cond`` parameter. (:issue:`8255`) + By `Maximilian Roos `_. +- ``.rolling_exp`` functions can now take a ``min_weight`` parameter, to only + output values when there are sufficient recent non-nan values. + ``numbagg>=0.3.1`` is required. (:pull:`8285`) + By `Maximilian Roos `_. +- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for + the ``variables`` parameter, passing the object as the only argument. + By `Maximilian Roos `_. +- ``.rolling_exp`` functions can now operate on dask-backed arrays, assuming the + core dim has exactly one chunk. (:pull:`8284`). + By `Maximilian Roos `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Made more arguments keyword-only (e.g. ``keep_attrs``, ``skipna``) for many :py:class:`xarray.DataArray` and + :py:class:`xarray.Dataset` methods (:pull:`6403`). By `Mathias Hauser `_. +- :py:meth:`Dataset.to_zarr` & :py:meth:`DataArray.to_zarr` require keyword + arguments after the initial 7 positional arguments. + By `Maximilian Roos `_. + + +Deprecations +~~~~~~~~~~~~ +- Rename :py:meth:`Dataset.reset_encoding` & :py:meth:`DataArray.reset_encoding` + to :py:meth:`Dataset.drop_encoding` & :py:meth:`DataArray.drop_encoding` for + consistency with other ``drop`` & ``reset`` methods — ``drop`` generally + removes something, while ``reset`` generally resets to some default or + standard value. (:pull:`8287`, :issue:`8259`) + By `Maximilian Roos `_. + +Bug fixes +~~~~~~~~~ + +- :py:meth:`DataArray.rename` & :py:meth:`Dataset.rename` would emit a warning + when the operation was a no-op. (:issue:`8266`) + By `Simon Hansen `_. +- Fixed a regression introduced in the previous release checking time-like units + when encoding/decoding masked data (:issue:`8269`, :pull:`8277`). + By `Kai Mühlbauer `_. + +- Fix datetime encoding precision loss regression introduced in the previous + release for datetimes encoded with units requiring floating point values, and + a reference date not equal to the first value of the datetime array + (:issue:`8271`, :pull:`8272`). By `Spencer Clark + `_. + +- Fix excess metadata requests when using a Zarr store. Prior to this, metadata + was re-read every time data was retrieved from the array, now metadata is retrieved only once + when they array is initialized. + (:issue:`8290`, :pull:`8297`). + By `Oliver McCormack `_. + + +Documentation +~~~~~~~~~~~~~ + +- Added page on the interoperability of xarray objects. + (:pull:`7992`) By `Tom Nicholas `_. +- Added xarray-regrid to the list of xarray related projects (:pull:`8272`). + By `Bart Schilperoort `_. + + +Internal Changes +~~~~~~~~~~~~~~~~ + +- More improvements to support the Python `array API standard `_ + by using duck array ops in more places in the codebase. (:pull:`8267`) + By `Tom White `_. + + +.. _whats-new.2023.09.0: + +v2023.09.0 (Sep 26, 2023) +------------------------- + +This release continues work on the new :py:class:`xarray.Coordinates` object, allows to provide `preferred_chunks` when +reading from netcdf files, enables :py:func:`xarray.apply_ufunc` to handle missing core dimensions and fixes several bugs. + +Thanks to the 24 contributors to this release: Alexander Fischer, Amrest Chinkamol, Benoit Bovy, Darsh Ranjan, Deepak Cherian, +Gianfranco Costamagna, Gregorio L. Trevisan, Illviljan, Joe Hamman, JR, Justus Magin, Kai Mühlbauer, Kian-Meng Ang, Kyle Sunden, +Martin Raspaud, Mathias Hauser, Mattia Almansi, Maximilian Roos, András Gunyhó, Michael Niklas, Richard Kleijn, Riulinchen, +Tom Nicholas and Wiktor Kraśnicki. + +We welcome the following new contributors to Xarray!: Alexander Fischer, Amrest Chinkamol, Darsh Ranjan, Gianfranco Costamagna, Gregorio L. Trevisan, +Kian-Meng Ang, Riulinchen and Wiktor Kraśnicki. + +New Features +~~~~~~~~~~~~ + +- Added the :py:meth:`Coordinates.assign` method that can be used to combine + different collections of coordinates prior to assign them to a Dataset or + DataArray (:pull:`8102`) at once. + By `Benoît Bovy `_. +- Provide `preferred_chunks` for data read from netcdf files (:issue:`1440`, :pull:`7948`). + By `Martin Raspaud `_. +- Added `on_missing_core_dims` to :py:meth:`apply_ufunc` to allow for copying or + dropping a :py:class:`Dataset`'s variables with missing core dimensions (:pull:`8138`). + By `Maximilian Roos `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The :py:class:`Coordinates` constructor now creates a (pandas) index by + default for each dimension coordinate. To keep the previous behavior (no index + created), pass an empty dictionary to ``indexes``. The constructor now also + extracts and add the indexes from another :py:class:`Coordinates` object + passed via ``coords`` (:pull:`8107`). + By `Benoît Bovy `_. +- Static typing of ``xlim`` and ``ylim`` arguments in plotting functions now must + be ``tuple[float, float]`` to align with matplotlib requirements. (:issue:`7802`, :pull:`8030`). + By `Michael Niklas `_. + +Deprecations +~~~~~~~~~~~~ + +- Deprecate passing a :py:class:`pandas.MultiIndex` object directly to the + :py:class:`Dataset` and :py:class:`DataArray` constructors as well as to + :py:meth:`Dataset.assign` and :py:meth:`Dataset.assign_coords`. + A new Xarray :py:class:`Coordinates` object has to be created first using + :py:meth:`Coordinates.from_pandas_multiindex` (:pull:`8094`). + By `Benoît Bovy `_. Bug fixes ~~~~~~~~~ +- Improved static typing of reduction methods (:pull:`6746`). + By `Richard Kleijn `_. +- Fix bug where empty attrs would generate inconsistent tokens (:issue:`6970`, :pull:`8101`). + By `Mattia Almansi `_. +- Improved handling of multi-coordinate indexes when updating coordinates, including bug fixes + (and improved warnings for deprecated features) for pandas multi-indexes (:pull:`8094`). + By `Benoît Bovy `_. +- Fixed a bug in :py:func:`merge` with ``compat='minimal'`` where the coordinate + names were not updated properly internally (:issue:`7405`, :issue:`7588`, + :pull:`8104`). + By `Benoît Bovy `_. +- Fix bug where :py:class:`DataArray` instances on the right-hand side + of :py:meth:`DataArray.__setitem__` lose dimension names (:issue:`7030`, :pull:`8067`). + By `Darsh Ranjan `_. +- Return ``float64`` in presence of ``NaT`` in :py:class:`~core.accessor_dt.DatetimeAccessor` and + special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar` + (:issue:`7928`, :pull:`8084`). + By `Kai Mühlbauer `_. +- Fix :py:meth:`~core.rolling.DatasetRolling.construct` with stride on Datasets without indexes. + (:issue:`7021`, :pull:`7578`). + By `Amrest Chinkamol `_ and `Michael Niklas `_. +- Calling plot with kwargs ``col``, ``row`` or ``hue`` no longer squeezes dimensions passed via these arguments + (:issue:`7552`, :pull:`8174`). + By `Wiktor Kraśnicki `_. +- Fixed a bug where casting from ``float`` to ``int64`` (undefined for ``NaN``) led to varying issues (:issue:`7817`, :issue:`7942`, :issue:`7790`, :issue:`6191`, :issue:`7096`, + :issue:`1064`, :pull:`7827`). + By `Kai Mühlbauer `_. +- Fixed a bug where inaccurate ``coordinates`` silently failed to decode variable (:issue:`1809`, :pull:`8195`). + By `Kai Mühlbauer `_ +- ``.rolling_exp`` functions no longer mistakenly lose non-dimensioned coords + (:issue:`6528`, :pull:`8114`). + By `Maximilian Roos `_. +- In the event that user-provided datetime64/timedelta64 units and integer dtype encoding parameters conflict with each other, override the units to preserve an integer dtype for most faithful serialization to disk (:issue:`1064`, :pull:`8201`). + By `Kai Mühlbauer `_. +- Static typing of dunder ops methods (like :py:meth:`DataArray.__eq__`) has been fixed. + Remaining issues are upstream problems (:issue:`7780`, :pull:`8204`). + By `Michael Niklas `_. +- Fix type annotation for ``center`` argument of plotting methods (like :py:meth:`xarray.plot.dataarray_plot.pcolormesh`) (:pull:`8261`). + By `Pieter Eendebak `_. + +Documentation +~~~~~~~~~~~~~ + +- Make documentation of :py:meth:`DataArray.where` clearer (:issue:`7767`, :pull:`7955`). + By `Riulinchen `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Many error messages related to invalid dimensions or coordinates now always show the list of valid dims/coords (:pull:`8079`). + By `András Gunyhó `_. +- Refactor of encoding and decoding times/timedeltas to preserve nanosecond resolution in arrays that contain missing values (:pull:`7827`). + By `Kai Mühlbauer `_. +- Transition ``.rolling_exp`` functions to use `.apply_ufunc` internally rather + than `.reduce`, as the start of a broader effort to move non-reducing + functions away from ```.reduce``, (:pull:`8114`). + By `Maximilian Roos `_. +- Test range of fill_value's in test_interpolate_pd_compat (:issue:`8146`, :pull:`8189`). + By `Kai Mühlbauer `_. + +.. _whats-new.2023.08.0: + +v2023.08.0 (Aug 18, 2023) +------------------------- + +This release brings changes to minimum dependencies, allows reading of datasets where a dimension name is +associated with a multidimensional variable (e.g. finite volume ocean model output), and introduces +a new :py:class:`xarray.Coordinates` object. + +Thanks to the 16 contributors to this release: Anderson Banihirwe, Articoking, Benoit Bovy, Deepak Cherian, Harshitha, Ian Carroll, +Joe Hamman, Justus Magin, Peter Hill, Rachel Wegener, Riley Kuttruff, Thomas Nicholas, Tom Nicholas, ilgast, quantsnus, vallirep + +Announcements +~~~~~~~~~~~~~ + +The :py:class:`xarray.Variable` class is being refactored out to a new project title 'namedarray'. +See the `design doc `_ for more +details. Reach out to us on this [discussion topic](https://github.com/pydata/xarray/discussions/8080) if you have any thoughts. + +New Features +~~~~~~~~~~~~ + +- :py:class:`Coordinates` can now be constructed independently of any Dataset or + DataArray (it is also returned by the :py:attr:`Dataset.coords` and + :py:attr:`DataArray.coords` properties). ``Coordinates`` objects are useful for + passing both coordinate variables and indexes to new Dataset / DataArray objects, + e.g., via their constructor or via :py:meth:`Dataset.assign_coords`. We may also + wrap coordinate variables in a ``Coordinates`` object in order to skip + the automatic creation of (pandas) indexes for dimension coordinates. + The :py:class:`Coordinates.from_pandas_multiindex` constructor may be used to + create coordinates directly from a :py:class:`pandas.MultiIndex` object (it is + preferred over passing it directly as coordinate data, which may be deprecated soon). + Like Dataset and DataArray objects, ``Coordinates`` objects may now be used in + :py:func:`align` and :py:func:`merge`. + (:issue:`6392`, :pull:`7368`). + By `Benoît Bovy `_. +- Visually group together coordinates with the same indexes in the index section of the text repr (:pull:`7225`). + By `Justus Magin `_. +- Allow creating Xarray objects where a multidimensional variable shares its name + with a dimension. Examples include output from finite volume models like FVCOM. + (:issue:`2233`, :pull:`7989`) + By `Deepak Cherian `_ and `Benoit Bovy `_. +- When outputting :py:class:`Dataset` objects as Zarr via :py:meth:`Dataset.to_zarr`, + user can now specify that chunks that will contain no valid data will not be written. + Originally, this could be done by specifying ``"write_empty_chunks": True`` in the + ``encoding`` parameter; however, this setting would not carry over when appending new + data to an existing dataset. (:issue:`8009`) Requires ``zarr>=2.11``. + + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The minimum versions of some dependencies were changed (:pull:`8022`): + + ===================== ========= ======== + Package Old New + ===================== ========= ======== + boto3 1.20 1.24 + cftime 1.5 1.6 + dask-core 2022.1 2022.7 + distributed 2022.1 2022.7 + hfnetcdf 0.13 1.0 + iris 3.1 3.2 + lxml 4.7 4.9 + netcdf4 1.5.7 1.6.0 + numpy 1.21 1.22 + pint 0.18 0.19 + pydap 3.2 3.3 + rasterio 1.2 1.3 + scipy 1.7 1.8 + toolz 0.11 0.12 + typing_extensions 4.0 4.3 + zarr 2.10 2.12 + numbagg 0.1 0.2.1 + ===================== ========= ======== Documentation ~~~~~~~~~~~~~ +- Added page on the internal design of xarray objects. + (:pull:`7991`) By `Tom Nicholas `_. +- Added examples to docstrings of :py:meth:`Dataset.assign_attrs`, :py:meth:`Dataset.broadcast_equals`, + :py:meth:`Dataset.equals`, :py:meth:`Dataset.identical`, :py:meth:`Dataset.expand_dims`,:py:meth:`Dataset.drop_vars` + (:issue:`6793`, :pull:`7937`) By `Harshitha `_. - Add docstrings for the :py:class:`Index` base class and add some documentation on how to create custom, Xarray-compatible indexes (:pull:`6975`) By `Benoît Bovy `_. +- Added a page clarifying the role of Xarray core team members. + (:pull:`7999`) By `Tom Nicholas `_. +- Fixed broken links in "See also" section of :py:meth:`Dataset.count` (:issue:`8055`, :pull:`8057`) + By `Articoking `_. +- Extended the glossary by adding terms Aligning, Broadcasting, Merging, Concatenating, Combining, lazy, + labeled, serialization, indexing (:issue:`3355`, :pull:`7732`) + By `Harshitha `_. Internal Changes ~~~~~~~~~~~~~~~~ - :py:func:`as_variable` now consistently includes the variable name in any exceptions raised. (:pull:`7995`). By `Peter Hill `_ +- :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to + `coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`). + `By Ian Carroll `_. .. _whats-new.2023.07.0: @@ -83,9 +414,9 @@ Bug fixes Documentation ~~~~~~~~~~~~~ -- Added examples to docstrings of :py:meth:`Dataset.tail`, :py:meth:`Dataset.head`, :py:meth:`Dataset.dropna`, - :py:meth:`Dataset.ffill`, :py:meth:`Dataset.bfill`, :py:meth:`Dataset.set_coords`, :py:meth:`Dataset.reset_coords` - (:issue:`6793`, :pull:`7936`) By `Harshitha `_ . +- Added examples to docstrings of :py:meth:`Dataset.assign_attrs`, :py:meth:`Dataset.broadcast_equals`, + :py:meth:`Dataset.equals`, :py:meth:`Dataset.identical`, :py:meth:`Dataset.expand_dims`,:py:meth:`Dataset.drop_vars` + (:issue:`6793`, :pull:`7937`) By `Harshitha `_. - Added page on wrapping chunked numpy-like arrays as alternatives to dask arrays. (:pull:`7951`) By `Tom Nicholas `_. - Expanded the page on wrapping numpy-like "duck" arrays. @@ -775,7 +1106,7 @@ Bug fixes By `Michael Niklas `_. - Fix side effects on index coordinate metadata after aligning objects. (:issue:`6852`, :pull:`6857`) By `Benoît Bovy `_. -- Make FacetGrid.set_titles send kwargs correctly using `handle.udpate(kwargs)`. (:issue:`6839`, :pull:`6843`) +- Make FacetGrid.set_titles send kwargs correctly using `handle.update(kwargs)`. (:issue:`6839`, :pull:`6843`) By `Oliver Lopez `_. - Fix bug where index variables would be changed inplace. (:issue:`6931`, :pull:`6938`) By `Michael Niklas `_. @@ -4638,7 +4969,7 @@ Bug fixes - Corrected a bug with incorrect coordinates for non-georeferenced geotiff files (:issue:`1686`). Internally, we now use the rasterio coordinate transform tool instead of doing the computations ourselves. A - ``parse_coordinates`` kwarg has beed added to :py:func:`~open_rasterio` + ``parse_coordinates`` kwarg has been added to :py:func:`~open_rasterio` (set to ``True`` per default). By `Fabien Maussion `_. - The colors of discrete colormaps are now the same regardless if `seaborn` diff --git a/pyproject.toml b/pyproject.toml index 4d63fd564ba..b16063e0370 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,49 @@ +[project] +authors = [ + {name = "xarray Developers", email = "xarray@googlegroups.com"}, +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Intended Audience :: Science/Research", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", +] +description = "N-D labeled arrays and datasets in Python" +dynamic = ["version"] +license = {text = "Apache-2.0"} +name = "xarray" +readme = "README.md" +requires-python = ">=3.9" + +dependencies = [ + "numpy>=1.22", + "packaging>=21.3", + "pandas>=1.4", +] + +[project.urls] +Documentation = "https://docs.xarray.dev" +SciPy2015-talk = "https://www.youtube.com/watch?v=X0pAhJgySxk" +homepage = "https://xarray.dev/" +issue-tracker = "https://github.com/pydata/xarray/issues" +source-code = "https://github.com/pydata/xarray" + +[project.entry-points."xarray.chunkmanagers"] +dask = "xarray.core.daskmanager:DaskManager" + +[project.optional-dependencies] +accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"] +complete = ["xarray[accel,io,parallel,viz]"] +io = ["netCDF4", "h5netcdf", "scipy", 'pydap; python_version<"3.10"', "zarr", "fsspec", "cftime", "pooch"] +parallel = ["dask[complete]"] +viz = ["matplotlib", "seaborn", "nc-time-axis"] + [build-system] build-backend = "setuptools.build_meta" requires = [ @@ -5,8 +51,11 @@ requires = [ "setuptools-scm>=7", ] +[tool.setuptools] +packages = ["xarray"] + [tool.setuptools_scm] -fallback_version = "999" +fallback_version = "9999" [tool.coverage.run] omit = [ @@ -23,12 +72,16 @@ source = ["xarray"] exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"] [tool.mypy] +enable_error_code = "redundant-self" exclude = 'xarray/util/generate_.*\.py' files = "xarray" show_error_codes = true +show_error_context = true +warn_redundant_casts = true +warn_unused_configs = true warn_unused_ignores = true -# Most of the numerical computing stack doesn't have type annotations yet. +# Much of the numerical computing stack doesn't have type annotations yet. [[tool.mypy.overrides]] ignore_missing_imports = true module = [ @@ -41,10 +94,10 @@ module = [ "cftime.*", "cubed.*", "cupy.*", + "dask.types.*", "fsspec.*", "h5netcdf.*", "h5py.*", - "importlib_metadata.*", "iris.*", "matplotlib.*", "mpl_toolkits.*", @@ -53,6 +106,7 @@ module = [ "numbagg.*", "netCDF4.*", "netcdftime.*", + "opt_einsum.*", "pandas.*", "pooch.*", "PseudoNetCDF.*", @@ -64,39 +118,154 @@ module = [ "sparse.*", "toolz.*", "zarr.*", + "numpy.exceptions.*", # remove once support for `numpy<2.0` has been dropped ] +# Gradually we want to add more modules to this list, ratcheting up our total +# coverage. Once a module is here, functions are checked by mypy regardless of +# whether they have type annotations. It would be especially useful to have test +# files listed here, because without them being checked, we don't have a great +# way of testing our annotations. [[tool.mypy.overrides]] -ignore_errors = true -module = [] +check_untyped_defs = true +module = [ + "xarray.core.accessor_dt", + "xarray.core.accessor_str", + "xarray.core.alignment", + "xarray.core.computation", + "xarray.core.rolling_exp", + "xarray.indexes.*", + "xarray.tests.*", +] +# This then excludes some modules from the above list. (So ideally we remove +# from here in time...) +[[tool.mypy.overrides]] +check_untyped_defs = false +module = [ + "xarray.tests.test_coarsen", + "xarray.tests.test_coding_times", + "xarray.tests.test_combine", + "xarray.tests.test_computation", + "xarray.tests.test_concat", + "xarray.tests.test_coordinates", + "xarray.tests.test_dask", + "xarray.tests.test_dataarray", + "xarray.tests.test_duck_array_ops", + "xarray.tests.test_groupby", + "xarray.tests.test_indexing", + "xarray.tests.test_merge", + "xarray.tests.test_missing", + "xarray.tests.test_parallelcompat", + "xarray.tests.test_plot", + "xarray.tests.test_sparse", + "xarray.tests.test_ufuncs", + "xarray.tests.test_units", + "xarray.tests.test_utils", + "xarray.tests.test_variable", + "xarray.tests.test_weighted", +] + +# Use strict = true whenever namedarray has become standalone. In the meantime +# don't forget to add all new files related to namedarray here: +# ref: https://mypy.readthedocs.io/en/stable/existing_code.html#introduce-stricter-options +[[tool.mypy.overrides]] +# Start off with these +warn_unused_ignores = true + +# Getting these passing should be easy +strict_concatenate = true +strict_equality = true + +# Strongly recommend enabling this one as soon as you can +check_untyped_defs = true + +# These shouldn't be too much additional work, but may be tricky to +# get passing if you use a lot of untyped libraries +disallow_any_generics = true +disallow_subclassing_any = true +disallow_untyped_decorators = true + +# These next few are various gradations of forcing use of type annotations +disallow_incomplete_defs = true +disallow_untyped_calls = true +disallow_untyped_defs = true + +# This one isn't too hard to get passing, but return on investment is lower +no_implicit_reexport = true + +# This one can be tricky to get passing if you use a lot of untyped libraries +warn_return_any = true + +module = ["xarray.namedarray.*", "xarray.tests.test_namedarray"] + +[tool.pyright] +# include = ["src"] +# exclude = ["**/node_modules", +# "**/__pycache__", +# "src/experimental", +# "src/typestubs" +# ] +# ignore = ["src/oldstuff"] +defineConstant = {DEBUG = true} +# stubPath = "src/stubs" +# venv = "env367" + +# Enabling this means that developers who have disabled the warning locally — +# because not all dependencies are installable — are overridden +# reportMissingImports = true +reportMissingTypeStubs = false + +# pythonVersion = "3.6" +# pythonPlatform = "Linux" + +# executionEnvironments = [ +# { root = "src/web", pythonVersion = "3.5", pythonPlatform = "Windows", extraPaths = [ "src/service_libs" ] }, +# { root = "src/sdk", pythonVersion = "3.0", extraPaths = [ "src/backend" ] }, +# { root = "src/tests", extraPaths = ["src/tests/e2e", "src/sdk" ]}, +# { root = "src" } +# ] [tool.ruff] -target-version = "py39" builtins = ["ellipsis"] exclude = [ - ".eggs", - "doc", - "_typed_ops.pyi", + ".eggs", + "doc", + "_typed_ops.pyi", ] +target-version = "py39" # E402: module level import not at top of file # E501: line too long - let black worry about that # E731: do not assign a lambda expression, use a def ignore = [ - "E402", - "E501", - "E731", + "E402", + "E501", + "E731", ] select = [ - # Pyflakes - "F", - # Pycodestyle - "E", - "W", - # isort - "I", - # Pyupgrade - "UP", + "F", # Pyflakes + "E", # Pycodestyle + "W", + "I", # isort + "UP", # Pyupgrade ] [tool.ruff.isort] known-first-party = ["xarray"] + +[tool.pytest.ini_options] +addopts = ["--strict-config", "--strict-markers"] +filterwarnings = [ + "ignore:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning", +] +log_cli_level = "INFO" +markers = [ + "flaky: flaky tests", + "network: tests requiring a network connection", + "slow: slow tests", +] +minversion = "7" +python_files = "test_*.py" +testpaths = ["xarray/tests", "properties"] + +[tool.aliases] +test = "pytest" diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 952710518d0..00000000000 --- a/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -# This file is redundant with setup.cfg; -# it exists to let GitHub build the repository dependency graph -# https://help.github.com/en/github/visualizing-repository-data-with-graphs/listing-the-packages-that-a-repository-depends-on - -numpy >= 1.21 -packaging >= 21.3 -pandas >= 1.4 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 85ac8e259e5..00000000000 --- a/setup.cfg +++ /dev/null @@ -1,154 +0,0 @@ -[metadata] -name = xarray -author = xarray Developers -author_email = xarray@googlegroups.com -license = Apache-2.0 -description = N-D labeled arrays and datasets in Python -long_description_content_type=text/x-rst -long_description = - **xarray** (formerly **xray**) is an open source project and Python package - that makes working with labelled multi-dimensional arrays simple, - efficient, and fun! - - xarray introduces labels in the form of dimensions, coordinates and - attributes on top of raw NumPy_-like arrays, which allows for a more - intuitive, more concise, and less error-prone developer experience. - The package includes a large and growing library of domain-agnostic functions - for advanced analytics and visualization with these data structures. - - xarray was inspired by and borrows heavily from pandas_, the popular data - analysis package focused on labelled tabular data. - It is particularly tailored to working with netCDF_ files, which were the - source of xarray's data model, and integrates tightly with dask_ for parallel - computing. - - .. _NumPy: https://www.numpy.org - .. _pandas: https://pandas.pydata.org - .. _dask: https://dask.org - .. _netCDF: https://www.unidata.ucar.edu/software/netcdf - - Why xarray? - ----------- - Multi-dimensional (a.k.a. N-dimensional, ND) arrays (sometimes called - "tensors") are an essential part of computational science. - They are encountered in a wide range of fields, including physics, astronomy, - geoscience, bioinformatics, engineering, finance, and deep learning. - In Python, NumPy_ provides the fundamental data structure and API for - working with raw ND arrays. - However, real-world datasets are usually more than just raw numbers; - they have labels which encode information about how the array values map - to locations in space, time, etc. - - xarray doesn't just keep track of labels on arrays -- it uses them to provide a - powerful and concise interface. For example: - - - Apply operations over dimensions by name: ``x.sum('time')``. - - Select values by label instead of integer location: ``x.loc['2014-01-01']`` or ``x.sel(time='2014-01-01')``. - - Mathematical operations (e.g., ``x - y``) vectorize across multiple dimensions (array broadcasting) based on dimension names, not shape. - - Flexible split-apply-combine operations with groupby: ``x.groupby('time.dayofyear').mean()``. - - Database like alignment based on coordinate labels that smoothly handles missing values: ``x, y = xr.align(x, y, join='outer')``. - - Keep track of arbitrary metadata in the form of a Python dictionary: ``x.attrs``. - - Learn more - ---------- - - Documentation: ``_ - - Issue tracker: ``_ - - Source code: ``_ - - SciPy2015 talk: ``_ - -url = https://github.com/pydata/xarray -classifiers = - Development Status :: 5 - Production/Stable - License :: OSI Approved :: Apache Software License - Operating System :: OS Independent - Intended Audience :: Science/Research - Programming Language :: Python - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.11 - Topic :: Scientific/Engineering - -[options] -packages = find: -zip_safe = False # https://mypy.readthedocs.io/en/latest/installed_packages.html -include_package_data = True -python_requires = >=3.9 -install_requires = - numpy >= 1.21 # recommended to use >= 1.22 for full quantile method support - pandas >= 1.4 - packaging >= 21.3 - -[options.extras_require] -io = - netCDF4 - h5netcdf - scipy - pydap; python_version<"3.10" # see https://github.com/pydap/pydap/issues/268 - zarr - fsspec - cftime - pooch - ## Scitools packages & dependencies (e.g: cartopy, cf-units) can be hard to install - # scitools-iris - -accel = - scipy - bottleneck - numbagg - flox - -parallel = - dask[complete] - -viz = - matplotlib - seaborn - nc-time-axis - ## Cartopy requires 3rd party libraries and only provides source distributions - ## See: https://github.com/SciTools/cartopy/issues/805 - # cartopy - -complete = - %(io)s - %(accel)s - %(parallel)s - %(viz)s - -docs = - %(complete)s - sphinx-autosummary-accessors - sphinx_rtd_theme - ipython - ipykernel - jupyter-client - nbsphinx - scanpydoc - -[options.package_data] -xarray = - py.typed - tests/data/* - static/css/* - static/html/* - -[options.entry_points] -xarray.chunkmanagers = - dask = xarray.core.daskmanager:DaskManager - -[tool:pytest] -python_files = test_*.py -testpaths = xarray/tests properties -# Fixed upstream in https://github.com/pydata/bottleneck/pull/199 -filterwarnings = - ignore:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning -markers = - flaky: flaky tests - network: tests requiring a network connection - slow: slow tests - -[aliases] -test = pytest - -[pytest-watch] -nobeep = True diff --git a/setup.py b/setup.py index 088d7e4eac6..69343515fd5 100755 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ #!/usr/bin/env python from setuptools import setup -setup(use_scm_version={"fallback_version": "999"}) +setup(use_scm_version={"fallback_version": "9999"}) diff --git a/xarray/__init__.py b/xarray/__init__.py index 87b897cf1ea..1fd3b0c4336 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,3 +1,5 @@ +from importlib.metadata import version as _version + from xarray import testing, tutorial from xarray.backends.api import ( load_dataarray, @@ -26,6 +28,7 @@ where, ) from xarray.core.concat import concat +from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.extensions import ( @@ -37,21 +40,15 @@ from xarray.core.merge import Context, MergeError, merge from xarray.core.options import get_options, set_options from xarray.core.parallel import map_blocks -from xarray.core.variable import Coordinate, IndexVariable, Variable, as_variable +from xarray.core.variable import IndexVariable, Variable, as_variable from xarray.util.print_versions import show_versions -try: - from importlib.metadata import version as _version -except ImportError: - # if the fallback library is missing, we are doomed. - from importlib_metadata import version as _version - try: __version__ = _version("xarray") except Exception: # Local copy or not installed with setuptools. # Disable minimum version checks on downstream libraries. - __version__ = "999" + __version__ = "9999" # A hardcoded __all__ variable is necessary to appease # `mypy --strict` running in projects that import xarray. @@ -100,6 +97,7 @@ "CFTimeIndex", "Context", "Coordinate", + "Coordinates", "DataArray", "Dataset", "Index", diff --git a/xarray/backends/api.py b/xarray/backends/api.py index d992d3999a3..27e155872de 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -488,6 +488,9 @@ def open_dataset( as coordinate variables. - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and other attributes as coordinate variables. + + Only existing variables can be set as coordinates. Missing variables + will be silently ignored. drop_variables: str or iterable of str, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -691,6 +694,9 @@ def open_dataarray( as coordinate variables. - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and other attributes as coordinate variables. + + Only existing variables can be set as coordinates. Missing variables + will be silently ignored. drop_variables: str or iterable of str, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -930,7 +936,9 @@ def open_mfdataset( If a callable, it must expect a sequence of ``attrs`` dicts and a context object as its only parameters. **kwargs : optional - Additional arguments passed on to :py:func:`xarray.open_dataset`. + Additional arguments passed on to :py:func:`xarray.open_dataset`. For an + overview of some of the possible options, see the documentation of + :py:func:`xarray.open_dataset` Returns ------- @@ -965,6 +973,13 @@ def open_mfdataset( ... "file_*.nc", concat_dim="time", preprocess=partial_func ... ) # doctest: +SKIP + It is also possible to use any argument to ``open_dataset`` together + with ``open_mfdataset``, such as for example ``drop_variables``: + + >>> ds = xr.open_mfdataset( + ... "file.nc", drop_variables=["varname_1", "varname_2"] # any list of vars + ... ) # doctest: +SKIP + References ---------- @@ -1047,8 +1062,8 @@ def open_mfdataset( ) else: raise ValueError( - "{} is an invalid option for the keyword argument" - " ``combine``".format(combine) + f"{combine} is an invalid option for the keyword argument" + " ``combine``" ) except ValueError: for ds in datasets: @@ -1513,6 +1528,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: Literal[True] = True, consolidated: bool | None = None, append_dim: Hashable | None = None, @@ -1520,6 +1536,7 @@ def to_zarr( safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + write_empty_chunks: bool | None = None, chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> backends.ZarrStore: ... @@ -1543,6 +1560,7 @@ def to_zarr( safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + write_empty_chunks: bool | None = None, chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> Delayed: ... @@ -1556,6 +1574,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, @@ -1563,6 +1582,7 @@ def to_zarr( safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + write_empty_chunks: bool | None = None, chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> backends.ZarrStore | Delayed: """This function creates an appropriate datastore for writing a dataset to @@ -1656,6 +1676,7 @@ def to_zarr( safe_chunks=safe_chunks, stacklevel=4, # for Dataset.to_zarr() zarr_version=zarr_version, + write_empty=write_empty_chunks, ) if mode in ["a", "r+"]: diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 1ac988c6b4f..5b8f9a6840f 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -247,7 +247,7 @@ def sync(self, compute=True, chunkmanager_store_kwargs=None): chunkmanager = get_chunked_array_type(*self.sources) # TODO: consider wrapping targets with dask.delayed, if this makes - # for any discernible difference in perforance, e.g., + # for any discernible difference in performance, e.g., # targets = [dask.delayed(t) for t in self.targets] if chunkmanager_store_kwargs is None: diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 697ebb8ab92..19748084625 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -6,8 +6,6 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Any -from packaging.version import Version - from xarray.backends.common import ( BACKEND_ENTRYPOINTS, BackendEntrypoint, @@ -200,6 +198,8 @@ def open_store_variable(self, name, var): "fletcher32": var.fletcher32, "shuffle": var.shuffle, } + if var.chunks: + encoding["preferred_chunks"] = dict(zip(var.dimensions, var.chunks)) # Convert h5py-style compression options to NetCDF4-Python # style, if possible if var.compression == "gzip": @@ -233,28 +233,14 @@ def get_attrs(self): return FrozenDict(_read_attributes(self.ds)) def get_dimensions(self): - import h5netcdf - - if Version(h5netcdf.__version__) >= Version("0.14.0.dev0"): - return FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items()) - else: - return self.ds.dimensions + return FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items()) def get_encoding(self): - import h5netcdf - - if Version(h5netcdf.__version__) >= Version("0.14.0.dev0"): - return { - "unlimited_dims": { - k for k, v in self.ds.dimensions.items() if v.isunlimited() - } - } - else: - return { - "unlimited_dims": { - k for k, v in self.ds.dimensions.items() if v is None - } + return { + "unlimited_dims": { + k for k, v in self.ds.dimensions.items() if v.isunlimited() } + } def set_dimension(self, name, length, is_unlimited=False): _ensure_no_forward_slash_in_name(name) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index b5c3413e7f8..f21f15bf795 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -207,7 +207,7 @@ def _ensure_fill_value_valid(data, attributes): # work around for netCDF4/scipy issue where _FillValue has the wrong type: # https://github.com/Unidata/netcdf4-python/issues/271 if data.dtype.kind == "S" and "_FillValue" in attributes: - attributes["_FillValue"] = np.string_(attributes["_FillValue"]) + attributes["_FillValue"] = np.bytes_(attributes["_FillValue"]) def _force_native_endianness(var): @@ -426,6 +426,7 @@ def open_store_variable(self, name, var): else: encoding["contiguous"] = False encoding["chunksizes"] = tuple(chunking) + encoding["preferred_chunks"] = dict(zip(var.dimensions, chunking)) # TODO: figure out how to round-trip "endian-ness" without raising # warnings from netCDF4 # encoding['endian'] = var.endian() @@ -538,7 +539,7 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint): """ Backend for netCDF files based on the netCDF4 package. - It can open ".nc", ".nc4", ".cdf" files and will be choosen + It can open ".nc", ".nc4", ".cdf" files and will be chosen as default for these files. Additionally it can open valid HDF5 files, see diff --git a/xarray/backends/netcdf3.py b/xarray/backends/netcdf3.py index ef389eefc90..db00ef1972b 100644 --- a/xarray/backends/netcdf3.py +++ b/xarray/backends/netcdf3.py @@ -88,13 +88,30 @@ def encode_nc3_attrs(attrs): return {k: encode_nc3_attr_value(v) for k, v in attrs.items()} +def _maybe_prepare_times(var): + # checks for integer-based time-like and + # replaces np.iinfo(np.int64).min with _FillValue or np.nan + # this keeps backwards compatibility + + data = var.data + if data.dtype.kind in "iu": + units = var.attrs.get("units", None) + if units is not None: + if coding.variables._is_time_like(units): + mask = data == np.iinfo(np.int64).min + if mask.any(): + data = np.where(mask, var.attrs.get("_FillValue", np.nan), data) + return data + + def encode_nc3_variable(var): for coder in [ coding.strings.EncodedStringCoder(allows_unicode=False), coding.strings.CharacterArrayCoder(), ]: var = coder.encode(var) - data = coerce_nc3_dtype(var.data) + data = _maybe_prepare_times(var) + data = coerce_nc3_dtype(data) attrs = encode_nc3_attrs(var.attrs) return Variable(var.dims, data, attrs, var.encoding) diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 116c48f5692..9b5bcc82e6f 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Any import numpy as np -from packaging.version import Version from xarray.backends.common import ( BACKEND_ENTRYPOINTS, @@ -123,11 +122,10 @@ def open( "output_grid": output_grid or True, "timeout": timeout, } - if Version(pydap.lib.__version__) >= Version("3.3.0"): - if verify is not None: - kwargs.update({"verify": verify}) - if user_charset is not None: - kwargs.update({"user_charset": user_charset}) + if verify is not None: + kwargs.update({"verify": verify}) + if user_charset is not None: + kwargs.update({"user_charset": user_charset}) ds = pydap.client.open_url(**kwargs) return cls(ds) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 214581bba27..d6ad15f4f87 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -61,27 +61,29 @@ def encode_zarr_attr_value(value): class ZarrArrayWrapper(BackendArray): - __slots__ = ("datastore", "dtype", "shape", "variable_name") + __slots__ = ("datastore", "dtype", "shape", "variable_name", "_array") def __init__(self, variable_name, datastore): self.datastore = datastore self.variable_name = variable_name - array = self.get_array() - self.shape = array.shape + # some callers attempt to evaluate an array if an `array` property exists on the object. + # we prefix with _ to avoid this inference. + self._array = self.datastore.zarr_group[self.variable_name] + self.shape = self._array.shape # preserve vlen string object dtype (GH 7328) - if array.filters is not None and any( - [filt.codec_id == "vlen-utf8" for filt in array.filters] + if self._array.filters is not None and any( + [filt.codec_id == "vlen-utf8" for filt in self._array.filters] ): dtype = coding.strings.create_vlen_dtype(str) else: - dtype = array.dtype + dtype = self._array.dtype self.dtype = dtype def get_array(self): - return self.datastore.zarr_group[self.variable_name] + return self._array def _oindex(self, key): return self.get_array().oindex[key] @@ -368,6 +370,7 @@ class ZarrStore(AbstractWritableDataStore): "_synchronizer", "_write_region", "_safe_chunks", + "_write_empty", ) @classmethod @@ -386,6 +389,7 @@ def open_group( safe_chunks=True, stacklevel=2, zarr_version=None, + write_empty: bool | None = None, ): import zarr @@ -457,6 +461,7 @@ def open_group( append_dim, write_region, safe_chunks, + write_empty, ) def __init__( @@ -467,6 +472,7 @@ def __init__( append_dim=None, write_region=None, safe_chunks=True, + write_empty: bool | None = None, ): self.zarr_group = zarr_group self._read_only = self.zarr_group.read_only @@ -477,6 +483,7 @@ def __init__( self._append_dim = append_dim self._write_region = write_region self._safe_chunks = safe_chunks + self._write_empty = write_empty @property def ds(self): @@ -648,6 +655,8 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No dimensions. """ + import zarr + for vn, v in variables.items(): name = _encode_variable_name(vn) check = vn in check_encoding_set @@ -665,7 +674,14 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No # TODO: if mode="a", consider overriding the existing variable # metadata. This would need some case work properly with region # and append_dim. - zarr_array = self.zarr_group[name] + if self._write_empty is not None: + zarr_array = zarr.open( + store=self.zarr_group.store, + path=f"{self.zarr_group.name}/{name}", + write_empty_chunks=self._write_empty, + ) + else: + zarr_array = self.zarr_group[name] else: # new variable encoding = extract_zarr_variable_encoding( @@ -679,8 +695,25 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No if coding.strings.check_vlen_dtype(dtype) == str: dtype = str + + if self._write_empty is not None: + if ( + "write_empty_chunks" in encoding + and encoding["write_empty_chunks"] != self._write_empty + ): + raise ValueError( + 'Differing "write_empty_chunks" values in encoding and parameters' + f'Got {encoding["write_empty_chunks"] = } and {self._write_empty = }' + ) + else: + encoding["write_empty_chunks"] = self._write_empty + zarr_array = self.zarr_group.create( - name, shape=shape, dtype=dtype, fill_value=fill_value, **encoding + name, + shape=shape, + dtype=dtype, + fill_value=fill_value, + **encoding, ) zarr_array = _put_attrs(zarr_array, encoded_attrs) diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index a746163c3fd..0b469ae26fc 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -105,7 +105,7 @@ def __init__(self, n: int = 1): if not isinstance(n, int): raise TypeError( "The provided multiple 'n' must be an integer. " - "Instead a value of type {!r} was provided.".format(type(n)) + f"Instead a value of type {type(n)!r} was provided." ) self.n = n @@ -353,13 +353,13 @@ def _validate_month(month, default_month): raise TypeError( "'self.month' must be an integer value between 1 " "and 12. Instead, it was set to a value of " - "{!r}".format(result_month) + f"{result_month!r}" ) elif not (1 <= result_month <= 12): raise ValueError( "'self.month' must be an integer value between 1 " "and 12. Instead, it was set to a value of " - "{!r}".format(result_month) + f"{result_month!r}" ) return result_month @@ -771,7 +771,7 @@ def to_cftime_datetime(date_str_or_date, calendar=None): raise TypeError( "date_str_or_date must be a string or a " "subclass of cftime.datetime. Instead got " - "{!r}.".format(date_str_or_date) + f"{date_str_or_date!r}." ) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index c6a7b9f8763..a0800db445a 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -228,12 +228,12 @@ def assert_all_valid_date_type(data): if not isinstance(sample, cftime.datetime): raise TypeError( "CFTimeIndex requires cftime.datetime " - "objects. Got object of {}.".format(date_type) + f"objects. Got object of {date_type}." ) if not all(isinstance(value, date_type) for value in data): raise TypeError( "CFTimeIndex requires using datetime " - "objects of all the same type. Got\n{}.".format(data) + f"objects of all the same type. Got\n{data}." ) @@ -470,13 +470,9 @@ def get_loc(self, key): else: return super().get_loc(key) - def _maybe_cast_slice_bound(self, label, side, kind=None): + def _maybe_cast_slice_bound(self, label, side): """Adapted from pandas.tseries.index.DatetimeIndex._maybe_cast_slice_bound - - Note that we have never used the kind argument in CFTimeIndex and it is - deprecated as of pandas version 1.3.0. It exists only for compatibility - reasons. We can remove it when our minimum version of pandas is 1.3.0. """ if not isinstance(label, str): return label @@ -557,8 +553,7 @@ def shift(self, n: int | float, freq: str | timedelta): return self + n * to_offset(freq) else: raise TypeError( - "'freq' must be of type " - "str or datetime.timedelta, got {}.".format(freq) + "'freq' must be of type " f"str or datetime.timedelta, got {freq}." ) def __add__(self, other): @@ -640,10 +635,10 @@ def to_datetimeindex(self, unsafe=False): if calendar not in _STANDARD_CALENDARS and not unsafe: warnings.warn( "Converting a CFTimeIndex with dates from a non-standard " - "calendar, {!r}, to a pandas.DatetimeIndex, which uses dates " + f"calendar, {calendar!r}, to a pandas.DatetimeIndex, which uses dates " "from the standard calendar. This may lead to subtle errors " "in operations that depend on the length of time between " - "dates.".format(calendar), + "dates.", RuntimeWarning, stacklevel=2, ) diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index d0bfb1a7a63..89ceaddd93b 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -59,9 +59,9 @@ def encode(self, variable, name=None): if contains_unicode and (encode_as_char or not self.allows_unicode): if "_FillValue" in attrs: raise NotImplementedError( - "variable {!r} has a _FillValue specified, but " + f"variable {name!r} has a _FillValue specified, but " "_FillValue is not yet supported on unicode strings: " - "https://github.com/pydata/xarray/issues/1647".format(name) + "https://github.com/pydata/xarray/issues/1647" ) string_encoding = encoding.pop("_Encoding", "utf-8") @@ -100,7 +100,7 @@ def ensure_fixed_length_bytes(var): dims, data, attrs, encoding = unpack_for_encoding(var) if check_vlen_dtype(data.dtype) == bytes: # TODO: figure out how to handle this with dask - data = np.asarray(data, dtype=np.string_) + data = np.asarray(data, dtype=np.bytes_) return Variable(dims, data, attrs, encoding) @@ -151,7 +151,7 @@ def bytes_to_char(arr): def _numpy_bytes_to_char(arr): """Like netCDF4.stringtochar, but faster and more flexible.""" # ensure the array is contiguous - arr = np.array(arr, copy=False, order="C", dtype=np.string_) + arr = np.array(arr, copy=False, order="C", dtype=np.bytes_) return arr.reshape(arr.shape + (1,)).view("S1") @@ -168,7 +168,7 @@ def char_to_bytes(arr): if not size: # can't make an S0 dtype - return np.zeros(arr.shape[:-1], dtype=np.string_) + return np.zeros(arr.shape[:-1], dtype=np.bytes_) if is_chunked_array(arr): chunkmanager = get_chunked_array_type(arr) @@ -176,7 +176,7 @@ def char_to_bytes(arr): if len(arr.chunks[-1]) > 1: raise ValueError( "cannot stacked dask character array with " - "multiple chunks in the last dimension: {}".format(arr) + f"multiple chunks in the last dimension: {arr}" ) dtype = np.dtype("S" + str(arr.shape[-1])) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 3745d61acc0..039fe371100 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -25,6 +25,7 @@ from xarray.core.formatting import first_n_items, format_timestamp, last_item from xarray.core.pdcompat import nanosecond_precision_timestamp from xarray.core.pycompat import is_duck_dask_array +from xarray.core.utils import emit_user_level_warning from xarray.core.variable import Variable try: @@ -122,6 +123,18 @@ def _netcdf_to_numpy_timeunit(units: str) -> str: }[units] +def _numpy_to_netcdf_timeunit(units: str) -> str: + return { + "ns": "nanoseconds", + "us": "microseconds", + "ms": "milliseconds", + "s": "seconds", + "m": "minutes", + "h": "hours", + "D": "days", + }[units] + + def _ensure_padded_year(ref_date: str) -> str: # Reference dates without a padded year (e.g. since 1-1-1 or since 2-3-4) # are ambiguous (is it YMD or DMY?). This can lead to some very odd @@ -171,6 +184,20 @@ def _unpack_netcdf_time_units(units: str) -> tuple[str, str]: return delta_units, ref_date +def _unpack_time_units_and_ref_date(units: str) -> tuple[str, pd.Timestamp]: + # same us _unpack_netcdf_time_units but finalizes ref_date for + # processing in encode_cf_datetime + time_units, _ref_date = _unpack_netcdf_time_units(units) + # TODO: the strict enforcement of nanosecond precision Timestamps can be + # relaxed when addressing GitHub issue #7493. + ref_date = nanosecond_precision_timestamp(_ref_date) + # If the ref_date Timestamp is timezone-aware, convert to UTC and + # make it timezone-naive (GH 2649). + if ref_date.tz is not None: + ref_date = ref_date.tz_convert(None) + return time_units, ref_date + + def _decode_cf_datetime_dtype( data, units: str, calendar: str, use_cftime: bool | None ) -> np.dtype: @@ -218,12 +245,12 @@ def _decode_datetime_with_pandas( ) -> np.ndarray: if not _is_standard_calendar(calendar): raise OutOfBoundsDatetime( - "Cannot decode times from a non-standard calendar, {!r}, using " - "pandas.".format(calendar) + f"Cannot decode times from a non-standard calendar, {calendar!r}, using " + "pandas." ) - delta, ref_date = _unpack_netcdf_time_units(units) - delta = _netcdf_to_numpy_timeunit(delta) + time_units, ref_date = _unpack_netcdf_time_units(units) + time_units = _netcdf_to_numpy_timeunit(time_units) try: # TODO: the strict enforcement of nanosecond precision Timestamps can be # relaxed when addressing GitHub issue #7493. @@ -237,8 +264,8 @@ def _decode_datetime_with_pandas( warnings.filterwarnings("ignore", "invalid value encountered", RuntimeWarning) if flat_num_dates.size > 0: # avoid size 0 datetimes GH1329 - pd.to_timedelta(flat_num_dates.min(), delta) + ref_date - pd.to_timedelta(flat_num_dates.max(), delta) + ref_date + pd.to_timedelta(flat_num_dates.min(), time_units) + ref_date + pd.to_timedelta(flat_num_dates.max(), time_units) + ref_date # To avoid integer overflow when converting to nanosecond units for integer # dtypes smaller than np.int64 cast all integer and unsigned integer dtype @@ -251,9 +278,12 @@ def _decode_datetime_with_pandas( # Cast input ordinals to integers of nanoseconds because pd.to_timedelta # works much faster when dealing with integers (GH 1399). - flat_num_dates_ns_int = (flat_num_dates * _NS_PER_TIME_DELTA[delta]).astype( - np.int64 - ) + # properly handle NaN/NaT to prevent casting NaN to int + nan = np.isnan(flat_num_dates) | (flat_num_dates == np.iinfo(np.int64).min) + flat_num_dates = flat_num_dates * _NS_PER_TIME_DELTA[time_units] + flat_num_dates_ns_int = np.zeros_like(flat_num_dates, dtype=np.int64) + flat_num_dates_ns_int[nan] = np.iinfo(np.int64).min + flat_num_dates_ns_int[~nan] = flat_num_dates[~nan].astype(np.int64) # Use pd.to_timedelta to safely cast integer values to timedeltas, # and add those to a Timestamp to safely produce a DatetimeIndex. This @@ -364,6 +394,10 @@ def _infer_time_units_from_diff(unique_timedeltas) -> str: return "seconds" +def _time_units_to_timedelta64(units: str) -> np.timedelta64: + return np.timedelta64(1, _netcdf_to_numpy_timeunit(units)).astype("timedelta64[ns]") + + def infer_calendar_name(dates) -> CFCalendar: """Given an array of datetimes, infer the CF calendar name""" if is_np_datetime_like(dates.dtype): @@ -452,8 +486,8 @@ def cftime_to_nptime(times, raise_on_invalid: bool = True) -> np.ndarray: except ValueError as e: if raise_on_invalid: raise ValueError( - "Cannot convert date {} to a date in the " - "standard calendar. Reason: {}.".format(t, e) + f"Cannot convert date {t} to a date in the " + f"standard calendar. Reason: {e}." ) else: dt = "NaT" @@ -467,7 +501,7 @@ def convert_times(times, date_type, raise_on_invalid: bool = True) -> np.ndarray Useful to convert between calendars in numpy and cftime or between cftime calendars. If raise_on_valid is True (default), invalid dates trigger a ValueError. - Otherwise, the invalid element is replaced by np.NaN for cftime types and np.NaT for np.datetime64. + Otherwise, the invalid element is replaced by np.nan for cftime types and np.NaT for np.datetime64. """ if date_type in (pd.Timestamp, np.datetime64) and not is_np_datetime_like( times.dtype @@ -485,13 +519,11 @@ def convert_times(times, date_type, raise_on_invalid: bool = True) -> np.ndarray except ValueError as e: if raise_on_invalid: raise ValueError( - "Cannot convert date {} to a date in the " - "{} calendar. Reason: {}.".format( - t, date_type(2000, 1, 1).calendar, e - ) + f"Cannot convert date {t} to a date in the " + f"{date_type(2000, 1, 1).calendar} calendar. Reason: {e}." ) else: - dt = np.NaN + dt = np.nan new[i] = dt return new @@ -574,9 +606,12 @@ def _should_cftime_be_used( def _cleanup_netcdf_time_units(units: str) -> str: - delta, ref_date = _unpack_netcdf_time_units(units) + time_units, ref_date = _unpack_netcdf_time_units(units) + time_units = time_units.lower() + if not time_units.endswith("s"): + time_units = f"{time_units}s" try: - units = f"{delta} since {format_timestamp(ref_date)}" + units = f"{time_units} since {format_timestamp(ref_date)}" except (OutOfBoundsDatetime, ValueError): # don't worry about reifying the units if they're out of bounds or # formatted badly @@ -621,8 +656,22 @@ def cast_to_int_if_safe(num) -> np.ndarray: return num +def _division(deltas, delta, floor): + if floor: + # calculate int64 floor division + # to preserve integer dtype if possible (GH 4045, GH7817). + num = deltas // delta.astype(np.int64) + num = num.astype(np.int64, copy=False) + else: + num = deltas / delta + return num + + def encode_cf_datetime( - dates, units: str | None = None, calendar: str | None = None + dates, + units: str | None = None, + calendar: str | None = None, + dtype: np.dtype | None = None, ) -> tuple[np.ndarray, str, str]: """Given an array of datetime objects, returns the tuple `(num, units, calendar)` suitable for a CF compliant time variable. @@ -635,32 +684,24 @@ def encode_cf_datetime( """ dates = np.asarray(dates) + data_units = infer_datetime_units(dates) + if units is None: - units = infer_datetime_units(dates) + units = data_units else: units = _cleanup_netcdf_time_units(units) if calendar is None: calendar = infer_calendar_name(dates) - delta, _ref_date = _unpack_netcdf_time_units(units) try: if not _is_standard_calendar(calendar) or dates.dtype.kind == "O": # parse with cftime instead raise OutOfBoundsDatetime assert dates.dtype == "datetime64[ns]" - delta_units = _netcdf_to_numpy_timeunit(delta) - time_delta = np.timedelta64(1, delta_units).astype("timedelta64[ns]") - - # TODO: the strict enforcement of nanosecond precision Timestamps can be - # relaxed when addressing GitHub issue #7493. - ref_date = nanosecond_precision_timestamp(_ref_date) - - # If the ref_date Timestamp is timezone-aware, convert to UTC and - # make it timezone-naive (GH 2649). - if ref_date.tz is not None: - ref_date = ref_date.tz_convert(None) + time_units, ref_date = _unpack_time_units_and_ref_date(units) + time_delta = _time_units_to_timedelta64(time_units) # Wrap the dates in a DatetimeIndex to do the subtraction to ensure # an OverflowError is raised if the ref_date is too far away from @@ -668,29 +709,94 @@ def encode_cf_datetime( dates_as_index = pd.DatetimeIndex(dates.ravel()) time_deltas = dates_as_index - ref_date - # Use floor division if time_delta evenly divides all differences - # to preserve integer dtype if possible (GH 4045). - if np.all(time_deltas % time_delta == np.timedelta64(0, "ns")): - num = time_deltas // time_delta - else: - num = time_deltas / time_delta + # retrieve needed units to faithfully encode to int64 + needed_units, data_ref_date = _unpack_time_units_and_ref_date(data_units) + if data_units != units: + # this accounts for differences in the reference times + ref_delta = abs(data_ref_date - ref_date).to_timedelta64() + data_delta = _time_units_to_timedelta64(needed_units) + if (ref_delta % data_delta) > np.timedelta64(0, "ns"): + needed_units = _infer_time_units_from_diff(ref_delta) + + # needed time delta to encode faithfully to int64 + needed_time_delta = _time_units_to_timedelta64(needed_units) + + floor_division = True + if time_delta > needed_time_delta: + floor_division = False + if dtype is None: + emit_user_level_warning( + f"Times can't be serialized faithfully to int64 with requested units {units!r}. " + f"Resolution of {needed_units!r} needed. Serializing times to floating point instead. " + f"Set encoding['dtype'] to integer dtype to serialize to int64. " + f"Set encoding['dtype'] to floating point dtype to silence this warning." + ) + elif np.issubdtype(dtype, np.integer): + new_units = f"{needed_units} since {format_timestamp(ref_date)}" + emit_user_level_warning( + f"Times can't be serialized faithfully to int64 with requested units {units!r}. " + f"Serializing with units {new_units!r} instead. " + f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. " + f"Set encoding['units'] to {new_units!r} to silence this warning ." + ) + units = new_units + time_delta = needed_time_delta + floor_division = True + + num = _division(time_deltas, time_delta, floor_division) num = num.values.reshape(dates.shape) except (OutOfBoundsDatetime, OverflowError, ValueError): num = _encode_datetime_with_cftime(dates, units, calendar) + # do it now only for cftime-based flow + # we already covered for this in pandas-based flow + num = cast_to_int_if_safe(num) - num = cast_to_int_if_safe(num) return (num, units, calendar) -def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarray, str]: +def encode_cf_timedelta( + timedeltas, units: str | None = None, dtype: np.dtype | None = None +) -> tuple[np.ndarray, str]: + data_units = infer_timedelta_units(timedeltas) + if units is None: - units = infer_timedelta_units(timedeltas) + units = data_units + + time_delta = _time_units_to_timedelta64(units) + time_deltas = pd.TimedeltaIndex(timedeltas.ravel()) + + # retrieve needed units to faithfully encode to int64 + needed_units = data_units + if data_units != units: + needed_units = _infer_time_units_from_diff(np.unique(time_deltas.dropna())) + + # needed time delta to encode faithfully to int64 + needed_time_delta = _time_units_to_timedelta64(needed_units) + + floor_division = True + if time_delta > needed_time_delta: + floor_division = False + if dtype is None: + emit_user_level_warning( + f"Timedeltas can't be serialized faithfully to int64 with requested units {units!r}. " + f"Resolution of {needed_units!r} needed. Serializing timeseries to floating point instead. " + f"Set encoding['dtype'] to integer dtype to serialize to int64. " + f"Set encoding['dtype'] to floating point dtype to silence this warning." + ) + elif np.issubdtype(dtype, np.integer): + emit_user_level_warning( + f"Timedeltas can't be serialized faithfully with requested units {units!r}. " + f"Serializing with units {needed_units!r} instead. " + f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. " + f"Set encoding['units'] to {needed_units!r} to silence this warning ." + ) + units = needed_units + time_delta = needed_time_delta + floor_division = True - np_unit = _netcdf_to_numpy_timeunit(units) - num = 1.0 * timedeltas / np.timedelta64(1, np_unit) - num = np.where(pd.isnull(timedeltas), np.nan, num) - num = cast_to_int_if_safe(num) + num = _division(time_deltas, time_delta, floor_division) + num = num.values.reshape(timedeltas.shape) return (num, units) @@ -704,9 +810,11 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: ) or contains_cftime_datetimes(variable): dims, data, attrs, encoding = unpack_for_encoding(variable) - (data, units, calendar) = encode_cf_datetime( - data, encoding.pop("units", None), encoding.pop("calendar", None) - ) + units = encoding.pop("units", None) + calendar = encoding.pop("calendar", None) + dtype = encoding.get("dtype", None) + (data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype) + safe_setitem(attrs, "units", units, name=name) safe_setitem(attrs, "calendar", calendar, name=name) @@ -740,7 +848,9 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: if np.issubdtype(variable.data.dtype, np.timedelta64): dims, data, attrs, encoding = unpack_for_encoding(variable) - data, units = encode_cf_timedelta(data, encoding.pop("units", None)) + data, units = encode_cf_timedelta( + data, encoding.pop("units", None), encoding.get("dtype", None) + ) safe_setitem(attrs, "units", units, name=name) return Variable(dims, data, attrs, encoding, fastpath=True) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 8ba7dcbb0e2..df660f90d9e 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -179,10 +179,10 @@ def safe_setitem(dest, key: Hashable, value, name: T_Name = None): if key in dest: var_str = f" on variable {name!r}" if name else "" raise ValueError( - "failed to prevent overwriting existing key {} in attrs{}. " + f"failed to prevent overwriting existing key {key} in attrs{var_str}. " "This is probably an encoding field used by xarray to describe " "how a variable is serialized. To proceed, remove this key from " - "the variable's attributes manually.".format(key, var_str) + "the variable's attributes manually." ) dest[key] = value @@ -215,6 +215,34 @@ def _apply_mask( return np.where(condition, decoded_fill_value, data) +def _is_time_like(units): + # test for time-like + if units is None: + return False + time_strings = [ + "days", + "hours", + "minutes", + "seconds", + "milliseconds", + "microseconds", + "nanoseconds", + ] + units = str(units) + # to prevent detecting units like `days accumulated` as time-like + # special casing for datetime-units and timedelta-units (GH-8269) + if "since" in units: + from xarray.coding.times import _unpack_netcdf_time_units + + try: + _unpack_netcdf_time_units(units) + except ValueError: + return False + return True + else: + return any(tstr == units for tstr in time_strings) + + class CFMaskCoder(VariableCoder): """Mask or unmask fill values according to CF conventions.""" @@ -236,19 +264,32 @@ def encode(self, variable: Variable, name: T_Name = None): f"Variable {name!r} has conflicting _FillValue ({fv}) and missing_value ({mv}). Cannot encode data." ) + # special case DateTime to properly handle NaT + is_time_like = _is_time_like(attrs.get("units")) + if fv_exists: # Ensure _FillValue is cast to same dtype as data's encoding["_FillValue"] = dtype.type(fv) fill_value = pop_to(encoding, attrs, "_FillValue", name=name) if not pd.isnull(fill_value): - data = duck_array_ops.fillna(data, fill_value) + if is_time_like and data.dtype.kind in "iu": + data = duck_array_ops.where( + data != np.iinfo(np.int64).min, data, fill_value + ) + else: + data = duck_array_ops.fillna(data, fill_value) if mv_exists: # Ensure missing_value is cast to same dtype as data's encoding["missing_value"] = dtype.type(mv) fill_value = pop_to(encoding, attrs, "missing_value", name=name) if not pd.isnull(fill_value) and not fv_exists: - data = duck_array_ops.fillna(data, fill_value) + if is_time_like and data.dtype.kind in "iu": + data = duck_array_ops.where( + data != np.iinfo(np.int64).min, data, fill_value + ) + else: + data = duck_array_ops.fillna(data, fill_value) return Variable(dims, data, attrs, encoding, fastpath=True) @@ -275,7 +316,13 @@ def decode(self, variable: Variable, name: T_Name = None): stacklevel=3, ) - dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype) + # special case DateTime to properly handle NaT + dtype: np.typing.DTypeLike + decoded_fill_value: Any + if _is_time_like(attrs.get("units")) and data.dtype.kind in "iu": + dtype, decoded_fill_value = np.int64, np.iinfo(np.int64).min + else: + dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype) if encoded_fill_values: transform = partial( diff --git a/xarray/conventions.py b/xarray/conventions.py index 053863ace2a..cf207f0c37a 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -1,9 +1,8 @@ from __future__ import annotations -import warnings from collections import defaultdict from collections.abc import Hashable, Iterable, Mapping, MutableMapping -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Literal, Union import numpy as np import pandas as pd @@ -16,6 +15,7 @@ contains_cftime_datetimes, ) from xarray.core.pycompat import is_duck_dask_array +from xarray.core.utils import emit_user_level_warning from xarray.core.variable import IndexVariable, Variable CF_RELATED_DATA = ( @@ -75,20 +75,18 @@ def _infer_dtype(array, name: T_Name = None) -> np.dtype: return dtype raise ValueError( - "unable to infer dtype on variable {!r}; xarray " - "cannot serialize arbitrary Python objects".format(name) + f"unable to infer dtype on variable {name!r}; xarray " + "cannot serialize arbitrary Python objects" ) def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None: if isinstance(var, IndexVariable) and isinstance(var.to_index(), pd.MultiIndex): raise NotImplementedError( - "variable {!r} is a MultiIndex, which cannot yet be " + f"variable {name!r} is a MultiIndex, which cannot yet be " "serialized to netCDF files. Instead, either use reset_index() " "to convert MultiIndex levels into coordinate variables instead " - "or use https://cf-xarray.readthedocs.io/en/latest/coding.html.".format( - name - ) + "or use https://cf-xarray.readthedocs.io/en/latest/coding.html." ) @@ -113,13 +111,13 @@ def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: return var if is_duck_dask_array(data): - warnings.warn( - "variable {} has data in the form of a dask array with " + emit_user_level_warning( + f"variable {name} has data in the form of a dask array with " "dtype=object, which means it is being loaded into memory " "to determine a data type that can be safely stored on disk. " "To avoid this, coerce this variable to a fixed-size dtype " - "with astype() before saving it.".format(name), - SerializationWarning, + "with astype() before saving it.", + category=SerializationWarning, ) data = data.compute() @@ -356,15 +354,14 @@ def _update_bounds_encoding(variables: T_Variables) -> None: and "bounds" in attrs and attrs["bounds"] in variables ): - warnings.warn( - "Variable '{0}' has datetime type and a " - "bounds variable but {0}.encoding does not have " - "units specified. The units encodings for '{0}' " - "and '{1}' will be determined independently " + emit_user_level_warning( + f"Variable {name:s} has datetime type and a " + f"bounds variable but {name:s}.encoding does not have " + f"units specified. The units encodings for {name:s} " + f"and {attrs['bounds']} will be determined independently " "and may not be equal, counter to CF-conventions. " "If this is a concern, specify a units encoding for " - "'{0}' before writing to a file.".format(name, attrs["bounds"]), - UserWarning, + f"{name:s} before writing to a file.", ) if has_date_units and "bounds" in attrs: @@ -381,7 +378,7 @@ def decode_cf_variables( concat_characters: bool = True, mask_and_scale: bool = True, decode_times: bool = True, - decode_coords: bool = True, + decode_coords: bool | Literal["coordinates", "all"] = True, drop_variables: T_DropVariables = None, use_cftime: bool | None = None, decode_timedelta: bool | None = None, @@ -443,11 +440,14 @@ def stackable(dim: Hashable) -> bool: if decode_coords in [True, "coordinates", "all"]: var_attrs = new_vars[k].attrs if "coordinates" in var_attrs: - coord_str = var_attrs["coordinates"] - var_coord_names = coord_str.split() - if all(k in variables for k in var_coord_names): - new_vars[k].encoding["coordinates"] = coord_str - del var_attrs["coordinates"] + var_coord_names = [ + c for c in var_attrs["coordinates"].split() if c in variables + ] + # propagate as is + new_vars[k].encoding["coordinates"] = var_attrs["coordinates"] + del var_attrs["coordinates"] + # but only use as coordinate if existing + if var_coord_names: coord_names.update(var_coord_names) if decode_coords == "all": @@ -463,8 +463,8 @@ def stackable(dim: Hashable) -> bool: for role_or_name in part.split() ] if len(roles_and_names) % 2 == 1: - warnings.warn( - f"Attribute {attr_name:s} malformed", stacklevel=5 + emit_user_level_warning( + f"Attribute {attr_name:s} malformed" ) var_names = roles_and_names[1::2] if all(var_name in variables for var_name in var_names): @@ -476,9 +476,8 @@ def stackable(dim: Hashable) -> bool: for proj_name in var_names if proj_name not in variables ] - warnings.warn( + emit_user_level_warning( f"Variable(s) referenced in {attr_name:s} not in variables: {referenced_vars_not_in_variables!s}", - stacklevel=5, ) del var_attrs[attr_name] @@ -495,7 +494,7 @@ def decode_cf( concat_characters: bool = True, mask_and_scale: bool = True, decode_times: bool = True, - decode_coords: bool = True, + decode_coords: bool | Literal["coordinates", "all"] = True, drop_variables: T_DropVariables = None, use_cftime: bool | None = None, decode_timedelta: bool | None = None, @@ -634,12 +633,11 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): for name in list(non_dim_coord_names): if isinstance(name, str) and " " in name: - warnings.warn( - "coordinate {!r} has a space in its name, which means it " + emit_user_level_warning( + f"coordinate {name!r} has a space in its name, which means it " "cannot be marked as a coordinate on disk and will be " - "saved as a data variable instead".format(name), - SerializationWarning, - stacklevel=6, + "saved as a data variable instead", + category=SerializationWarning, ) non_dim_coord_names.discard(name) @@ -694,7 +692,7 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): if not coords_str and variable_coordinates[name]: coordinates_text = " ".join( str(coord_name) - for coord_name in variable_coordinates[name] + for coord_name in sorted(variable_coordinates[name]) if coord_name not in not_technically_coordinates ) if coordinates_text: @@ -712,14 +710,14 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): if global_coordinates: attributes = dict(attributes) if "coordinates" in attributes: - warnings.warn( + emit_user_level_warning( f"cannot serialize global coordinates {global_coordinates!r} because the global " f"attribute 'coordinates' already exists. This may prevent faithful roundtripping" f"of xarray datasets", - SerializationWarning, + category=SerializationWarning, ) else: - attributes["coordinates"] = " ".join(map(str, global_coordinates)) + attributes["coordinates"] = " ".join(sorted(map(str, global_coordinates))) return variables, attributes diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index 4f6dce1a04a..89cec94e24f 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -8,7 +8,7 @@ from xarray.core import duck_array_ops from xarray.core.options import OPTIONS -from xarray.core.types import Dims +from xarray.core.types import Dims, Self from xarray.core.utils import contains_only_chunked_or_numpy, module_available if TYPE_CHECKING: @@ -30,7 +30,7 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> Dataset: + ) -> Self: raise NotImplementedError() def count( @@ -39,7 +39,7 @@ def count( *, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``count`` along some dimension(s). @@ -65,8 +65,8 @@ def count( See Also -------- - numpy.count - dask.array.count + pandas.DataFrame.count + dask.dataframe.DataFrame.count DataArray.count :ref:`agg` User guide on reduction or aggregation operations. @@ -74,7 +74,7 @@ def count( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -89,7 +89,7 @@ def count( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.count() @@ -111,7 +111,7 @@ def all( *, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``all`` along some dimension(s). @@ -183,7 +183,7 @@ def any( *, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``any`` along some dimension(s). @@ -256,7 +256,7 @@ def max( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``max`` along some dimension(s). @@ -296,7 +296,7 @@ def max( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -311,7 +311,7 @@ def max( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.max() @@ -343,7 +343,7 @@ def min( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``min`` along some dimension(s). @@ -383,7 +383,7 @@ def min( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -398,13 +398,13 @@ def min( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.min() Dimensions: () Data variables: - da float64 1.0 + da float64 0.0 Use ``skipna`` to control whether NaNs are ignored. @@ -430,7 +430,7 @@ def mean( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``mean`` along some dimension(s). @@ -474,7 +474,7 @@ def mean( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -489,13 +489,13 @@ def mean( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.mean() Dimensions: () Data variables: - da float64 1.8 + da float64 1.6 Use ``skipna`` to control whether NaNs are ignored. @@ -522,7 +522,7 @@ def prod( min_count: int | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``prod`` along some dimension(s). @@ -572,7 +572,7 @@ def prod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -587,13 +587,13 @@ def prod( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.prod() Dimensions: () Data variables: - da float64 12.0 + da float64 0.0 Use ``skipna`` to control whether NaNs are ignored. @@ -609,7 +609,7 @@ def prod( Dimensions: () Data variables: - da float64 12.0 + da float64 0.0 """ return self.reduce( duck_array_ops.prod, @@ -629,7 +629,7 @@ def sum( min_count: int | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``sum`` along some dimension(s). @@ -679,7 +679,7 @@ def sum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -694,13 +694,13 @@ def sum( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.sum() Dimensions: () Data variables: - da float64 9.0 + da float64 8.0 Use ``skipna`` to control whether NaNs are ignored. @@ -716,7 +716,7 @@ def sum( Dimensions: () Data variables: - da float64 9.0 + da float64 8.0 """ return self.reduce( duck_array_ops.sum, @@ -736,7 +736,7 @@ def std( ddof: int = 0, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``std`` along some dimension(s). @@ -783,7 +783,7 @@ def std( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -798,13 +798,13 @@ def std( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.std() Dimensions: () Data variables: - da float64 0.7483 + da float64 1.02 Use ``skipna`` to control whether NaNs are ignored. @@ -820,7 +820,7 @@ def std( Dimensions: () Data variables: - da float64 0.8367 + da float64 1.14 """ return self.reduce( duck_array_ops.std, @@ -840,7 +840,7 @@ def var( ddof: int = 0, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``var`` along some dimension(s). @@ -887,7 +887,7 @@ def var( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -902,13 +902,13 @@ def var( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.var() Dimensions: () Data variables: - da float64 0.56 + da float64 1.04 Use ``skipna`` to control whether NaNs are ignored. @@ -924,7 +924,7 @@ def var( Dimensions: () Data variables: - da float64 0.7 + da float64 1.3 """ return self.reduce( duck_array_ops.var, @@ -943,7 +943,7 @@ def median( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``median`` along some dimension(s). @@ -987,7 +987,7 @@ def median( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -1002,7 +1002,7 @@ def median( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.median() @@ -1034,7 +1034,7 @@ def cumsum( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``cumsum`` along some dimension(s). @@ -1078,7 +1078,7 @@ def cumsum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -1093,14 +1093,14 @@ def cumsum( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.cumsum() Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 3.0 6.0 7.0 9.0 9.0 + da (time) float64 1.0 3.0 6.0 6.0 8.0 8.0 Use ``skipna`` to control whether NaNs are ignored. @@ -1109,7 +1109,7 @@ def cumsum( Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 3.0 6.0 7.0 9.0 nan + da (time) float64 1.0 3.0 6.0 6.0 8.0 nan """ return self.reduce( duck_array_ops.cumsum, @@ -1127,7 +1127,7 @@ def cumprod( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``cumprod`` along some dimension(s). @@ -1171,7 +1171,7 @@ def cumprod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -1186,14 +1186,14 @@ def cumprod( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.cumprod() Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 6.0 6.0 12.0 12.0 + da (time) float64 1.0 2.0 6.0 0.0 0.0 0.0 Use ``skipna`` to control whether NaNs are ignored. @@ -1202,7 +1202,7 @@ def cumprod( Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 6.0 6.0 12.0 nan + da (time) float64 1.0 2.0 6.0 0.0 0.0 nan """ return self.reduce( duck_array_ops.cumprod, @@ -1226,7 +1226,7 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> DataArray: + ) -> Self: raise NotImplementedError() def count( @@ -1235,7 +1235,7 @@ def count( *, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``count`` along some dimension(s). @@ -1261,8 +1261,8 @@ def count( See Also -------- - numpy.count - dask.array.count + pandas.DataFrame.count + dask.dataframe.DataFrame.count Dataset.count :ref:`agg` User guide on reduction or aggregation operations. @@ -1270,7 +1270,7 @@ def count( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -1279,7 +1279,7 @@ def count( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``all`` along some dimension(s). @@ -1367,7 +1367,7 @@ def any( *, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``any`` along some dimension(s). @@ -1434,7 +1434,7 @@ def max( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``max`` along some dimension(s). @@ -1474,7 +1474,7 @@ def max( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -1483,7 +1483,7 @@ def max( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``min`` along some dimension(s). @@ -1553,7 +1553,7 @@ def min( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -1562,14 +1562,14 @@ def min( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.min() - array(1.) + array(0.) Use ``skipna`` to control whether NaNs are ignored. @@ -1592,7 +1592,7 @@ def mean( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``mean`` along some dimension(s). @@ -1636,7 +1636,7 @@ def mean( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -1645,14 +1645,14 @@ def mean( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.mean() - array(1.8) + array(1.6) Use ``skipna`` to control whether NaNs are ignored. @@ -1676,7 +1676,7 @@ def prod( min_count: int | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``prod`` along some dimension(s). @@ -1726,7 +1726,7 @@ def prod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -1735,14 +1735,14 @@ def prod( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.prod() - array(12.) + array(0.) Use ``skipna`` to control whether NaNs are ignored. @@ -1754,7 +1754,7 @@ def prod( >>> da.prod(skipna=True, min_count=2) - array(12.) + array(0.) """ return self.reduce( duck_array_ops.prod, @@ -1773,7 +1773,7 @@ def sum( min_count: int | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``sum`` along some dimension(s). @@ -1823,7 +1823,7 @@ def sum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -1832,14 +1832,14 @@ def sum( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.sum() - array(9.) + array(8.) Use ``skipna`` to control whether NaNs are ignored. @@ -1851,7 +1851,7 @@ def sum( >>> da.sum(skipna=True, min_count=2) - array(9.) + array(8.) """ return self.reduce( duck_array_ops.sum, @@ -1870,7 +1870,7 @@ def std( ddof: int = 0, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``std`` along some dimension(s). @@ -1917,7 +1917,7 @@ def std( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -1926,14 +1926,14 @@ def std( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.std() - array(0.74833148) + array(1.0198039) Use ``skipna`` to control whether NaNs are ignored. @@ -1945,7 +1945,7 @@ def std( >>> da.std(skipna=True, ddof=1) - array(0.83666003) + array(1.14017543) """ return self.reduce( duck_array_ops.std, @@ -1964,7 +1964,7 @@ def var( ddof: int = 0, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``var`` along some dimension(s). @@ -2011,7 +2011,7 @@ def var( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -2020,14 +2020,14 @@ def var( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.var() - array(0.56) + array(1.04) Use ``skipna`` to control whether NaNs are ignored. @@ -2039,7 +2039,7 @@ def var( >>> da.var(skipna=True, ddof=1) - array(0.7) + array(1.3) """ return self.reduce( duck_array_ops.var, @@ -2057,7 +2057,7 @@ def median( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``median`` along some dimension(s). @@ -2101,7 +2101,7 @@ def median( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -2110,7 +2110,7 @@ def median( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``cumsum`` along some dimension(s). @@ -2184,7 +2184,7 @@ def cumsum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -2193,14 +2193,14 @@ def cumsum( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.cumsum() - array([1., 3., 6., 7., 9., 9.]) + array([1., 3., 6., 6., 8., 8.]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.cumsum(skipna=False) - array([ 1., 3., 6., 7., 9., nan]) + array([ 1., 3., 6., 6., 8., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``cumprod`` along some dimension(s). @@ -2273,7 +2273,7 @@ def cumprod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -2282,14 +2282,14 @@ def cumprod( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.cumprod() - array([ 1., 2., 6., 6., 12., 12.]) + array([1., 2., 6., 0., 0., 0.]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.cumprod(skipna=False) - array([ 1., 2., 6., 6., 12., nan]) + array([ 1., 2., 6., 0., 0., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -2400,7 +2400,7 @@ def count( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.groupby("labels").count() @@ -2685,7 +2685,7 @@ def max( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -2700,7 +2700,7 @@ def max( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.groupby("labels").max() @@ -2801,7 +2801,7 @@ def min( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -2816,7 +2816,7 @@ def min( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.groupby("labels").min() @@ -2824,7 +2824,7 @@ def min( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 1.0 2.0 1.0 + da (labels) float64 1.0 2.0 0.0 Use ``skipna`` to control whether NaNs are ignored. @@ -2834,7 +2834,7 @@ def min( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 nan 2.0 1.0 + da (labels) float64 nan 2.0 0.0 """ if ( flox_available @@ -2919,7 +2919,7 @@ def mean( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -2934,7 +2934,7 @@ def mean( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.groupby("labels").mean() @@ -2942,7 +2942,7 @@ def mean( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 1.0 2.0 2.0 + da (labels) float64 1.0 2.0 1.5 Use ``skipna`` to control whether NaNs are ignored. @@ -2952,7 +2952,7 @@ def mean( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 nan 2.0 2.0 + da (labels) float64 nan 2.0 1.5 """ if ( flox_available @@ -3044,7 +3044,7 @@ def prod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -3059,7 +3059,7 @@ def prod( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.groupby("labels").prod() @@ -3067,7 +3067,7 @@ def prod( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 1.0 4.0 3.0 + da (labels) float64 1.0 4.0 0.0 Use ``skipna`` to control whether NaNs are ignored. @@ -3077,7 +3077,7 @@ def prod( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 nan 4.0 3.0 + da (labels) float64 nan 4.0 0.0 Specify ``min_count`` for finer control over when NaNs are ignored. @@ -3087,7 +3087,7 @@ def prod( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 nan 4.0 3.0 + da (labels) float64 nan 4.0 0.0 """ if ( flox_available @@ -3181,7 +3181,7 @@ def sum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -3196,7 +3196,7 @@ def sum( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.groupby("labels").sum() @@ -3204,7 +3204,7 @@ def sum( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 1.0 4.0 4.0 + da (labels) float64 1.0 4.0 3.0 Use ``skipna`` to control whether NaNs are ignored. @@ -3214,7 +3214,7 @@ def sum( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 nan 4.0 4.0 + da (labels) float64 nan 4.0 3.0 Specify ``min_count`` for finer control over when NaNs are ignored. @@ -3224,7 +3224,7 @@ def sum( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 nan 4.0 4.0 + da (labels) float64 nan 4.0 3.0 """ if ( flox_available @@ -3315,7 +3315,7 @@ def std( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -3330,7 +3330,7 @@ def std( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.groupby("labels").std() @@ -3338,7 +3338,7 @@ def std( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 0.0 0.0 1.0 + da (labels) float64 0.0 0.0 1.5 Use ``skipna`` to control whether NaNs are ignored. @@ -3348,7 +3348,7 @@ def std( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 nan 0.0 1.0 + da (labels) float64 nan 0.0 1.5 Specify ``ddof=1`` for an unbiased estimate. @@ -3358,7 +3358,7 @@ def std( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 nan 0.0 1.414 + da (labels) float64 nan 0.0 2.121 """ if ( flox_available @@ -3449,7 +3449,7 @@ def var( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -3464,7 +3464,7 @@ def var( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.groupby("labels").var() @@ -3472,7 +3472,7 @@ def var( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 0.0 0.0 1.0 + da (labels) float64 0.0 0.0 2.25 Use ``skipna`` to control whether NaNs are ignored. @@ -3482,7 +3482,7 @@ def var( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 nan 0.0 1.0 + da (labels) float64 nan 0.0 2.25 Specify ``ddof=1`` for an unbiased estimate. @@ -3492,7 +3492,7 @@ def var( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 nan 0.0 2.0 + da (labels) float64 nan 0.0 4.5 """ if ( flox_available @@ -3579,7 +3579,7 @@ def median( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -3594,7 +3594,7 @@ def median( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.groupby("labels").median() @@ -3602,7 +3602,7 @@ def median( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 1.0 2.0 2.0 + da (labels) float64 1.0 2.0 1.5 Use ``skipna`` to control whether NaNs are ignored. @@ -3612,7 +3612,7 @@ def median( Coordinates: * labels (labels) object 'a' 'b' 'c' Data variables: - da (labels) float64 nan 2.0 2.0 + da (labels) float64 nan 2.0 1.5 """ return self.reduce( duck_array_ops.median, @@ -3682,7 +3682,7 @@ def cumsum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -3697,14 +3697,14 @@ def cumsum( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.groupby("labels").cumsum() Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 3.0 4.0 4.0 1.0 + da (time) float64 1.0 2.0 3.0 3.0 4.0 1.0 Use ``skipna`` to control whether NaNs are ignored. @@ -3713,7 +3713,7 @@ def cumsum( Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 3.0 4.0 4.0 nan + da (time) float64 1.0 2.0 3.0 3.0 4.0 nan """ return self.reduce( duck_array_ops.cumsum, @@ -3783,7 +3783,7 @@ def cumprod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -3798,14 +3798,14 @@ def cumprod( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.groupby("labels").cumprod() Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 3.0 3.0 4.0 1.0 + da (time) float64 1.0 2.0 3.0 0.0 4.0 1.0 Use ``skipna`` to control whether NaNs are ignored. @@ -3814,7 +3814,7 @@ def cumprod( Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 3.0 3.0 4.0 nan + da (time) float64 1.0 2.0 3.0 0.0 4.0 nan """ return self.reduce( duck_array_ops.cumprod, @@ -3881,8 +3881,8 @@ def count( See Also -------- - numpy.count - dask.array.count + pandas.DataFrame.count + dask.dataframe.DataFrame.count Dataset.count :ref:`resampling` User guide on resampling operations. @@ -3899,7 +3899,7 @@ def count( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -3914,7 +3914,7 @@ def count( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.resample(time="3M").count() @@ -4199,7 +4199,7 @@ def max( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -4214,7 +4214,7 @@ def max( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.resample(time="3M").max() @@ -4315,7 +4315,7 @@ def min( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -4330,7 +4330,7 @@ def min( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.resample(time="3M").min() @@ -4338,7 +4338,7 @@ def min( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 1.0 2.0 + da (time) float64 1.0 0.0 2.0 Use ``skipna`` to control whether NaNs are ignored. @@ -4348,7 +4348,7 @@ def min( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 1.0 nan + da (time) float64 1.0 0.0 nan """ if ( flox_available @@ -4433,7 +4433,7 @@ def mean( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -4448,7 +4448,7 @@ def mean( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.resample(time="3M").mean() @@ -4456,7 +4456,7 @@ def mean( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 2.0 2.0 + da (time) float64 1.0 1.667 2.0 Use ``skipna`` to control whether NaNs are ignored. @@ -4466,7 +4466,7 @@ def mean( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 2.0 nan + da (time) float64 1.0 1.667 nan """ if ( flox_available @@ -4558,7 +4558,7 @@ def prod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -4573,7 +4573,7 @@ def prod( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.resample(time="3M").prod() @@ -4581,7 +4581,7 @@ def prod( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 6.0 2.0 + da (time) float64 1.0 0.0 2.0 Use ``skipna`` to control whether NaNs are ignored. @@ -4591,7 +4591,7 @@ def prod( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 6.0 nan + da (time) float64 1.0 0.0 nan Specify ``min_count`` for finer control over when NaNs are ignored. @@ -4601,7 +4601,7 @@ def prod( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 nan 6.0 nan + da (time) float64 nan 0.0 nan """ if ( flox_available @@ -4695,7 +4695,7 @@ def sum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -4710,7 +4710,7 @@ def sum( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.resample(time="3M").sum() @@ -4718,7 +4718,7 @@ def sum( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 6.0 2.0 + da (time) float64 1.0 5.0 2.0 Use ``skipna`` to control whether NaNs are ignored. @@ -4728,7 +4728,7 @@ def sum( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 6.0 nan + da (time) float64 1.0 5.0 nan Specify ``min_count`` for finer control over when NaNs are ignored. @@ -4738,7 +4738,7 @@ def sum( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 nan 6.0 nan + da (time) float64 nan 5.0 nan """ if ( flox_available @@ -4829,7 +4829,7 @@ def std( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -4844,7 +4844,7 @@ def std( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.resample(time="3M").std() @@ -4852,7 +4852,7 @@ def std( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 0.0 0.8165 0.0 + da (time) float64 0.0 1.247 0.0 Use ``skipna`` to control whether NaNs are ignored. @@ -4862,7 +4862,7 @@ def std( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 0.0 0.8165 nan + da (time) float64 0.0 1.247 nan Specify ``ddof=1`` for an unbiased estimate. @@ -4872,7 +4872,7 @@ def std( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 nan 1.0 nan + da (time) float64 nan 1.528 nan """ if ( flox_available @@ -4963,7 +4963,7 @@ def var( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -4978,7 +4978,7 @@ def var( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.resample(time="3M").var() @@ -4986,7 +4986,7 @@ def var( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 0.0 0.6667 0.0 + da (time) float64 0.0 1.556 0.0 Use ``skipna`` to control whether NaNs are ignored. @@ -4996,7 +4996,7 @@ def var( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 0.0 0.6667 nan + da (time) float64 0.0 1.556 nan Specify ``ddof=1`` for an unbiased estimate. @@ -5006,7 +5006,7 @@ def var( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 nan 1.0 nan + da (time) float64 nan 2.333 nan """ if ( flox_available @@ -5093,7 +5093,7 @@ def median( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -5108,7 +5108,7 @@ def median( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.resample(time="3M").median() @@ -5196,7 +5196,7 @@ def cumsum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -5211,14 +5211,14 @@ def cumsum( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.resample(time="3M").cumsum() Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 5.0 6.0 2.0 2.0 + da (time) float64 1.0 2.0 5.0 5.0 2.0 2.0 Use ``skipna`` to control whether NaNs are ignored. @@ -5227,7 +5227,7 @@ def cumsum( Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 5.0 6.0 2.0 nan + da (time) float64 1.0 2.0 5.0 5.0 2.0 nan """ return self.reduce( duck_array_ops.cumsum, @@ -5297,7 +5297,7 @@ def cumprod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -5312,14 +5312,14 @@ def cumprod( * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> ds.resample(time="3M").cumprod() Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 6.0 6.0 2.0 2.0 + da (time) float64 1.0 2.0 6.0 0.0 2.0 2.0 Use ``skipna`` to control whether NaNs are ignored. @@ -5328,7 +5328,7 @@ def cumprod( Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 6.0 6.0 2.0 nan + da (time) float64 1.0 2.0 6.0 0.0 2.0 nan """ return self.reduce( duck_array_ops.cumprod, @@ -5395,8 +5395,8 @@ def count( See Also -------- - numpy.count - dask.array.count + pandas.DataFrame.count + dask.dataframe.DataFrame.count DataArray.count :ref:`groupby` User guide on groupby operations. @@ -5413,7 +5413,7 @@ def count( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -5422,7 +5422,7 @@ def count( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -5701,7 +5701,7 @@ def max( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -5808,14 +5808,14 @@ def min( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.groupby("labels").min() - array([1., 2., 1.]) + array([1., 2., 0.]) Coordinates: * labels (labels) object 'a' 'b' 'c' @@ -5823,7 +5823,7 @@ def min( >>> da.groupby("labels").min(skipna=False) - array([nan, 2., 1.]) + array([nan, 2., 0.]) Coordinates: * labels (labels) object 'a' 'b' 'c' """ @@ -5908,7 +5908,7 @@ def mean( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -5917,14 +5917,14 @@ def mean( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.groupby("labels").mean() - array([1., 2., 2.]) + array([1. , 2. , 1.5]) Coordinates: * labels (labels) object 'a' 'b' 'c' @@ -5932,7 +5932,7 @@ def mean( >>> da.groupby("labels").mean(skipna=False) - array([nan, 2., 2.]) + array([nan, 2. , 1.5]) Coordinates: * labels (labels) object 'a' 'b' 'c' """ @@ -6024,7 +6024,7 @@ def prod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -6033,14 +6033,14 @@ def prod( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.groupby("labels").prod() - array([1., 4., 3.]) + array([1., 4., 0.]) Coordinates: * labels (labels) object 'a' 'b' 'c' @@ -6048,7 +6048,7 @@ def prod( >>> da.groupby("labels").prod(skipna=False) - array([nan, 4., 3.]) + array([nan, 4., 0.]) Coordinates: * labels (labels) object 'a' 'b' 'c' @@ -6056,7 +6056,7 @@ def prod( >>> da.groupby("labels").prod(skipna=True, min_count=2) - array([nan, 4., 3.]) + array([nan, 4., 0.]) Coordinates: * labels (labels) object 'a' 'b' 'c' """ @@ -6150,7 +6150,7 @@ def sum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -6159,14 +6159,14 @@ def sum( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.groupby("labels").sum() - array([1., 4., 4.]) + array([1., 4., 3.]) Coordinates: * labels (labels) object 'a' 'b' 'c' @@ -6174,7 +6174,7 @@ def sum( >>> da.groupby("labels").sum(skipna=False) - array([nan, 4., 4.]) + array([nan, 4., 3.]) Coordinates: * labels (labels) object 'a' 'b' 'c' @@ -6182,7 +6182,7 @@ def sum( >>> da.groupby("labels").sum(skipna=True, min_count=2) - array([nan, 4., 4.]) + array([nan, 4., 3.]) Coordinates: * labels (labels) object 'a' 'b' 'c' """ @@ -6273,7 +6273,7 @@ def std( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -6282,14 +6282,14 @@ def std( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.groupby("labels").std() - array([0., 0., 1.]) + array([0. , 0. , 1.5]) Coordinates: * labels (labels) object 'a' 'b' 'c' @@ -6297,7 +6297,7 @@ def std( >>> da.groupby("labels").std(skipna=False) - array([nan, 0., 1.]) + array([nan, 0. , 1.5]) Coordinates: * labels (labels) object 'a' 'b' 'c' @@ -6305,7 +6305,7 @@ def std( >>> da.groupby("labels").std(skipna=True, ddof=1) - array([ nan, 0. , 1.41421356]) + array([ nan, 0. , 2.12132034]) Coordinates: * labels (labels) object 'a' 'b' 'c' """ @@ -6396,7 +6396,7 @@ def var( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -6405,14 +6405,14 @@ def var( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.groupby("labels").var() - array([0., 0., 1.]) + array([0. , 0. , 2.25]) Coordinates: * labels (labels) object 'a' 'b' 'c' @@ -6420,7 +6420,7 @@ def var( >>> da.groupby("labels").var(skipna=False) - array([nan, 0., 1.]) + array([ nan, 0. , 2.25]) Coordinates: * labels (labels) object 'a' 'b' 'c' @@ -6428,7 +6428,7 @@ def var( >>> da.groupby("labels").var(skipna=True, ddof=1) - array([nan, 0., 2.]) + array([nan, 0. , 4.5]) Coordinates: * labels (labels) object 'a' 'b' 'c' """ @@ -6515,7 +6515,7 @@ def median( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -6524,14 +6524,14 @@ def median( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.groupby("labels").median() - array([1., 2., 2.]) + array([1. , 2. , 1.5]) Coordinates: * labels (labels) object 'a' 'b' 'c' @@ -6539,7 +6539,7 @@ def median( >>> da.groupby("labels").median(skipna=False) - array([nan, 2., 2.]) + array([nan, 2. , 1.5]) Coordinates: * labels (labels) object 'a' 'b' 'c' """ @@ -6610,7 +6610,7 @@ def cumsum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -6619,14 +6619,14 @@ def cumsum( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.groupby("labels").cumsum() - array([1., 2., 3., 4., 4., 1.]) + array([1., 2., 3., 3., 4., 1.]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.groupby("labels").cumsum(skipna=False) - array([ 1., 2., 3., 4., 4., nan]) + array([ 1., 2., 3., 3., 4., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -6716,14 +6716,14 @@ def cumprod( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.groupby("labels").cumprod() - array([1., 2., 3., 3., 4., 1.]) + array([1., 2., 3., 0., 4., 1.]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.groupby("labels").cumprod(skipna=False) - array([ 1., 2., 3., 3., 4., nan]) + array([ 1., 2., 3., 0., 4., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -6828,7 +6828,7 @@ def count( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -7107,7 +7107,7 @@ def max( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -7214,14 +7214,14 @@ def min( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.resample(time="3M").min() - array([1., 1., 2.]) + array([1., 0., 2.]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 @@ -7229,7 +7229,7 @@ def min( >>> da.resample(time="3M").min(skipna=False) - array([ 1., 1., nan]) + array([ 1., 0., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ @@ -7314,7 +7314,7 @@ def mean( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -7323,14 +7323,14 @@ def mean( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.resample(time="3M").mean() - array([1., 2., 2.]) + array([1. , 1.66666667, 2. ]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 @@ -7338,7 +7338,7 @@ def mean( >>> da.resample(time="3M").mean(skipna=False) - array([ 1., 2., nan]) + array([1. , 1.66666667, nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ @@ -7430,7 +7430,7 @@ def prod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -7439,14 +7439,14 @@ def prod( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.resample(time="3M").prod() - array([1., 6., 2.]) + array([1., 0., 2.]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 @@ -7454,7 +7454,7 @@ def prod( >>> da.resample(time="3M").prod(skipna=False) - array([ 1., 6., nan]) + array([ 1., 0., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 @@ -7462,7 +7462,7 @@ def prod( >>> da.resample(time="3M").prod(skipna=True, min_count=2) - array([nan, 6., nan]) + array([nan, 0., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ @@ -7556,7 +7556,7 @@ def sum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -7565,14 +7565,14 @@ def sum( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.resample(time="3M").sum() - array([1., 6., 2.]) + array([1., 5., 2.]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 @@ -7580,7 +7580,7 @@ def sum( >>> da.resample(time="3M").sum(skipna=False) - array([ 1., 6., nan]) + array([ 1., 5., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 @@ -7588,7 +7588,7 @@ def sum( >>> da.resample(time="3M").sum(skipna=True, min_count=2) - array([nan, 6., nan]) + array([nan, 5., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ @@ -7679,7 +7679,7 @@ def std( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -7688,14 +7688,14 @@ def std( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.resample(time="3M").std() - array([0. , 0.81649658, 0. ]) + array([0. , 1.24721913, 0. ]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 @@ -7703,7 +7703,7 @@ def std( >>> da.resample(time="3M").std(skipna=False) - array([0. , 0.81649658, nan]) + array([0. , 1.24721913, nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 @@ -7711,7 +7711,7 @@ def std( >>> da.resample(time="3M").std(skipna=True, ddof=1) - array([nan, 1., nan]) + array([ nan, 1.52752523, nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ @@ -7802,7 +7802,7 @@ def var( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -7811,14 +7811,14 @@ def var( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.resample(time="3M").var() - array([0. , 0.66666667, 0. ]) + array([0. , 1.55555556, 0. ]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 @@ -7826,7 +7826,7 @@ def var( >>> da.resample(time="3M").var(skipna=False) - array([0. , 0.66666667, nan]) + array([0. , 1.55555556, nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 @@ -7834,7 +7834,7 @@ def var( >>> da.resample(time="3M").var(skipna=True, ddof=1) - array([nan, 1., nan]) + array([ nan, 2.33333333, nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ @@ -7921,7 +7921,7 @@ def median( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -7930,7 +7930,7 @@ def median( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -8025,14 +8025,14 @@ def cumsum( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.resample(time="3M").cumsum() - array([1., 2., 5., 6., 2., 2.]) + array([1., 2., 5., 5., 2., 2.]) Coordinates: labels (time) >> da.resample(time="3M").cumsum(skipna=False) - array([ 1., 2., 5., 6., 2., nan]) + array([ 1., 2., 5., 5., 2., nan]) Coordinates: labels (time) >> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), @@ -8122,14 +8122,14 @@ def cumprod( ... ) >>> da - array([ 1., 2., 3., 1., 2., nan]) + array([ 1., 2., 3., 0., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) >> da.resample(time="3M").cumprod() - array([1., 2., 6., 6., 2., 2.]) + array([1., 2., 6., 0., 2., 2.]) Coordinates: labels (time) >> da.resample(time="3M").cumprod(skipna=False) - array([ 1., 2., 6., 6., 2., nan]) + array([ 1., 2., 6., 0., 2., nan]) Coordinates: labels (time) Self: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: DsCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: DsCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: DsCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ior) - def __ilshift__(self, other): + def __ilshift__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ilshift) - def __irshift__(self, other): + def __irshift__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.irshift) - def _unary_op(self, f, *args, **kwargs): + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -215,157 +232,159 @@ def conjugate(self, *args, **kwargs): class DataArrayOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: DaCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: DaCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: DaCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ior) - def __ilshift__(self, other): + def __ilshift__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ilshift) - def __irshift__(self, other): + def __irshift__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.irshift) - def _unary_op(self, f, *args, **kwargs): + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -421,157 +440,303 @@ def conjugate(self, *args, **kwargs): class VariableOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: VarCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + @overload + def __add__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __add__(self, other: VarCompatible) -> Self: + ... + + def __add__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.add) - def __sub__(self, other): + @overload + def __sub__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __sub__(self, other: VarCompatible) -> Self: + ... + + def __sub__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.sub) - def __mul__(self, other): + @overload + def __mul__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __mul__(self, other: VarCompatible) -> Self: + ... + + def __mul__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.mul) - def __pow__(self, other): + @overload + def __pow__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __pow__(self, other: VarCompatible) -> Self: + ... + + def __pow__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + @overload + def __truediv__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __truediv__(self, other: VarCompatible) -> Self: + ... + + def __truediv__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + @overload + def __floordiv__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __floordiv__(self, other: VarCompatible) -> Self: + ... + + def __floordiv__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + @overload + def __mod__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __mod__(self, other: VarCompatible) -> Self: + ... + + def __mod__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.mod) - def __and__(self, other): + @overload + def __and__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __and__(self, other: VarCompatible) -> Self: + ... + + def __and__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.and_) - def __xor__(self, other): + @overload + def __xor__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __xor__(self, other: VarCompatible) -> Self: + ... + + def __xor__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.xor) - def __or__(self, other): + @overload + def __or__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __or__(self, other: VarCompatible) -> Self: + ... + + def __or__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + @overload + def __lshift__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __lshift__(self, other: VarCompatible) -> Self: + ... + + def __lshift__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + @overload + def __rshift__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __rshift__(self, other: VarCompatible) -> Self: + ... + + def __rshift__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + @overload + def __lt__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __lt__(self, other: VarCompatible) -> Self: + ... + + def __lt__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.lt) - def __le__(self, other): + @overload + def __le__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __le__(self, other: VarCompatible) -> Self: + ... + + def __le__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.le) - def __gt__(self, other): + @overload + def __gt__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __gt__(self, other: VarCompatible) -> Self: + ... + + def __gt__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.gt) - def __ge__(self, other): + @overload + def __ge__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __ge__(self, other: VarCompatible) -> Self: + ... + + def __ge__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.ge) - def __eq__(self, other): + @overload # type:ignore[override] + def __eq__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __eq__(self, other: VarCompatible) -> Self: + ... + + def __eq__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + @overload # type:ignore[override] + def __ne__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __ne__(self, other: VarCompatible) -> Self: + ... + + def __ne__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: VarCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ior) - def __ilshift__(self, other): + def __ilshift__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ilshift) - def __irshift__(self, other): + def __irshift__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.irshift) - def _unary_op(self, f, *args, **kwargs): + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -627,91 +792,93 @@ def conjugate(self, *args, **kwargs): class DatasetGroupByOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: GroupByCompatible, f: Callable, reflexive: bool = False + ) -> Dataset: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ @@ -747,91 +914,93 @@ def __ror__(self, other): class DataArrayGroupByOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: T_Xarray, f: Callable, reflexive: bool = False + ) -> T_Xarray: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi deleted file mode 100644 index 9e2ba2d3a06..00000000000 --- a/xarray/core/_typed_ops.pyi +++ /dev/null @@ -1,782 +0,0 @@ -"""Stub file for mixin classes with arithmetic operators.""" -# This file was generated using xarray.util.generate_ops. Do not edit manually. - -from typing import NoReturn, TypeVar, overload - -import numpy as np -from numpy.typing import ArrayLike - -from .dataarray import DataArray -from .dataset import Dataset -from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy -from .types import ( - DaCompatible, - DsCompatible, - GroupByIncompatible, - ScalarOrArray, - VarCompatible, -) -from .variable import Variable - -try: - from dask.array import Array as DaskArray -except ImportError: - DaskArray = np.ndarray # type: ignore - -# DatasetOpsMixin etc. are parent classes of Dataset etc. -# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally -# we use the ones in `types`. (We're open to refining this, and potentially integrating -# the `py` & `pyi` files to simplify them.) -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") - -class DatasetOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - def __add__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __sub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __mul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __pow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __truediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __floordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __mod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __and__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __xor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __or__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __lshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __lt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __le__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __gt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __ge__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __eq__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override] - def __ne__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override] - def __radd__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rsub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rmul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rpow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rtruediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rfloordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rmod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rand__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rxor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __ror__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_Dataset) -> T_Dataset: ... - def __pos__(self: T_Dataset) -> T_Dataset: ... - def __abs__(self: T_Dataset) -> T_Dataset: ... - def __invert__(self: T_Dataset) -> T_Dataset: ... - def round(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def argsort(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def conj(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def conjugate(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - -class DataArrayOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __add__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __sub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __mul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __pow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __truediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __floordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __mod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __and__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __xor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __or__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __lshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __lt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __le__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __gt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ge__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __eq__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ne__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __radd__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rsub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rmul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rpow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rtruediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rfloordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rmod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rand__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rxor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ror__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_DataArray) -> T_DataArray: ... - def __pos__(self: T_DataArray) -> T_DataArray: ... - def __abs__(self: T_DataArray) -> T_DataArray: ... - def __invert__(self: T_DataArray) -> T_DataArray: ... - def round(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def argsort(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def conj(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def conjugate(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - -class VariableOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __add__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __sub__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mul__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __pow__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __truediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __floordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mod__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __and__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __xor__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __or__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lt__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __le__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __gt__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ge__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __eq__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ne__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __radd__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rsub__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmul__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rpow__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rtruediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rfloordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmod__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rand__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rxor__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ror__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_Variable) -> T_Variable: ... - def __pos__(self: T_Variable) -> T_Variable: ... - def __abs__(self: T_Variable) -> T_Variable: ... - def __invert__(self: T_Variable) -> T_Variable: ... - def round(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def argsort(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def conj(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def conjugate(self: T_Variable, *args, **kwargs) -> T_Variable: ... - -class DatasetGroupByOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: "DataArray") -> "Dataset": ... - @overload - def __add__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: "DataArray") -> "Dataset": ... - @overload - def __sub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: "DataArray") -> "Dataset": ... - @overload - def __mul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: "DataArray") -> "Dataset": ... - @overload - def __pow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: "DataArray") -> "Dataset": ... - @overload - def __mod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: "DataArray") -> "Dataset": ... - @overload - def __and__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: "DataArray") -> "Dataset": ... - @overload - def __xor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: "DataArray") -> "Dataset": ... - @overload - def __or__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: "DataArray") -> "Dataset": ... - @overload - def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: "DataArray") -> "Dataset": ... - @overload - def __lt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: "DataArray") -> "Dataset": ... - @overload - def __le__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: "DataArray") -> "Dataset": ... - @overload - def __gt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ge__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: "DataArray") -> "Dataset": ... - @overload - def __eq__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ne__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: "DataArray") -> "Dataset": ... - @overload - def __radd__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rand__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ror__(self, other: GroupByIncompatible) -> NoReturn: ... - -class DataArrayGroupByOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __add__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __sub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __pow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __and__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __xor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __or__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __le__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __gt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ge__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __eq__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ne__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __radd__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rand__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ror__(self, other: GroupByIncompatible) -> NoReturn: ... diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py index 77976fe610e..0d4a402cd19 100644 --- a/xarray/core/accessor_dt.py +++ b/xarray/core/accessor_dt.py @@ -7,6 +7,7 @@ import pandas as pd from xarray.coding.times import infer_calendar_name +from xarray.core import duck_array_ops from xarray.core.common import ( _contains_datetime_like_objects, is_np_datetime_like, @@ -50,7 +51,7 @@ def _access_through_cftimeindex(values, name): from xarray.coding.cftimeindex import CFTimeIndex if not isinstance(values, CFTimeIndex): - values_as_cftimeindex = CFTimeIndex(values.ravel()) + values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values)) else: values_as_cftimeindex = values if name == "season": @@ -69,16 +70,30 @@ def _access_through_series(values, name): """Coerce an array of datetime-like values to a pandas Series and access requested datetime component """ - values_as_series = pd.Series(values.ravel(), copy=False) + values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False) if name == "season": months = values_as_series.dt.month.values field_values = _season_from_months(months) elif name == "isocalendar": + # special NaT-handling can be removed when + # https://github.com/pandas-dev/pandas/issues/54657 is resolved + field_values = values_as_series.dt.isocalendar() + # test for and apply needed dtype + hasna = any(field_values.year.isnull()) + if hasna: + field_values = np.dstack( + [ + getattr(field_values, name).astype(np.float64, copy=False).values + for name in ["year", "week", "day"] + ] + ) + else: + field_values = np.array(field_values, dtype=np.int64) # isocalendar returns iso- year, week, and weekday -> reshape - field_values = np.array(values_as_series.dt.isocalendar(), dtype=np.int64) return field_values.T.reshape(3, *values.shape) else: field_values = getattr(values_as_series.dt, name).values + return field_values.reshape(values.shape) @@ -110,7 +125,7 @@ def _get_date_field(values, name, dtype): from dask.array import map_blocks new_axis = chunks = None - # isocalendar adds adds an axis + # isocalendar adds an axis if name == "isocalendar": chunks = (3,) + values.chunksize new_axis = 0 @@ -119,7 +134,12 @@ def _get_date_field(values, name, dtype): access_method, values, name, dtype=dtype, new_axis=new_axis, chunks=chunks ) else: - return access_method(values, name).astype(dtype, copy=False) + out = access_method(values, name) + # cast only for integer types to keep float64 in presence of NaT + # see https://github.com/pydata/xarray/issues/7928 + if np.issubdtype(out.dtype, np.integer): + out = out.astype(dtype, copy=False) + return out def _round_through_series_or_index(values, name, freq): @@ -129,10 +149,10 @@ def _round_through_series_or_index(values, name, freq): from xarray.coding.cftimeindex import CFTimeIndex if is_np_datetime_like(values.dtype): - values_as_series = pd.Series(values.ravel(), copy=False) + values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False) method = getattr(values_as_series.dt, name) else: - values_as_cftimeindex = CFTimeIndex(values.ravel()) + values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values)) method = getattr(values_as_cftimeindex, name) field_values = method(freq=freq).values @@ -176,7 +196,7 @@ def _strftime_through_cftimeindex(values, date_format: str): """ from xarray.coding.cftimeindex import CFTimeIndex - values_as_cftimeindex = CFTimeIndex(values.ravel()) + values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values)) field_values = values_as_cftimeindex.strftime(date_format) return field_values.values.reshape(values.shape) @@ -186,7 +206,7 @@ def _strftime_through_series(values, date_format: str): """Coerce an array of datetime-like values to a pandas Series and apply string formatting """ - values_as_series = pd.Series(values.ravel(), copy=False) + values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False) strs = values_as_series.dt.strftime(date_format) return strs.values.reshape(values.shape) @@ -581,7 +601,7 @@ class CombinedDatetimelikeAccessor( DatetimeAccessor[T_DataArray], TimedeltaAccessor[T_DataArray] ): def __new__(cls, obj: T_DataArray) -> CombinedDatetimelikeAccessor: - # CombinedDatetimelikeAccessor isn't really instatiated. Instead + # CombinedDatetimelikeAccessor isn't really instantiated. Instead # we need to choose which parent (datetime or timedelta) is # appropriate. Since we're checking the dtypes anyway, we'll just # do all the validation here. diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 31028f10350..573200b5c88 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -471,7 +471,7 @@ def cat(self, *others, sep: str | bytes | Any = "") -> T_DataArray: ... ) >>> values_2 = np.array(3.4) >>> values_3 = "" - >>> values_4 = np.array("test", dtype=np.unicode_) + >>> values_4 = np.array("test", dtype=np.str_) Determine the separator to use @@ -2386,7 +2386,7 @@ def _partitioner( # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, axis=-1) # type: ignore[return-value] + return self._obj.copy().expand_dims({dim: 0}, axis=-1) arrfunc = lambda x, isep: np.array(func(x, isep), dtype=self._obj.dtype) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index edebccc2534..732ec5d3ea6 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -5,13 +5,12 @@ from collections import defaultdict from collections.abc import Hashable, Iterable, Mapping from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, Final, Generic, TypeVar, cast, overload import numpy as np import pandas as pd from xarray.core import dtypes -from xarray.core.common import DataWithCoords from xarray.core.indexes import ( Index, Indexes, @@ -20,15 +19,20 @@ indexes_all_equal, safe_cast_to_index, ) +from xarray.core.types import T_Alignable from xarray.core.utils import is_dict_like, is_full_slice from xarray.core.variable import Variable, as_compatible_data, calculate_dimensions if TYPE_CHECKING: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import JoinOptions, T_DataArray, T_Dataset, T_DataWithCoords - -DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords) + from xarray.core.types import ( + Alignable, + JoinOptions, + T_DataArray, + T_Dataset, + T_DuckArray, + ) def reindex_variables( @@ -92,7 +96,7 @@ def reindex_variables( NormalizedIndexVars = dict[MatchingIndexKey, dict[Hashable, Variable]] -class Aligner(Generic[DataAlignable]): +class Aligner(Generic[T_Alignable]): """Implements all the complex logic for the re-indexing and alignment of Xarray objects. @@ -105,8 +109,8 @@ class Aligner(Generic[DataAlignable]): """ - objects: tuple[DataAlignable, ...] - results: tuple[DataAlignable, ...] + objects: tuple[T_Alignable, ...] + results: tuple[T_Alignable, ...] objects_matching_indexes: tuple[dict[MatchingIndexKey, Index], ...] join: str exclude_dims: frozenset[Hashable] @@ -127,10 +131,10 @@ class Aligner(Generic[DataAlignable]): def __init__( self, - objects: Iterable[DataAlignable], + objects: Iterable[T_Alignable], join: str = "inner", indexes: Mapping[Any, Any] | None = None, - exclude_dims: Iterable = frozenset(), + exclude_dims: str | Iterable[Hashable] = frozenset(), exclude_vars: Iterable[Hashable] = frozenset(), method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, @@ -175,7 +179,7 @@ def __init__( def _normalize_indexes( self, - indexes: Mapping[Any, Any], + indexes: Mapping[Any, Any | T_DuckArray], ) -> tuple[NormalizedIndexes, NormalizedIndexVars]: """Normalize the indexes/indexers used for re-indexing or alignment. @@ -196,7 +200,7 @@ def _normalize_indexes( f"Indexer has dimensions {idx.dims} that are different " f"from that to be indexed along '{k}'" ) - data = as_compatible_data(idx) + data: T_DuckArray = as_compatible_data(idx) pd_idx = safe_cast_to_index(data) pd_idx.name = k if isinstance(pd_idx, pd.MultiIndex): @@ -510,7 +514,7 @@ def _get_dim_pos_indexers( def _get_indexes_and_vars( self, - obj: DataAlignable, + obj: T_Alignable, matching_indexes: dict[MatchingIndexKey, Index], ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: new_indexes = {} @@ -533,13 +537,13 @@ def _get_indexes_and_vars( def _reindex_one( self, - obj: DataAlignable, + obj: T_Alignable, matching_indexes: dict[MatchingIndexKey, Index], - ) -> DataAlignable: + ) -> T_Alignable: new_indexes, new_variables = self._get_indexes_and_vars(obj, matching_indexes) dim_pos_indexers = self._get_dim_pos_indexers(matching_indexes) - new_obj = obj._reindex_callback( + return obj._reindex_callback( self, dim_pos_indexers, new_variables, @@ -548,8 +552,6 @@ def _reindex_one( self.exclude_dims, self.exclude_vars, ) - new_obj.encoding = obj.encoding - return new_obj def reindex_all(self) -> None: self.results = tuple( @@ -580,14 +582,113 @@ def align(self) -> None: self.reindex_all() +T_Obj1 = TypeVar("T_Obj1", bound="Alignable") +T_Obj2 = TypeVar("T_Obj2", bound="Alignable") +T_Obj3 = TypeVar("T_Obj3", bound="Alignable") +T_Obj4 = TypeVar("T_Obj4", bound="Alignable") +T_Obj5 = TypeVar("T_Obj5", bound="Alignable") + + +@overload +def align( + obj1: T_Obj1, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1]: + ... + + +@overload +def align( + obj1: T_Obj1, + obj2: T_Obj2, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2]: + ... + + +@overload +def align( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2, T_Obj3]: + ... + + +@overload +def align( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]: + ... + + +@overload +def align( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + obj5: T_Obj5, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]: + ... + + +@overload def align( - *objects: DataAlignable, + *objects: T_Alignable, join: JoinOptions = "inner", copy: bool = True, indexes=None, - exclude=frozenset(), + exclude: str | Iterable[Hashable] = frozenset(), fill_value=dtypes.NA, -) -> tuple[DataAlignable, ...]: +) -> tuple[T_Alignable, ...]: + ... + + +def align( # type: ignore[misc] + *objects: T_Alignable, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Alignable, ...]: """ Given any number of Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -624,7 +725,7 @@ def align( indexes : dict-like, optional Any indexes explicitly provided with the `indexes` argument should be used in preference to the aligned indexes. - exclude : sequence of str, optional + exclude : str, iterable of hashable or None, optional Dimensions that must be excluded from alignment fill_value : scalar or dict-like, optional Value to use for newly missing values. If a dict-like, maps @@ -791,16 +892,17 @@ def align( def deep_align( objects: Iterable[Any], join: JoinOptions = "inner", - copy=True, + copy: bool = True, indexes=None, - exclude=frozenset(), - raise_on_invalid=True, + exclude: str | Iterable[Hashable] = frozenset(), + raise_on_invalid: bool = True, fill_value=dtypes.NA, -): +) -> list[Any]: """Align objects for merging, recursing into dictionary values. This function is not public API. """ + from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -808,14 +910,14 @@ def deep_align( indexes = {} def is_alignable(obj): - return isinstance(obj, (DataArray, Dataset)) - - positions = [] - keys = [] - out = [] - targets = [] - no_key = object() - not_replaced = object() + return isinstance(obj, (Coordinates, DataArray, Dataset)) + + positions: list[int] = [] + keys: list[type[object] | Hashable] = [] + out: list[Any] = [] + targets: list[Alignable] = [] + no_key: Final = object() + not_replaced: Final = object() for position, variables in enumerate(objects): if is_alignable(variables): positions.append(position) @@ -842,7 +944,7 @@ def is_alignable(obj): elif raise_on_invalid: raise ValueError( "object to align is neither an xarray.Dataset, " - "an xarray.DataArray nor a dictionary: {!r}".format(variables) + f"an xarray.DataArray nor a dictionary: {variables!r}" ) else: out.append(variables) @@ -860,13 +962,13 @@ def is_alignable(obj): if key is no_key: out[position] = aligned_obj else: - out[position][key] = aligned_obj # type: ignore[index] # maybe someone can fix this? + out[position][key] = aligned_obj return out def reindex( - obj: DataAlignable, + obj: T_Alignable, indexers: Mapping[Any, Any], method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, @@ -874,7 +976,7 @@ def reindex( fill_value: Any = dtypes.NA, sparse: bool = False, exclude_vars: Iterable[Hashable] = frozenset(), -) -> DataAlignable: +) -> T_Alignable: """Re-index either a Dataset or a DataArray. Not public API. @@ -905,13 +1007,13 @@ def reindex( def reindex_like( - obj: DataAlignable, + obj: T_Alignable, other: Dataset | DataArray, method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = dtypes.NA, -) -> DataAlignable: +) -> T_Alignable: """Re-index either a Dataset or a DataArray like another Dataset/DataArray. Not public API. @@ -953,8 +1055,8 @@ def _get_broadcast_dims_map_common_coords(args, exclude): def _broadcast_helper( - arg: T_DataWithCoords, exclude, dims_map, common_coords -) -> T_DataWithCoords: + arg: T_Alignable, exclude, dims_map, common_coords +) -> T_Alignable: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -984,16 +1086,76 @@ def _broadcast_dataset(ds: T_Dataset) -> T_Dataset: # remove casts once https://github.com/python/mypy/issues/12800 is resolved if isinstance(arg, DataArray): - return cast("T_DataWithCoords", _broadcast_array(arg)) + return cast(T_Alignable, _broadcast_array(arg)) elif isinstance(arg, Dataset): - return cast("T_DataWithCoords", _broadcast_dataset(arg)) + return cast(T_Alignable, _broadcast_dataset(arg)) else: raise ValueError("all input must be Dataset or DataArray objects") -# TODO: this typing is too restrictive since it cannot deal with mixed -# DataArray and Dataset types...? Is this a problem? -def broadcast(*args: T_DataWithCoords, exclude=None) -> tuple[T_DataWithCoords, ...]: +@overload +def broadcast( + obj1: T_Obj1, /, *, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Obj1]: + ... + + +@overload +def broadcast( + obj1: T_Obj1, obj2: T_Obj2, /, *, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Obj1, T_Obj2]: + ... + + +@overload +def broadcast( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + /, + *, + exclude: str | Iterable[Hashable] | None = None, +) -> tuple[T_Obj1, T_Obj2, T_Obj3]: + ... + + +@overload +def broadcast( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + /, + *, + exclude: str | Iterable[Hashable] | None = None, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]: + ... + + +@overload +def broadcast( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + obj5: T_Obj5, + /, + *, + exclude: str | Iterable[Hashable] | None = None, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]: + ... + + +@overload +def broadcast( + *args: T_Alignable, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Alignable, ...]: + ... + + +def broadcast( # type: ignore[misc] + *args: T_Alignable, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Alignable, ...]: """Explicitly broadcast any number of DataArray or Dataset objects against one another. @@ -1007,7 +1169,7 @@ def broadcast(*args: T_DataWithCoords, exclude=None) -> tuple[T_DataWithCoords, ---------- *args : DataArray or Dataset Arrays to broadcast against each other. - exclude : sequence of str, optional + exclude : str, iterable of hashable or None, optional Dimensions that must not be broadcasted Returns diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index 5b2cf38ee2e..d320eef1bbf 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -15,7 +15,6 @@ ) from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce from xarray.core.ops import ( - IncludeCumMethods, IncludeNumpySameMethods, IncludeReduceMethods, ) @@ -56,10 +55,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): if ufunc.signature is not None: raise NotImplementedError( - "{} not supported: xarray objects do not directly implement " + f"{ufunc} not supported: xarray objects do not directly implement " "generalized ufuncs. Instead, use xarray.apply_ufunc or " "explicitly convert to xarray objects to NumPy arrays " - "(e.g., with `.values`).".format(ufunc) + "(e.g., with `.values`)." ) if method != "__call__": @@ -99,8 +98,6 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): class VariableArithmetic( ImplementsArrayReduce, - IncludeReduceMethods, - IncludeCumMethods, IncludeNumpySameMethods, SupportsArithmetic, VariableOpsMixin, diff --git a/xarray/core/combine.py b/xarray/core/combine.py index cee27300beb..eecd01d011e 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -1,7 +1,6 @@ from __future__ import annotations import itertools -import warnings from collections import Counter from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Literal, Union @@ -110,9 +109,9 @@ def _infer_concat_order_from_coords(datasets): ascending = False else: raise ValueError( - "Coordinate variable {} is neither " + f"Coordinate variable {dim} is neither " "monotonically increasing nor " - "monotonically decreasing on all datasets".format(dim) + "monotonically decreasing on all datasets" ) # Assume that any two datasets whose coord along dim starts @@ -222,10 +221,8 @@ def _combine_nd( n_dims = len(example_tile_id) if len(concat_dims) != n_dims: raise ValueError( - "concat_dims has length {} but the datasets " - "passed are nested in a {}-dimensional structure".format( - len(concat_dims), n_dims - ) + f"concat_dims has length {len(concat_dims)} but the datasets " + f"passed are nested in a {n_dims}-dimensional structure" ) # Each iteration of this loop reduces the length of the tile_ids tuples @@ -647,13 +644,12 @@ def _combine_single_variable_hypercube( if not (indexes.is_monotonic_increasing or indexes.is_monotonic_decreasing): raise ValueError( "Resulting object does not have monotonic" - " global indexes along dimension {}".format(dim) + f" global indexes along dimension {dim}" ) return concatenated -# TODO remove empty list default param after version 0.21, see PR4696 def combine_by_coords( data_objects: Iterable[Dataset | DataArray] = [], compat: CompatOptions = "no_conflicts", @@ -662,7 +658,6 @@ def combine_by_coords( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "no_conflicts", - datasets: Iterable[Dataset] | None = None, ) -> Dataset | DataArray: """ @@ -760,8 +755,6 @@ def combine_by_coords( If a callable, it must expect a sequence of ``attrs`` dicts and a context object as its only parameters. - datasets : Iterable of Datasets - Returns ------- combined : xarray.Dataset or xarray.DataArray @@ -918,14 +911,6 @@ def combine_by_coords( DataArrays or Datasets, a ValueError will be raised (as this is an ambiguous operation). """ - # TODO remove after version 0.21, see PR4696 - if datasets is not None: - warnings.warn( - "The datasets argument has been renamed to `data_objects`." - " From 0.21 on passing a value for datasets will raise an error." - ) - data_objects = datasets - if not data_objects: return Dataset() diff --git a/xarray/core/common.py b/xarray/core/common.py index d54c259ae2c..ab8a4d84261 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -14,7 +14,6 @@ from xarray.core.indexing import BasicIndexer, ExplicitlyIndexed from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.parallelcompat import get_chunked_array_type, guess_chunkmanager -from xarray.core.pdcompat import _convert_base_to_offset from xarray.core.pycompat import is_chunked_array from xarray.core.utils import ( Frozen, @@ -46,6 +45,7 @@ DatetimeLike, DTypeLikeSave, ScalarOrArray, + Self, SideOptions, T_Chunks, T_DataWithCoords, @@ -223,7 +223,7 @@ def _get_axis_num(self: Any, dim: Hashable) -> int: raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") @property - def sizes(self: Any) -> Frozen[Hashable, int]: + def sizes(self: Any) -> Mapping[Hashable, int]: """Ordered mapping from dimension names to lengths. Immutable. @@ -382,11 +382,11 @@ class DataWithCoords(AttrAccessMixin): __slots__ = ("_close",) def squeeze( - self: T_DataWithCoords, + self, dim: Hashable | Iterable[Hashable] | None = None, drop: bool = False, axis: int | Iterable[int] | None = None, - ) -> T_DataWithCoords: + ) -> Self: """Return a new object with squeezed data. Parameters @@ -415,12 +415,12 @@ def squeeze( return self.isel(drop=drop, **{d: 0 for d in dims}) def clip( - self: T_DataWithCoords, + self, min: ScalarOrArray | None = None, max: ScalarOrArray | None = None, *, keep_attrs: bool | None = None, - ) -> T_DataWithCoords: + ) -> Self: """ Return an array whose values are limited to ``[min, max]``. At least one of max or min must be given. @@ -473,10 +473,10 @@ def _calc_assign_results( return {k: v(self) if callable(v) else v for k, v in kwargs.items()} def assign_coords( - self: T_DataWithCoords, + self, coords: Mapping[Any, Any] | None = None, **coords_kwargs: Any, - ) -> T_DataWithCoords: + ) -> Self: """Assign new coordinates to this object. Returns a new object with all the original data in addition to the new @@ -607,15 +607,21 @@ def assign_coords( Dataset.swap_dims Dataset.set_coords """ + from xarray.core.coordinates import Coordinates + coords_combined = either_dict_or_kwargs(coords, coords_kwargs, "assign_coords") data = self.copy(deep=False) - results: dict[Hashable, Any] = self._calc_assign_results(coords_combined) + + results: Coordinates | dict[Hashable, Any] + if isinstance(coords, Coordinates): + results = coords + else: + results = self._calc_assign_results(coords_combined) + data.coords.update(results) return data - def assign_attrs( - self: T_DataWithCoords, *args: Any, **kwargs: Any - ) -> T_DataWithCoords: + def assign_attrs(self, *args: Any, **kwargs: Any) -> Self: """Assign new attrs to this object. Returns a new object equivalent to ``self.attrs.update(*args, **kwargs)``. @@ -627,6 +633,36 @@ def assign_attrs( **kwargs keyword arguments passed into ``attrs.update``. + Examples + -------- + >>> dataset = xr.Dataset({"temperature": [25, 30, 27]}) + >>> dataset + + Dimensions: (temperature: 3) + Coordinates: + * temperature (temperature) int64 25 30 27 + Data variables: + *empty* + + >>> new_dataset = dataset.assign_attrs( + ... units="Celsius", description="Temperature data" + ... ) + >>> new_dataset + + Dimensions: (temperature: 3) + Coordinates: + * temperature (temperature) int64 25 30 27 + Data variables: + *empty* + Attributes: + units: Celsius + description: Temperature data + + # Attributes of the new dataset + + >>> new_dataset.attrs + {'units': 'Celsius', 'description': 'Temperature data'} + Returns ------- assigned : same type as caller @@ -950,6 +986,7 @@ def _resample( from xarray.core.dataarray import DataArray from xarray.core.groupby import ResolvedTimeResampleGrouper, TimeResampleGrouper + from xarray.core.pdcompat import _convert_base_to_offset from xarray.core.resample import RESAMPLE_DIM if keep_attrs is not None: @@ -1023,11 +1060,12 @@ def _resample( restore_coord_dims=restore_coord_dims, ) - def where( - self: T_DataWithCoords, cond: Any, other: Any = dtypes.NA, drop: bool = False - ) -> T_DataWithCoords: + def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: """Filter elements from this object according to a condition. + Returns elements from 'DataArray', where 'cond' is True, + otherwise fill in 'other'. + This operation follows the normal broadcasting and alignment rules that xarray uses for binary arithmetic. @@ -1035,10 +1073,12 @@ def where( ---------- cond : DataArray, Dataset, or callable Locations at which to preserve this object's values. dtype must be `bool`. - If a callable, it must expect this object as its only parameter. - other : scalar, DataArray or Dataset, optional + If a callable, the callable is passed this object, and the result is used as + the value for cond. + other : scalar, DataArray, Dataset, or callable, optional Value to use for locations in this object where ``cond`` is False. - By default, these locations filled with NA. + By default, these locations are filled with NA. If a callable, it must + expect this object as its only parameter. drop : bool, default: False If True, coordinate labels that only correspond to False values of the condition are dropped from the result. @@ -1086,7 +1126,16 @@ def where( [15., nan, nan, nan]]) Dimensions without coordinates: x, y - >>> a.where(lambda x: x.x + x.y < 4, drop=True) + >>> a.where(lambda x: x.x + x.y < 4, lambda x: -x) + + array([[ 0, 1, 2, 3, -4], + [ 5, 6, 7, -8, -9], + [ 10, 11, -12, -13, -14], + [ 15, -16, -17, -18, -19], + [-20, -21, -22, -23, -24]]) + Dimensions without coordinates: x, y + + >>> a.where(a.x + a.y < 4, drop=True) array([[ 0., 1., 2., 3.], [ 5., 6., 7., nan], @@ -1094,14 +1143,6 @@ def where( [15., nan, nan, nan]]) Dimensions without coordinates: x, y - >>> a.where(a.x + a.y < 4, -1, drop=True) - - array([[ 0, 1, 2, 3], - [ 5, 6, 7, -1], - [10, 11, -1, -1], - [15, -1, -1, -1]]) - Dimensions without coordinates: x, y - See Also -------- numpy.where : corresponding numpy function @@ -1113,14 +1154,16 @@ def where( if callable(cond): cond = cond(self) + if callable(other): + other = other(self) if drop: if not isinstance(cond, (Dataset, DataArray)): raise TypeError( - f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r}" + f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r} (or a callable than returns one)." ) - self, cond = align(self, cond) # type: ignore[assignment] + self, cond = align(self, cond) def _dataarray_indexer(dim: Hashable) -> DataArray: return cond.any(dim=(d for d in cond.dims if d != dim)) @@ -1167,9 +1210,7 @@ def close(self) -> None: self._close() self._close = None - def isnull( - self: T_DataWithCoords, keep_attrs: bool | None = None - ) -> T_DataWithCoords: + def isnull(self, keep_attrs: bool | None = None) -> Self: """Test each value in the array for whether it is a missing value. Parameters @@ -1212,9 +1253,7 @@ def isnull( keep_attrs=keep_attrs, ) - def notnull( - self: T_DataWithCoords, keep_attrs: bool | None = None - ) -> T_DataWithCoords: + def notnull(self, keep_attrs: bool | None = None) -> Self: """Test each value in the array for whether it is not a missing value. Parameters @@ -1257,7 +1296,7 @@ def notnull( keep_attrs=keep_attrs, ) - def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords: + def isin(self, test_elements: Any) -> Self: """Tests each value in the array for whether it is in test elements. Parameters @@ -1306,7 +1345,7 @@ def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords: ) def astype( - self: T_DataWithCoords, + self, dtype, *, order=None, @@ -1314,7 +1353,7 @@ def astype( subok=None, copy=None, keep_attrs=True, - ) -> T_DataWithCoords: + ) -> Self: """ Copy of the xarray object, with data cast to a specified type. Leaves coordinate dtype unchanged. @@ -1381,7 +1420,7 @@ def astype( dask="allowed", ) - def __enter__(self: T_DataWithCoords) -> T_DataWithCoords: + def __enter__(self) -> Self: return self def __exit__(self, exc_type, exc_value, traceback) -> None: diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 685307fc8c3..f506bc97a2c 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -8,7 +8,7 @@ import operator import warnings from collections import Counter -from collections.abc import Hashable, Iterable, Mapping, Sequence, Set +from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence, Set from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, overload import numpy as np @@ -17,6 +17,7 @@ from xarray.core.alignment import align, deep_align from xarray.core.common import zeros_like from xarray.core.duck_array_ops import datetime_to_numeric +from xarray.core.formatting import limit_lines from xarray.core.indexes import Index, filter_indexes_from_coords from xarray.core.merge import merge_attrs, merge_coordinates_without_align from xarray.core.options import OPTIONS, _get_keep_attrs @@ -32,6 +33,8 @@ from xarray.core.dataset import Dataset from xarray.core.types import CombineAttrsOptions, JoinOptions + MissingCoreDimOptions = Literal["raise", "copy", "drop"] + _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") _JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"}) @@ -42,6 +45,7 @@ def _first_of_type(args, kind): for arg in args: if isinstance(arg, kind): return arg + raise ValueError("This should be unreachable.") @@ -159,7 +163,7 @@ def to_gufunc_string(self, exclude_dims=frozenset()): if exclude_dims: exclude_dims = [self.dims_map[dim] for dim in exclude_dims] - counter = Counter() + counter: Counter = Counter() def _enumerate(dim): if dim in exclude_dims: @@ -285,8 +289,14 @@ def apply_dataarray_vfunc( from xarray.core.dataarray import DataArray if len(args) > 1: - args = deep_align( - args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False + args = tuple( + deep_align( + args, + join=join, + copy=False, + exclude=exclude_dims, + raise_on_invalid=False, + ) ) objs = _all_of_type(args, DataArray) @@ -347,7 +357,7 @@ def assert_and_return_exact_match(all_keys): if keys != first_keys: raise ValueError( "exact match required for all data variable names, " - f"but {keys!r} != {first_keys!r}" + f"but {list(keys)} != {list(first_keys)}: {set(keys) ^ set(first_keys)} are not in both." ) return first_keys @@ -376,7 +386,7 @@ def collect_dict_values( ] -def _as_variables_or_variable(arg): +def _as_variables_or_variable(arg) -> Variable | tuple[Variable]: try: return arg.variables except AttributeError: @@ -396,8 +406,39 @@ def _unpack_dict_tuples( return out +def _check_core_dims(signature, variable_args, name): + """ + Check if an arg has all the core dims required by the signature. + + Slightly awkward design, of returning the error message. But we want to + give a detailed error message, which requires inspecting the variable in + the inner loop. + """ + missing = [] + for i, (core_dims, variable_arg) in enumerate( + zip(signature.input_core_dims, variable_args) + ): + # Check whether all the dims are on the variable. Note that we need the + # `hasattr` to check for a dims property, to protect against the case where + # a numpy array is passed in. + if hasattr(variable_arg, "dims") and set(core_dims) - set(variable_arg.dims): + missing += [[i, variable_arg, core_dims]] + if missing: + message = "" + for i, variable_arg, core_dims in missing: + message += f"Missing core dims {set(core_dims) - set(variable_arg.dims)} from arg number {i + 1} on a variable named `{name}`:\n{variable_arg}\n\n" + message += "Either add the core dimension, or if passing a dataset alternatively pass `on_missing_core_dim` as `copy` or `drop`. " + return message + return True + + def apply_dict_of_variables_vfunc( - func, *args, signature: _UFuncSignature, join="inner", fill_value=None + func, + *args, + signature: _UFuncSignature, + join="inner", + fill_value=None, + on_missing_core_dim: MissingCoreDimOptions = "raise", ): """Apply a variable level function over dicts of DataArray, DataArray, Variable and ndarray objects. @@ -408,7 +449,20 @@ def apply_dict_of_variables_vfunc( result_vars = {} for name, variable_args in zip(names, grouped_by_name): - result_vars[name] = func(*variable_args) + core_dim_present = _check_core_dims(signature, variable_args, name) + if core_dim_present is True: + result_vars[name] = func(*variable_args) + else: + if on_missing_core_dim == "raise": + raise ValueError(core_dim_present) + elif on_missing_core_dim == "copy": + result_vars[name] = variable_args[0] + elif on_missing_core_dim == "drop": + pass + else: + raise ValueError( + f"Invalid value for `on_missing_core_dim`: {on_missing_core_dim!r}" + ) if signature.num_outputs > 1: return _unpack_dict_tuples(result_vars, signature.num_outputs) @@ -441,6 +495,7 @@ def apply_dataset_vfunc( fill_value=_NO_FILL_VALUE, exclude_dims=frozenset(), keep_attrs="override", + on_missing_core_dim: MissingCoreDimOptions = "raise", ) -> Dataset | tuple[Dataset, ...]: """Apply a variable level function over Dataset, dict of DataArray, DataArray, Variable and/or ndarray objects. @@ -457,8 +512,14 @@ def apply_dataset_vfunc( objs = _all_of_type(args, Dataset) if len(args) > 1: - args = deep_align( - args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False + args = tuple( + deep_align( + args, + join=join, + copy=False, + exclude=exclude_dims, + raise_on_invalid=False, + ) ) list_of_coords, list_of_indexes = build_output_coords_and_indexes( @@ -467,7 +528,12 @@ def apply_dataset_vfunc( args = tuple(getattr(arg, "data_vars", arg) for arg in args) result_vars = apply_dict_of_variables_vfunc( - func, *args, signature=signature, join=dataset_join, fill_value=fill_value + func, + *args, + signature=signature, + join=dataset_join, + fill_value=fill_value, + on_missing_core_dim=on_missing_core_dim, ) out: Dataset | tuple[Dataset, ...] @@ -517,7 +583,7 @@ def apply_groupby_func(func, *args): assert groupbys, "must have at least one groupby to iterate over" first_groupby = groupbys[0] (grouper,) = first_groupby.groupers - if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): + if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): # type: ignore[union-attr] raise ValueError( "apply_ufunc can only perform operations over " "multiple GroupBy objects at once if they are all " @@ -529,6 +595,7 @@ def apply_groupby_func(func, *args): iterators = [] for arg in args: + iterator: Iterator[Any] if isinstance(arg, GroupBy): iterator = (value for _, value in arg) elif hasattr(arg, "dims") and grouped_dim in arg.dims: @@ -543,9 +610,9 @@ def apply_groupby_func(func, *args): iterator = itertools.repeat(arg) iterators.append(iterator) - applied = (func(*zipped_args) for zipped_args in zip(*iterators)) + applied: Iterator = (func(*zipped_args) for zipped_args in zip(*iterators)) applied_example, applied = peek_at(applied) - combine = first_groupby._combine + combine = first_groupby._combine # type: ignore[attr-defined] if isinstance(applied_example, tuple): combined = tuple(combine(output) for output in zip(*applied)) else: @@ -595,17 +662,9 @@ def broadcast_compat_data( return data set_old_dims = set(old_dims) - missing_core_dims = [d for d in core_dims if d not in set_old_dims] - if missing_core_dims: - raise ValueError( - "operand to apply_ufunc has required core dimensions {}, but " - "some of these dimensions are absent on an input variable: {}".format( - list(core_dims), missing_core_dims - ) - ) - set_new_dims = set(new_dims) unexpected_dims = [d for d in old_dims if d not in set_new_dims] + if unexpected_dims: raise ValueError( "operand to apply_ufunc encountered unexpected " @@ -659,6 +718,7 @@ def apply_variable_ufunc( dask_gufunc_kwargs=None, ) -> Variable | tuple[Variable, ...]: """Apply a ndarray level function over Variable and/or ndarray objects.""" + from xarray.core.formatting import short_array_repr from xarray.core.variable import Variable, as_compatible_data dim_sizes = unified_dim_sizes( @@ -766,11 +826,11 @@ def func(*arrays): not isinstance(result_data, tuple) or len(result_data) != signature.num_outputs ): raise ValueError( - "applied function does not have the number of " - "outputs specified in the ufunc signature. " - "Result is not a tuple of {} elements: {!r}".format( - signature.num_outputs, result_data - ) + f"applied function does not have the number of " + f"outputs specified in the ufunc signature. " + f"Received a {type(result_data)} with {len(result_data)} elements. " + f"Expected a tuple of {signature.num_outputs} elements:\n\n" + f"{limit_lines(repr(result_data), limit=10)}" ) objs = _all_of_type(args, Variable) @@ -784,21 +844,22 @@ def func(*arrays): data = as_compatible_data(data) if data.ndim != len(dims): raise ValueError( - "applied function returned data with unexpected " + "applied function returned data with an unexpected " f"number of dimensions. Received {data.ndim} dimension(s) but " - f"expected {len(dims)} dimensions with names: {dims!r}" + f"expected {len(dims)} dimensions with names {dims!r}, from:\n\n" + f"{short_array_repr(data)}" ) var = Variable(dims, data, fastpath=True) for dim, new_size in var.sizes.items(): if dim in dim_sizes and new_size != dim_sizes[dim]: raise ValueError( - "size of dimension {!r} on inputs was unexpectedly " - "changed by applied function from {} to {}. Only " + f"size of dimension '{dim}' on inputs was unexpectedly " + f"changed by applied function from {dim_sizes[dim]} to {new_size}. Only " "dimensions specified in ``exclude_dims`` with " - "xarray.apply_ufunc are allowed to change size.".format( - dim, dim_sizes[dim], new_size - ) + "xarray.apply_ufunc are allowed to change size. " + "The data returned was:\n\n" + f"{short_array_repr(data)}" ) var.attrs = attrs @@ -845,11 +906,12 @@ def apply_ufunc( dataset_fill_value: object = _NO_FILL_VALUE, keep_attrs: bool | str | None = None, kwargs: Mapping | None = None, - dask: str = "forbidden", + dask: Literal["forbidden", "allowed", "parallelized"] = "forbidden", output_dtypes: Sequence | None = None, output_sizes: Mapping[Any, int] | None = None, meta: Any = None, dask_gufunc_kwargs: dict[str, Any] | None = None, + on_missing_core_dim: MissingCoreDimOptions = "raise", ) -> Any: """Apply a vectorized function for unlabeled arrays on xarray objects. @@ -963,6 +1025,8 @@ def apply_ufunc( :py:func:`dask.array.apply_gufunc`. ``meta`` should be given in the ``dask_gufunc_kwargs`` parameter . It will be removed as direct parameter a future version. + on_missing_core_dim : {"raise", "copy", "drop"}, default: "raise" + How to handle missing core dimensions on input variables. Returns ------- @@ -1191,6 +1255,7 @@ def apply_ufunc( dataset_join=dataset_join, fill_value=dataset_fill_value, keep_attrs=keep_attrs, + on_missing_core_dim=on_missing_core_dim, ) # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc elif any(isinstance(a, DataArray) for a in args): @@ -1286,7 +1351,7 @@ def cov( if any(not isinstance(arr, DataArray) for arr in [da_a, da_b]): raise TypeError( "Only xr.DataArray is supported." - "Given {}.".format([type(arr) for arr in [da_a, da_b]]) + f"Given {[type(arr) for arr in [da_a, da_b]]}." ) return _cov_corr(da_a, da_b, dim=dim, ddof=ddof, method="cov") @@ -1364,7 +1429,7 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray: if any(not isinstance(arr, DataArray) for arr in [da_a, da_b]): raise TypeError( "Only xr.DataArray is supported." - "Given {}.".format([type(arr) for arr in [da_a, da_b]]) + f"Given {[type(arr) for arr in [da_a, da_b]]}." ) return _cov_corr(da_a, da_b, dim=dim, method="corr") @@ -1402,14 +1467,14 @@ def _cov_corr( ) / (valid_count) if method == "cov": - return cov # type: ignore[return-value] + return cov else: # compute std + corr da_a_std = da_a.std(dim=dim) da_b_std = da_b.std(dim=dim) corr = cov / (da_a_std * da_b_std) - return corr # type: ignore[return-value] + return corr def cross( @@ -1625,8 +1690,8 @@ def dot( dims: Dims = None, **kwargs: Any, ): - """Generalized dot product for xarray objects. Like np.einsum, but - provides a simpler interface based on array dimensions. + """Generalized dot product for xarray objects. Like ``np.einsum``, but + provides a simpler interface based on array dimension names. Parameters ---------- @@ -1636,13 +1701,24 @@ def dot( Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. If not specified, then all the common dimensions are summed over. **kwargs : dict - Additional keyword arguments passed to numpy.einsum or - dask.array.einsum + Additional keyword arguments passed to ``numpy.einsum`` or + ``dask.array.einsum`` Returns ------- DataArray + See Also + -------- + numpy.einsum + dask.array.einsum + opt_einsum.contract + + Notes + ----- + We recommend installing the optional ``opt_einsum`` package, or alternatively passing ``optimize=True``, + which is passed through to ``np.einsum``, and works for most array backends. + Examples -------- >>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=["a", "b"]) @@ -1707,7 +1783,7 @@ def dot( if any(not isinstance(arr, (Variable, DataArray)) for arr in arrays): raise TypeError( "Only xr.DataArray and xr.Variable are supported." - "Given {}.".format([type(arr) for arr in arrays]) + f"Given {[type(arr) for arr in arrays]}." ) if len(arrays) == 0: @@ -2046,9 +2122,13 @@ def _calc_idxminmax( raise ValueError("Must supply 'dim' argument for multidimensional arrays") if dim not in array.dims: - raise KeyError(f'Dimension "{dim}" not in dimension') + raise KeyError( + f"Dimension {dim!r} not found in array dimensions {array.dims!r}" + ) if dim not in array.coords: - raise KeyError(f'Dimension "{dim}" does not have coordinates') + raise KeyError( + f"Dimension {dim!r} is not one of the coordinates {tuple(array.coords.keys())}" + ) # These are dtypes with NaN values argmin and argmax can handle na_dtypes = "cfO" @@ -2066,7 +2146,8 @@ def _calc_idxminmax( chunkmanager = get_chunked_array_type(array.data) chunks = dict(zip(array.dims, array.chunks)) dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim]) - res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape)) + data = dask_coord[duck_array_ops.ravel(indx.data)] + res = indx.copy(data=duck_array_ops.reshape(data, indx.shape)) # we need to attach back the dim name res.name = dim else: diff --git a/xarray/core/concat.py b/xarray/core/concat.py index d7aad8c7188..8c558b38848 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Hashable, Iterable -from typing import TYPE_CHECKING, Any, Union, cast, overload +from typing import TYPE_CHECKING, Any, Union, overload import numpy as np import pandas as pd @@ -16,7 +16,7 @@ merge_attrs, merge_collected, ) -from xarray.core.types import T_DataArray, T_Dataset +from xarray.core.types import T_DataArray, T_Dataset, T_Variable from xarray.core.variable import Variable from xarray.core.variable import concat as concat_vars @@ -34,7 +34,7 @@ @overload def concat( objs: Iterable[T_Dataset], - dim: Hashable | T_DataArray | pd.Index, + dim: Hashable | T_Variable | T_DataArray | pd.Index, data_vars: T_DataVars = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", @@ -49,7 +49,7 @@ def concat( @overload def concat( objs: Iterable[T_DataArray], - dim: Hashable | T_DataArray | pd.Index, + dim: Hashable | T_Variable | T_DataArray | pd.Index, data_vars: T_DataVars = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", @@ -80,11 +80,11 @@ def concat( xarray objects to concatenate together. Each object is expected to consist of variables and coordinates with matching shapes except for along the concatenated dimension. - dim : Hashable or DataArray or pandas.Index + dim : Hashable or Variable or DataArray or pandas.Index Name of the dimension to concatenate along. This can either be a new dimension name, in which case it is added along axis=0, or an existing dimension name, in which case the location of the dimension is - unchanged. If dimension is provided as a DataArray or Index, its name + unchanged. If dimension is provided as a Variable, DataArray or Index, its name is used as the dimension to concatenate along and the values are added as a coordinate. data_vars : {"minimal", "different", "all"} or list of Hashable, optional @@ -391,17 +391,20 @@ def process_subset_opt(opt, subset): else: raise ValueError(f"unexpected value for {subset}: {opt}") else: - invalid_vars = [k for k in opt if k not in getattr(datasets[0], subset)] + valid_vars = tuple(getattr(datasets[0], subset)) + invalid_vars = [k for k in opt if k not in valid_vars] if invalid_vars: if subset == "coords": raise ValueError( - "some variables in coords are not coordinates on " - f"the first dataset: {invalid_vars}" + f"the variables {invalid_vars} in coords are not " + f"found in the coordinates of the first dataset {valid_vars}" ) else: + # note: data_vars are not listed in the error message here, + # because there may be lots of them raise ValueError( - "some variables in data_vars are not data variables " - f"on the first dataset: {invalid_vars}" + f"the variables {invalid_vars} in data_vars are not " + f"found in the data variables of the first dataset" ) concat_over.update(opt) @@ -447,7 +450,7 @@ def _parse_datasets( def _dataset_concat( datasets: list[T_Dataset], - dim: str | T_DataArray | pd.Index, + dim: str | T_Variable | T_DataArray | pd.Index, data_vars: T_DataVars, coords: str | list[str], compat: CompatOptions, @@ -501,8 +504,7 @@ def _dataset_concat( # case where concat dimension is a coordinate or data_var but not a dimension if (dim in coord_names or dim in data_names) and dim not in dim_names: - # TODO: Overriding type because .expand_dims has incorrect typing: - datasets = [cast(T_Dataset, ds.expand_dims(dim)) for ds in datasets] + datasets = [ds.expand_dims(dim) for ds in datasets] # determine which variables to concatenate concat_over, equals, concat_dim_lengths = _calc_concat_over( @@ -674,7 +676,7 @@ def get_indexes(name): def _dataarray_concat( arrays: Iterable[T_DataArray], - dim: str | T_DataArray | pd.Index, + dim: str | T_Variable | T_DataArray | pd.Index, data_vars: T_DataVars, coords: str | list[str], compat: CompatOptions, @@ -705,8 +707,7 @@ def _dataarray_concat( if compat == "identical": raise ValueError("array names not identical") else: - # TODO: Overriding type because .rename has incorrect typing: - arr = cast(T_DataArray, arr.rename(name)) + arr = arr.rename(name) datasets.append(arr._to_temp_dataset()) ds = _dataset_concat( diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 32809a54ddd..0c85b2a2d69 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -1,51 +1,60 @@ from __future__ import annotations -import warnings from collections.abc import Hashable, Iterator, Mapping, Sequence from contextlib import contextmanager -from typing import TYPE_CHECKING, Any +from typing import ( + TYPE_CHECKING, + Any, + Generic, + cast, +) import numpy as np import pandas as pd from xarray.core import formatting +from xarray.core.alignment import Aligner from xarray.core.indexes import ( Index, Indexes, + PandasIndex, PandasMultiIndex, assert_no_index_corrupted, + create_default_index_implicit, ) from xarray.core.merge import merge_coordinates_without_align, merge_coords -from xarray.core.utils import Frozen, ReprObject -from xarray.core.variable import Variable, calculate_dimensions +from xarray.core.types import DataVars, Self, T_DataArray, T_Xarray +from xarray.core.utils import ( + Frozen, + ReprObject, + either_dict_or_kwargs, + emit_user_level_warning, +) +from xarray.core.variable import Variable, as_variable, calculate_dimensions if TYPE_CHECKING: from xarray.core.common import DataWithCoords from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import T_DataArray # Used as the key corresponding to a DataArray's variable when converting # arbitrary DataArray objects to datasets _THIS_ARRAY = ReprObject("") -class Coordinates(Mapping[Hashable, "T_DataArray"]): +class AbstractCoordinates(Mapping[Hashable, "T_DataArray"]): _data: DataWithCoords __slots__ = ("_data",) def __getitem__(self, key: Hashable) -> T_DataArray: raise NotImplementedError() - def __setitem__(self, key: Hashable, value: Any) -> None: - self.update({key: value}) - @property def _names(self) -> set[Hashable]: raise NotImplementedError() @property - def dims(self) -> Mapping[Hashable, int] | tuple[Hashable, ...]: + def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]: raise NotImplementedError() @property @@ -54,10 +63,22 @@ def dtypes(self) -> Frozen[Hashable, np.dtype]: @property def indexes(self) -> Indexes[pd.Index]: + """Mapping of pandas.Index objects used for label based indexing. + + Raises an error if this Coordinates object has indexes that cannot + be coerced to pandas.Index objects. + + See Also + -------- + Coordinates.xindexes + """ return self._data.indexes @property def xindexes(self) -> Indexes[Index]: + """Mapping of :py:class:`~xarray.indexes.Index` objects + used for label based indexing. + """ return self._data.xindexes @property @@ -67,7 +88,7 @@ def variables(self): def _update_coords(self, coords, indexes): raise NotImplementedError() - def _maybe_drop_multiindex_coords(self, coords): + def _drop_coords(self, coord_names): raise NotImplementedError() def __iter__(self) -> Iterator[Hashable]: @@ -109,7 +130,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index: elif set(ordered_dims) != set(self.dims): raise ValueError( "ordered_dims must match dims, but does not: " - "{} vs {}".format(ordered_dims, self.dims) + f"{ordered_dims} vs {self.dims}" ) if len(ordered_dims) == 0: @@ -125,7 +146,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index: index_lengths = np.fromiter( (len(index) for index in indexes), dtype=np.intp ) - cumprod_lengths = np.cumproduct(index_lengths) + cumprod_lengths = np.cumprod(index_lengths) if cumprod_lengths[-1] == 0: # if any factor is empty, the cartesian product is empty @@ -163,13 +184,279 @@ def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index: return pd.MultiIndex(level_list, code_list, names=names) - def update(self, other: Mapping[Any, Any]) -> None: - other_vars = getattr(other, "variables", other) - self._maybe_drop_multiindex_coords(set(other_vars)) - coords, indexes = merge_coords( - [self.variables, other_vars], priority_arg=1, indexes=self.xindexes + +class Coordinates(AbstractCoordinates): + """Dictionary like container for Xarray coordinates (variables + indexes). + + This collection is a mapping of coordinate names to + :py:class:`~xarray.DataArray` objects. + + It can be passed directly to the :py:class:`~xarray.Dataset` and + :py:class:`~xarray.DataArray` constructors via their `coords` argument. This + will add both the coordinates variables and their index. + + Coordinates are either: + + - returned via the :py:attr:`Dataset.coords` and :py:attr:`DataArray.coords` + properties + - built from Pandas or other index objects + (e.g., :py:meth:`Coordinates.from_pandas_multiindex`) + - built directly from coordinate data and Xarray ``Index`` objects (beware that + no consistency check is done on those inputs) + + Parameters + ---------- + coords: dict-like, optional + Mapping where keys are coordinate names and values are objects that + can be converted into a :py:class:`~xarray.Variable` object + (see :py:func:`~xarray.as_variable`). If another + :py:class:`~xarray.Coordinates` object is passed, its indexes + will be added to the new created object. + indexes: dict-like, optional + Mapping of where keys are coordinate names and values are + :py:class:`~xarray.indexes.Index` objects. If None (default), + pandas indexes will be created for each dimension coordinate. + Passing an empty dictionary will skip this default behavior. + + Examples + -------- + Create a dimension coordinate with a default (pandas) index: + + >>> xr.Coordinates({"x": [1, 2]}) + Coordinates: + * x (x) int64 1 2 + + Create a dimension coordinate with no index: + + >>> xr.Coordinates(coords={"x": [1, 2]}, indexes={}) + Coordinates: + x (x) int64 1 2 + + Create a new Coordinates object from existing dataset coordinates + (indexes are passed): + + >>> ds = xr.Dataset(coords={"x": [1, 2]}) + >>> xr.Coordinates(ds.coords) + Coordinates: + * x (x) int64 1 2 + + Create indexed coordinates from a ``pandas.MultiIndex`` object: + + >>> midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]]) + >>> xr.Coordinates.from_pandas_multiindex(midx, "x") + Coordinates: + * x (x) object MultiIndex + * x_level_0 (x) object 'a' 'a' 'b' 'b' + * x_level_1 (x) int64 0 1 0 1 + + Create a new Dataset object by passing a Coordinates object: + + >>> midx_coords = xr.Coordinates.from_pandas_multiindex(midx, "x") + >>> xr.Dataset(coords=midx_coords) + + Dimensions: (x: 4) + Coordinates: + * x (x) object MultiIndex + * x_level_0 (x) object 'a' 'a' 'b' 'b' + * x_level_1 (x) int64 0 1 0 1 + Data variables: + *empty* + + """ + + _data: DataWithCoords + + __slots__ = ("_data",) + + def __init__( + self, + coords: Mapping[Any, Any] | None = None, + indexes: Mapping[Any, Index] | None = None, + ) -> None: + # When coordinates are constructed directly, an internal Dataset is + # created so that it is compatible with the DatasetCoordinates and + # DataArrayCoordinates classes serving as a proxy for the data. + # TODO: refactor DataArray / Dataset so that Coordinates store the data. + from xarray.core.dataset import Dataset + + if coords is None: + coords = {} + + variables: dict[Hashable, Variable] + default_indexes: dict[Hashable, PandasIndex] = {} + coords_obj_indexes: dict[Hashable, Index] = {} + + if isinstance(coords, Coordinates): + if indexes is not None: + raise ValueError( + "passing both a ``Coordinates`` object and a mapping of indexes " + "to ``Coordinates.__init__`` is not allowed " + "(this constructor does not support merging them)" + ) + variables = {k: v.copy() for k, v in coords.variables.items()} + coords_obj_indexes = dict(coords.xindexes) + else: + variables = {} + for name, data in coords.items(): + var = as_variable(data, name=name) + if var.dims == (name,) and indexes is None: + index, index_vars = create_default_index_implicit(var, list(coords)) + default_indexes.update({k: index for k in index_vars}) + variables.update(index_vars) + else: + variables[name] = var + + if indexes is None: + indexes = {} + else: + indexes = dict(indexes) + + indexes.update(default_indexes) + indexes.update(coords_obj_indexes) + + no_coord_index = set(indexes) - set(variables) + if no_coord_index: + raise ValueError( + f"no coordinate variables found for these indexes: {no_coord_index}" + ) + + for k, idx in indexes.items(): + if not isinstance(idx, Index): + raise TypeError(f"'{k}' is not an `xarray.indexes.Index` object") + + # maybe convert to base variable + for k, v in variables.items(): + if k not in indexes: + variables[k] = v.to_base_variable() + + self._data = Dataset._construct_direct( + coord_names=set(variables), variables=variables, indexes=indexes ) - self._update_coords(coords, indexes) + + @classmethod + def _construct_direct( + cls, + coords: dict[Any, Variable], + indexes: dict[Any, Index], + dims: dict[Any, int] | None = None, + ) -> Self: + from xarray.core.dataset import Dataset + + obj = object.__new__(cls) + obj._data = Dataset._construct_direct( + coord_names=set(coords), + variables=coords, + indexes=indexes, + dims=dims, + ) + return obj + + @classmethod + def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: str) -> Self: + """Wrap a pandas multi-index as Xarray coordinates (dimension + levels). + + The returned coordinates can be directly assigned to a + :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the + ``coords`` argument of their constructor. + + Parameters + ---------- + midx : :py:class:`pandas.MultiIndex` + Pandas multi-index object. + dim : str + Dimension name. + + Returns + ------- + coords : Coordinates + A collection of Xarray indexed coordinates created from the multi-index. + + """ + xr_idx = PandasMultiIndex(midx, dim) + + variables = xr_idx.create_variables() + indexes = {k: xr_idx for k in variables} + + return cls(coords=variables, indexes=indexes) + + @property + def _names(self) -> set[Hashable]: + return self._data._coord_names + + @property + def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]: + """Mapping from dimension names to lengths or tuple of dimension names.""" + return self._data.dims + + @property + def sizes(self) -> Frozen[Hashable, int]: + """Mapping from dimension names to lengths.""" + return self._data.sizes + + @property + def dtypes(self) -> Frozen[Hashable, np.dtype]: + """Mapping from coordinate names to dtypes. + + Cannot be modified directly. + + See Also + -------- + Dataset.dtypes + """ + return Frozen({n: v.dtype for n, v in self._data.variables.items()}) + + @property + def variables(self) -> Mapping[Hashable, Variable]: + """Low level interface to Coordinates contents as dict of Variable objects. + + This dictionary is frozen to prevent mutation. + """ + return self._data.variables + + def to_dataset(self) -> Dataset: + """Convert these coordinates into a new Dataset.""" + names = [name for name in self._data._variables if name in self._names] + return self._data._copy_listed(names) + + def __getitem__(self, key: Hashable) -> DataArray: + return self._data[key] + + def __delitem__(self, key: Hashable) -> None: + # redirect to DatasetCoordinates.__delitem__ + del self._data.coords[key] + + def equals(self, other: Self) -> bool: + """Two Coordinates objects are equal if they have matching variables, + all of which are equal. + + See Also + -------- + Coordinates.identical + """ + if not isinstance(other, Coordinates): + return False + return self.to_dataset().equals(other.to_dataset()) + + def identical(self, other: Self) -> bool: + """Like equals, but also checks all variable attributes. + + See Also + -------- + Coordinates.equals + """ + if not isinstance(other, Coordinates): + return False + return self.to_dataset().identical(other.to_dataset()) + + def _update_coords( + self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + ) -> None: + # redirect to DatasetCoordinates._update_coords + self._data.coords._update_coords(coords, indexes) + + def _drop_coords(self, coord_names): + # redirect to DatasetCoordinates._drop_coords + self._data.coords._drop_coords(coord_names) def _merge_raw(self, other, reflexive): """For use with binary arithmetic.""" @@ -200,7 +487,7 @@ def _merge_inplace(self, other): yield self._update_coords(variables, indexes) - def merge(self, other: Coordinates | None) -> Dataset: + def merge(self, other: Mapping[Any, Any] | None) -> Dataset: """Merge two sets of coordinates to create a new Dataset The method implements the logic used for joining coordinates in the @@ -214,8 +501,9 @@ def merge(self, other: Coordinates | None) -> Dataset: Parameters ---------- - other : DatasetCoordinates or DataArrayCoordinates - The coordinates from another dataset or data array. + other : dict-like, optional + A :py:class:`Coordinates` object or any mapping that can be turned + into coordinates. Returns ------- @@ -236,13 +524,163 @@ def merge(self, other: Coordinates | None) -> Dataset: variables=coords, coord_names=coord_names, indexes=indexes ) + def __setitem__(self, key: Hashable, value: Any) -> None: + self.update({key: value}) + + def update(self, other: Mapping[Any, Any]) -> None: + """Update this Coordinates variables with other coordinate variables.""" + + if not len(other): + return + + other_coords: Coordinates + + if isinstance(other, Coordinates): + # Coordinates object: just pass it (default indexes won't be created) + other_coords = other + else: + other_coords = create_coords_with_default_indexes( + getattr(other, "variables", other) + ) + + # Discard original indexed coordinates prior to merge allows to: + # - fail early if the new coordinates don't preserve the integrity of existing + # multi-coordinate indexes + # - drop & replace coordinates without alignment (note: we must keep indexed + # coordinates extracted from the DataArray objects passed as values to + # `other` - if any - as those are still used for aligning the old/new coordinates) + coords_to_align = drop_indexed_coords(set(other_coords) & set(other), self) + + coords, indexes = merge_coords( + [coords_to_align, other_coords], + priority_arg=1, + indexes=coords_to_align.xindexes, + ) + + # special case for PandasMultiIndex: updating only its dimension coordinate + # is still allowed but depreciated. + # It is the only case where we need to actually drop coordinates here (multi-index levels) + # TODO: remove when removing PandasMultiIndex's dimension coordinate. + self._drop_coords(self._names - coords_to_align._names) + + self._update_coords(coords, indexes) + + def assign(self, coords: Mapping | None = None, **coords_kwargs: Any) -> Self: + """Assign new coordinates (and indexes) to a Coordinates object, returning + a new object with all the original coordinates in addition to the new ones. + + Parameters + ---------- + coords : :class:`Coordinates` or mapping of hashable to Any + Mapping from coordinate names to the new values. If a ``Coordinates`` + object is passed, its indexes are assigned in the returned object. + Otherwise, a default (pandas) index is created for each dimension + coordinate found in the mapping. + **coords_kwargs + The keyword arguments form of ``coords``. + One of ``coords`` or ``coords_kwargs`` must be provided. + + Returns + ------- + new_coords : Coordinates + A new Coordinates object with the new coordinates (and indexes) + in addition to all the existing coordinates. + + Examples + -------- + >>> coords = xr.Coordinates() + >>> coords + Coordinates: + *empty* + + >>> coords.assign(x=[1, 2]) + Coordinates: + * x (x) int64 1 2 + + >>> midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]]) + >>> coords.assign(xr.Coordinates.from_pandas_multiindex(midx, "y")) + Coordinates: + * y (y) object MultiIndex + * y_level_0 (y) object 'a' 'a' 'b' 'b' + * y_level_1 (y) int64 0 1 0 1 + + """ + coords = either_dict_or_kwargs(coords, coords_kwargs, "assign") + new_coords = self.copy() + new_coords.update(coords) + return new_coords + + def _overwrite_indexes( + self, + indexes: Mapping[Any, Index], + variables: Mapping[Any, Variable] | None = None, + ) -> Self: + results = self.to_dataset()._overwrite_indexes(indexes, variables) + + # TODO: remove cast once we get rid of DatasetCoordinates + # and DataArrayCoordinates (i.e., Dataset and DataArray encapsulate Coordinates) + return cast(Self, results.coords) + + def _reindex_callback( + self, + aligner: Aligner, + dim_pos_indexers: dict[Hashable, Any], + variables: dict[Hashable, Variable], + indexes: dict[Hashable, Index], + fill_value: Any, + exclude_dims: frozenset[Hashable], + exclude_vars: frozenset[Hashable], + ) -> Self: + """Callback called from ``Aligner`` to create a new reindexed Coordinate.""" + aligned = self.to_dataset()._reindex_callback( + aligner, + dim_pos_indexers, + variables, + indexes, + fill_value, + exclude_dims, + exclude_vars, + ) + + # TODO: remove cast once we get rid of DatasetCoordinates + # and DataArrayCoordinates (i.e., Dataset and DataArray encapsulate Coordinates) + return cast(Self, aligned.coords) + + def _ipython_key_completions_(self): + """Provide method for the key-autocompletions in IPython.""" + return self._data._ipython_key_completions_() + + def copy( + self, + deep: bool = False, + memo: dict[int, Any] | None = None, + ) -> Self: + """Return a copy of this Coordinates object.""" + # do not copy indexes (may corrupt multi-coordinate indexes) + # TODO: disable variables deepcopy? it may also be problematic when they + # encapsulate index objects like pd.Index + variables = { + k: v._copy(deep=deep, memo=memo) for k, v in self.variables.items() + } + + # TODO: getting an error with `self._construct_direct`, possibly because of how + # a subclass implements `_construct_direct`. (This was originally the same + # runtime code, but we switched the type definitions in #8216, which + # necessitates the cast.) + return cast( + Self, + Coordinates._construct_direct( + coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes) + ), + ) + class DatasetCoordinates(Coordinates): - """Dictionary like container for Dataset coordinates. + """Dictionary like container for Dataset coordinates (variables + indexes). - Essentially an immutable dictionary with keys given by the array's - dimensions and the values given by the corresponding xarray.Coordinate - objects. + This collection can be passed directly to the :py:class:`~xarray.Dataset` + and :py:class:`~xarray.DataArray` constructors via their `coords` argument. + This will add both the coordinates variables and their index. """ _data: Dataset @@ -257,7 +695,7 @@ def _names(self) -> set[Hashable]: return self._data._coord_names @property - def dims(self) -> Mapping[Hashable, int]: + def dims(self) -> Frozen[Hashable, int]: return self._data.dims @property @@ -318,21 +756,28 @@ def _update_coords( original_indexes.update(indexes) self._data._indexes = original_indexes - def _maybe_drop_multiindex_coords(self, coords: set[Hashable]) -> None: - """Drops variables in coords, and any associated variables as well.""" + def _drop_coords(self, coord_names): + # should drop indexed coordinates only + for name in coord_names: + del self._data._variables[name] + del self._data._indexes[name] + self._data._coord_names.difference_update(coord_names) + + def _drop_indexed_coords(self, coords_to_drop: set[Hashable]) -> None: assert self._data.xindexes is not None - variables, indexes = drop_coords( - coords, self._data._variables, self._data.xindexes - ) - self._data._coord_names.intersection_update(variables) - self._data._variables = variables - self._data._indexes = indexes + new_coords = drop_indexed_coords(coords_to_drop, self) + for name in self._data._coord_names - new_coords._names: + del self._data._variables[name] + self._data._indexes = dict(new_coords.xindexes) + self._data._coord_names.intersection_update(new_coords._names) def __delitem__(self, key: Hashable) -> None: if key in self: del self._data[key] else: - raise KeyError(f"{key!r} is not a coordinate variable.") + raise KeyError( + f"{key!r} is not in coordinate variables {tuple(self.keys())}" + ) def _ipython_key_completions_(self): """Provide method for the key-autocompletions in IPython.""" @@ -343,11 +788,12 @@ def _ipython_key_completions_(self): ] -class DataArrayCoordinates(Coordinates["T_DataArray"]): - """Dictionary like container for DataArray coordinates. +class DataArrayCoordinates(Coordinates, Generic[T_DataArray]): + """Dictionary like container for DataArray coordinates (variables + indexes). - Essentially a dict with keys given by the array's - dimensions and the values given by corresponding DataArray objects. + This collection can be passed directly to the :py:class:`~xarray.Dataset` + and :py:class:`~xarray.DataArray` constructors via their `coords` argument. + This will add both the coordinates variables and their index. """ _data: T_DataArray @@ -398,13 +844,11 @@ def _update_coords( original_indexes.update(indexes) self._data._indexes = original_indexes - def _maybe_drop_multiindex_coords(self, coords: set[Hashable]) -> None: - """Drops variables in coords, and any associated variables as well.""" - variables, indexes = drop_coords( - coords, self._data._coords, self._data.xindexes - ) - self._data._coords = variables - self._data._indexes = indexes + def _drop_coords(self, coord_names): + # should drop indexed coordinates only + for name in coord_names: + del self._data._coords[name] + del self._data._indexes[name] @property def variables(self): @@ -419,7 +863,9 @@ def to_dataset(self) -> Dataset: def __delitem__(self, key: Hashable) -> None: if key not in self: - raise KeyError(f"{key!r} is not a coordinate variable.") + raise KeyError( + f"{key!r} is not in coordinate variables {tuple(self.keys())}" + ) assert_no_index_corrupted(self._data.xindexes, {key}) del self._data._coords[key] @@ -431,40 +877,51 @@ def _ipython_key_completions_(self): return self._data._ipython_key_completions_() -def drop_coords( - coords_to_drop: set[Hashable], variables, indexes: Indexes -) -> tuple[dict, dict]: - """Drop index variables associated with variables in coords_to_drop.""" - # Only warn when we're dropping the dimension with the multi-indexed coordinate - # If asked to drop a subset of the levels in a multi-index, we raise an error - # later but skip the warning here. - new_variables = dict(variables.copy()) - new_indexes = dict(indexes.copy()) - for key in coords_to_drop & set(indexes): - maybe_midx = indexes[key] - idx_coord_names = set(indexes.get_all_coords(key)) - if ( - isinstance(maybe_midx, PandasMultiIndex) - and key == maybe_midx.dim - and (idx_coord_names - coords_to_drop) - ): - warnings.warn( - f"Updating MultiIndexed coordinate {key!r} would corrupt indices for " - f"other variables: {list(maybe_midx.index.names)!r}. " - f"This will raise an error in the future. Use `.drop_vars({idx_coord_names!r})` before " +def drop_indexed_coords( + coords_to_drop: set[Hashable], coords: Coordinates +) -> Coordinates: + """Drop indexed coordinates associated with coordinates in coords_to_drop. + + This will raise an error in case it corrupts any passed index and its + coordinate variables. + + """ + new_variables = dict(coords.variables) + new_indexes = dict(coords.xindexes) + + for idx, idx_coords in coords.xindexes.group_by_index(): + idx_drop_coords = set(idx_coords) & coords_to_drop + + # special case for pandas multi-index: still allow but deprecate + # dropping only its dimension coordinate. + # TODO: remove when removing PandasMultiIndex's dimension coordinate. + if isinstance(idx, PandasMultiIndex) and idx_drop_coords == {idx.dim}: + idx_drop_coords.update(idx.index.names) + emit_user_level_warning( + f"updating coordinate {idx.dim!r} with a PandasMultiIndex would leave " + f"the multi-index level coordinates {list(idx.index.names)!r} in an inconsistent state. " + f"This will raise an error in the future. Use `.drop_vars({list(idx_coords)!r})` before " "assigning new coordinate values.", FutureWarning, - stacklevel=4, ) - for k in idx_coord_names: - del new_variables[k] - del new_indexes[k] - return new_variables, new_indexes + + elif idx_drop_coords and len(idx_drop_coords) != len(idx_coords): + idx_drop_coords_str = ", ".join(f"{k!r}" for k in idx_drop_coords) + idx_coords_str = ", ".join(f"{k!r}" for k in idx_coords) + raise ValueError( + f"cannot drop or update coordinate(s) {idx_drop_coords_str}, which would corrupt " + f"the following index built from coordinates {idx_coords_str}:\n" + f"{idx}" + ) + + for k in idx_drop_coords: + del new_variables[k] + del new_indexes[k] + + return Coordinates._construct_direct(coords=new_variables, indexes=new_indexes) -def assert_coordinate_consistent( - obj: T_DataArray | Dataset, coords: Mapping[Any, Variable] -) -> None: +def assert_coordinate_consistent(obj: T_Xarray, coords: Mapping[Any, Variable]) -> None: """Make sure the dimension coordinate of obj is consistent with coords. obj: DataArray or Dataset @@ -477,3 +934,81 @@ def assert_coordinate_consistent( f"dimension coordinate {k!r} conflicts between " f"indexed and indexing objects:\n{obj[k]}\nvs.\n{coords[k]}" ) + + +def create_coords_with_default_indexes( + coords: Mapping[Any, Any], data_vars: DataVars | None = None +) -> Coordinates: + """Returns a Coordinates object from a mapping of coordinates (arbitrary objects). + + Create default (pandas) indexes for each of the input dimension coordinates. + Extract coordinates from each input DataArray. + + """ + # Note: data_vars is needed here only because a pd.MultiIndex object + # can be promoted as coordinates. + # TODO: It won't be relevant anymore when this behavior will be dropped + # in favor of the more explicit ``Coordinates.from_pandas_multiindex()``. + + from xarray.core.dataarray import DataArray + + all_variables = dict(coords) + if data_vars is not None: + all_variables.update(data_vars) + + indexes: dict[Hashable, Index] = {} + variables: dict[Hashable, Variable] = {} + + # promote any pandas multi-index in data_vars as coordinates + coords_promoted: dict[Hashable, Any] = {} + pd_mindex_keys: list[Hashable] = [] + + for k, v in all_variables.items(): + if isinstance(v, pd.MultiIndex): + coords_promoted[k] = v + pd_mindex_keys.append(k) + elif k in coords: + coords_promoted[k] = v + + if pd_mindex_keys: + pd_mindex_keys_fmt = ",".join([f"'{k}'" for k in pd_mindex_keys]) + emit_user_level_warning( + f"the `pandas.MultiIndex` object(s) passed as {pd_mindex_keys_fmt} coordinate(s) or " + "data variable(s) will no longer be implicitly promoted and wrapped into " + "multiple indexed coordinates in the future " + "(i.e., one coordinate for each multi-index level + one dimension coordinate). " + "If you want to keep this behavior, you need to first wrap it explicitly using " + "`mindex_coords = xarray.Coordinates.from_pandas_multiindex(mindex_obj, 'dim')` " + "and pass it as coordinates, e.g., `xarray.Dataset(coords=mindex_coords)`, " + "`dataset.assign_coords(mindex_coords)` or `dataarray.assign_coords(mindex_coords)`.", + FutureWarning, + ) + + dataarray_coords: list[DataArrayCoordinates] = [] + + for name, obj in coords_promoted.items(): + if isinstance(obj, DataArray): + dataarray_coords.append(obj.coords) + + variable = as_variable(obj, name=name) + + if variable.dims == (name,): + idx, idx_vars = create_default_index_implicit(variable, all_variables) + indexes.update({k: idx for k in idx_vars}) + variables.update(idx_vars) + all_variables.update(idx_vars) + else: + variables[name] = variable + + new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes) + + # extract and merge coordinates and indexes from input DataArrays + if dataarray_coords: + prioritized = {k: (v, indexes.get(k, None)) for k, v in variables.items()} + variables, indexes = merge_coordinates_without_align( + dataarray_coords + [new_coords], + prioritized=prioritized, + ) + new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes) + + return new_coords diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index bbaf79e23ba..b892cf595b5 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4,7 +4,15 @@ import warnings from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence from os import PathLike -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + NoReturn, + overload, +) import numpy as np import pandas as pd @@ -23,7 +31,12 @@ from xarray.core.arithmetic import DataArrayArithmetic from xarray.core.common import AbstractArray, DataWithCoords, get_chunksizes from xarray.core.computation import unify_chunks -from xarray.core.coordinates import DataArrayCoordinates, assert_coordinate_consistent +from xarray.core.coordinates import ( + Coordinates, + DataArrayCoordinates, + assert_coordinate_consistent, + create_coords_with_default_indexes, +) from xarray.core.dataset import Dataset from xarray.core.formatting import format_item from xarray.core.indexes import ( @@ -34,8 +47,9 @@ isel_indexes, ) from xarray.core.indexing import is_fancy_indexer, map_index_queries -from xarray.core.merge import PANDAS_TYPES, MergeError, _create_indexes_from_coords +from xarray.core.merge import PANDAS_TYPES, MergeError from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.types import DaCompatible, T_DataArray, T_DataArrayOrSet from xarray.core.utils import ( Default, HybridMappingProxy, @@ -52,6 +66,7 @@ ) from xarray.plot.accessor import DataArrayPlotAccessor from xarray.plot.utils import _get_units_from_attrs +from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: from typing import TypeVar, Union @@ -95,8 +110,9 @@ QueryEngineOptions, QueryParserOptions, ReindexMethodOptions, + Self, SideOptions, - T_DataArray, + T_Chunks, T_Xarray, ) from xarray.core.weighted import DataArrayWeighted @@ -104,9 +120,28 @@ T_XarrayOther = TypeVar("T_XarrayOther", bound=Union["DataArray", Dataset]) +def _check_coords_dims(shape, coords, dims): + sizes = dict(zip(dims, shape)) + for k, v in coords.items(): + if any(d not in dims for d in v.dims): + raise ValueError( + f"coordinate {k} has dimensions {v.dims}, but these " + "are not a subset of the DataArray " + f"dimensions {dims}" + ) + + for d, s in v.sizes.items(): + if s != sizes[d]: + raise ValueError( + f"conflicting sizes for dimension {d!r}: " + f"length {sizes[d]} on the data but length {s} on " + f"coordinate {k!r}" + ) + + def _infer_coords_and_dims( shape, coords, dims -) -> tuple[dict[Hashable, Variable], tuple[Hashable, ...]]: +) -> tuple[Mapping[Hashable, Any], tuple[Hashable, ...]]: """All the logic for creating a new DataArray""" if ( @@ -144,40 +179,22 @@ def _infer_coords_and_dims( if not isinstance(d, str): raise TypeError(f"dimension {d} is not a string") - new_coords: dict[Hashable, Variable] = {} - - if utils.is_dict_like(coords): - for k, v in coords.items(): - new_coords[k] = as_variable(v, name=k) - elif coords is not None: - for dim, coord in zip(dims, coords): - var = as_variable(coord, name=dim) - var.dims = (dim,) - new_coords[dim] = var.to_index_variable() + new_coords: Mapping[Hashable, Any] - sizes = dict(zip(dims, shape)) - for k, v in new_coords.items(): - if any(d not in dims for d in v.dims): - raise ValueError( - f"coordinate {k} has dimensions {v.dims}, but these " - "are not a subset of the DataArray " - f"dimensions {dims}" - ) - - for d, s in zip(v.dims, v.shape): - if s != sizes[d]: - raise ValueError( - f"conflicting sizes for dimension {d!r}: " - f"length {sizes[d]} on the data but length {s} on " - f"coordinate {k!r}" - ) + if isinstance(coords, Coordinates): + new_coords = coords + else: + new_coords = {} + if utils.is_dict_like(coords): + for k, v in coords.items(): + new_coords[k] = as_variable(v, name=k) + elif coords is not None: + for dim, coord in zip(dims, coords): + var = as_variable(coord, name=dim) + var.dims = (dim,) + new_coords[dim] = var.to_index_variable() - if k in sizes and v.shape != (sizes[k],): - raise ValueError( - f"coordinate {k!r} is a DataArray dimension, but " - f"it has shape {v.shape!r} rather than expected shape {sizes[k]!r} " - "matching the dimension size" - ) + _check_coords_dims(shape, new_coords, dims) return new_coords, dims @@ -200,13 +217,13 @@ def _check_data_shape(data, coords, dims): return data -class _LocIndexer: +class _LocIndexer(Generic[T_DataArray]): __slots__ = ("data_array",) - def __init__(self, data_array: DataArray): + def __init__(self, data_array: T_DataArray): self.data_array = data_array - def __getitem__(self, key) -> DataArray: + def __getitem__(self, key) -> T_DataArray: if not utils.is_dict_like(key): # expand the indexer so we can handle Ellipsis labels = indexing.expanded_indexer(key, self.data_array.ndim) @@ -266,7 +283,7 @@ class DataArray( or pandas object, attempts are made to use this array's metadata to fill in other unspecified arguments. A view of the array's data is used instead of a copy if possible. - coords : sequence or dict of array_like, optional + coords : sequence or dict of array_like or :py:class:`~xarray.Coordinates`, optional Coordinates (tick labels) to use for indexing along each dimension. The following notations are accepted: @@ -286,6 +303,10 @@ class DataArray( - mapping {coord name: (dimension name, array-like)} - mapping {coord name: (tuple of dimension names, array-like)} + Alternatively, a :py:class:`~xarray.Coordinates` object may be used in + order to explicitly pass indexes (e.g., a multi-index or any custom + Xarray index) or to bypass the creation of a default index for any + :term:`Dimension coordinate` included in that object. dims : Hashable or sequence of Hashable, optional Name(s) of the data dimension(s). Must be either a Hashable (only for 1D data) or a sequence of Hashables with length equal @@ -297,6 +318,11 @@ class DataArray( attrs : dict_like or None, optional Attributes to assign to the new instance. By default, an empty attribute dictionary is initialized. + indexes : py:class:`~xarray.Indexes` or dict-like, optional + For internal use only. For passing indexes objects to the + new DataArray, use the ``coords`` argument instead with a + :py:class:`~xarray.Coordinate` object (both coordinate variables + and indexes will be extracted from the latter). Examples -------- @@ -386,7 +412,7 @@ def __init__( name: Hashable | None = None, attrs: Mapping | None = None, # internal parameters - indexes: dict[Hashable, Index] | None = None, + indexes: Mapping[Any, Index] | None = None, fastpath: bool = False, ) -> None: if fastpath: @@ -395,10 +421,11 @@ def __init__( assert attrs is None assert indexes is not None else: - # TODO: (benbovy - explicit indexes) remove - # once it becomes part of the public interface if indexes is not None: - raise ValueError("Providing explicit indexes is not supported yet") + raise ValueError( + "Explicitly passing indexes via the `indexes` argument is not supported " + "when `fastpath=False`. Use the `coords` argument instead." + ) # try to fill in arguments from data if they weren't supplied if coords is None: @@ -422,28 +449,29 @@ def __init__( data = as_compatible_data(data) coords, dims = _infer_coords_and_dims(data.shape, coords, dims) variable = Variable(dims, data, attrs, fastpath=True) - indexes, coords = _create_indexes_from_coords(coords) + + if not isinstance(coords, Coordinates): + coords = create_coords_with_default_indexes(coords) + indexes = dict(coords.xindexes) + coords = {k: v.copy() for k, v in coords.variables.items()} # These fully describe a DataArray self._variable = variable assert isinstance(coords, dict) self._coords = coords self._name = name - - # TODO(shoyer): document this argument, once it becomes part of the - # public interface. - self._indexes = indexes + self._indexes = indexes # type: ignore[assignment] self._close = None @classmethod def _construct_direct( - cls: type[T_DataArray], + cls, variable: Variable, coords: dict[Any, Variable], name: Hashable, indexes: dict[Hashable, Index], - ) -> T_DataArray: + ) -> Self: """Shortcut around __init__ for internal use when we want to skip costly validation """ @@ -456,12 +484,12 @@ def _construct_direct( return obj def _replace( - self: T_DataArray, + self, variable: Variable | None = None, coords=None, name: Hashable | None | Default = _default, indexes=None, - ) -> T_DataArray: + ) -> Self: if variable is None: variable = self.variable if coords is None: @@ -473,10 +501,10 @@ def _replace( return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True) def _replace_maybe_drop_dims( - self: T_DataArray, + self, variable: Variable, name: Hashable | None | Default = _default, - ) -> T_DataArray: + ) -> Self: if variable.dims == self.dims and variable.shape == self.shape: coords = self._coords.copy() indexes = self._indexes @@ -498,18 +526,18 @@ def _replace_maybe_drop_dims( return self._replace(variable, coords, name, indexes=indexes) def _overwrite_indexes( - self: T_DataArray, + self, indexes: Mapping[Any, Index], - coords: Mapping[Any, Variable] | None = None, + variables: Mapping[Any, Variable] | None = None, drop_coords: list[Hashable] | None = None, rename_dims: Mapping[Any, Any] | None = None, - ) -> T_DataArray: + ) -> Self: """Maybe replace indexes and their corresponding coordinates.""" if not indexes: return self - if coords is None: - coords = {} + if variables is None: + variables = {} if drop_coords is None: drop_coords = [] @@ -518,7 +546,7 @@ def _overwrite_indexes( new_indexes = dict(self._indexes) for name in indexes: - new_coords[name] = coords[name] + new_coords[name] = variables[name] new_indexes[name] = indexes[name] for name in drop_coords: @@ -536,8 +564,8 @@ def _to_temp_dataset(self) -> Dataset: return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False) def _from_temp_dataset( - self: T_DataArray, dataset: Dataset, name: Hashable | None | Default = _default - ) -> T_DataArray: + self, dataset: Dataset, name: Hashable | None | Default = _default + ) -> Self: variable = dataset._variables.pop(_THIS_ARRAY) coords = dataset._variables indexes = dataset._indexes @@ -749,7 +777,7 @@ def to_numpy(self) -> np.ndarray: """ return self.variable.to_numpy() - def as_numpy(self: T_DataArray) -> T_DataArray: + def as_numpy(self) -> Self: """ Coerces wrapped data and coordinates into numpy arrays, returning a DataArray. @@ -804,7 +832,7 @@ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: key = indexing.expanded_indexer(key, self.ndim) return dict(zip(self.dims, key)) - def _getitem_coord(self: T_DataArray, key: Any) -> T_DataArray: + def _getitem_coord(self, key: Any) -> Self: from xarray.core.dataset import _get_virtual_variable try: @@ -815,7 +843,7 @@ def _getitem_coord(self: T_DataArray, key: Any) -> T_DataArray: return self._replace_maybe_drop_dims(var, name=key) - def __getitem__(self: T_DataArray, key: Any) -> T_DataArray: + def __getitem__(self, key: Any) -> Self: if isinstance(key, str): return self._getitem_coord(key) else: @@ -832,6 +860,7 @@ def __setitem__(self, key: Any, value: Any) -> None: obj = self[key] if isinstance(value, DataArray): assert_coordinate_consistent(value, obj.coords.variables) + value = value.variable # DataArray key -> Variable key key = { k: v.variable if isinstance(v, DataArray) else v @@ -884,10 +913,16 @@ def encoding(self) -> dict[Any, Any]: def encoding(self, value: Mapping[Any, Any]) -> None: self.variable.encoding = dict(value) - def reset_encoding(self: T_DataArray) -> T_DataArray: + def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: """Return a new DataArray without encoding on the array or any attached coords.""" - ds = self._to_temp_dataset().reset_encoding() + ds = self._to_temp_dataset().drop_encoding() return self._from_temp_dataset(ds) @property @@ -906,36 +941,47 @@ def indexes(self) -> Indexes: @property def xindexes(self) -> Indexes: - """Mapping of xarray Index objects used for label based indexing.""" + """Mapping of :py:class:`~xarray.indexes.Index` objects + used for label based indexing. + """ return Indexes(self._indexes, {k: self._coords[k] for k in self._indexes}) @property def coords(self) -> DataArrayCoordinates: - """Dictionary-like container of coordinate arrays.""" + """Mapping of :py:class:`~xarray.DataArray` objects corresponding to + coordinate variables. + + See Also + -------- + Coordinates + """ return DataArrayCoordinates(self) @overload def reset_coords( - self: T_DataArray, + self, names: Dims = None, + *, drop: Literal[False] = False, ) -> Dataset: ... @overload def reset_coords( - self: T_DataArray, + self, names: Dims = None, *, drop: Literal[True], - ) -> T_DataArray: + ) -> Self: ... + @_deprecate_positional_args("v2023.10.0") def reset_coords( - self: T_DataArray, + self, names: Dims = None, + *, drop: bool = False, - ) -> T_DataArray | Dataset: + ) -> Self | Dataset: """Given names of coordinates, reset them to become variables. Parameters @@ -1047,15 +1093,15 @@ def __dask_postpersist__(self): func, args = self._to_temp_dataset().__dask_postpersist__() return self._dask_finalize, (self.name, func) + args - @staticmethod - def _dask_finalize(results, name, func, *args, **kwargs) -> DataArray: + @classmethod + def _dask_finalize(cls, results, name, func, *args, **kwargs) -> Self: ds = func(results, *args, **kwargs) variable = ds._variables.pop(_THIS_ARRAY) coords = ds._variables indexes = ds._indexes - return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) + return cls(variable, coords, name=name, indexes=indexes, fastpath=True) - def load(self: T_DataArray, **kwargs) -> T_DataArray: + def load(self, **kwargs) -> Self: """Manually trigger loading of this array's data from disk or a remote source into memory and return this array. @@ -1079,7 +1125,7 @@ def load(self: T_DataArray, **kwargs) -> T_DataArray: self._coords = new._coords return self - def compute(self: T_DataArray, **kwargs) -> T_DataArray: + def compute(self, **kwargs) -> Self: """Manually trigger loading of this array's data from disk or a remote source into memory and return a new array. The original is left unaltered. @@ -1101,7 +1147,7 @@ def compute(self: T_DataArray, **kwargs) -> T_DataArray: new = self.copy(deep=False) return new.load(**kwargs) - def persist(self: T_DataArray, **kwargs) -> T_DataArray: + def persist(self, **kwargs) -> Self: """Trigger computation in constituent dask arrays This keeps them as dask arrays but encourages them to keep data in @@ -1120,7 +1166,7 @@ def persist(self: T_DataArray, **kwargs) -> T_DataArray: ds = self._to_temp_dataset().persist(**kwargs) return self._from_temp_dataset(ds) - def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: + def copy(self, deep: bool = True, data: Any = None) -> Self: """Returns a copy of this array. If `deep=True`, a deep copy is made of the data array. @@ -1191,11 +1237,11 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: return self._copy(deep=deep, data=data) def _copy( - self: T_DataArray, + self, deep: bool = True, data: Any = None, memo: dict[int, Any] | None = None, - ) -> T_DataArray: + ) -> Self: variable = self.variable._copy(deep=deep, data=data, memo=memo) indexes, index_vars = self.xindexes.copy_indexes(deep=deep) @@ -1208,12 +1254,10 @@ def _copy( return self._replace(variable, coords, indexes=indexes) - def __copy__(self: T_DataArray) -> T_DataArray: + def __copy__(self) -> Self: return self._copy(deep=False) - def __deepcopy__( - self: T_DataArray, memo: dict[int, Any] | None = None - ) -> T_DataArray: + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: return self._copy(deep=True, memo=memo) # mutable objects should not be Hashable @@ -1253,15 +1297,11 @@ def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: all_variables = [self.variable] + [c.variable for c in self.coords.values()] return get_chunksizes(all_variables) + @_deprecate_positional_args("v2023.10.0") def chunk( - self: T_DataArray, - chunks: ( - int - | Literal["auto"] - | tuple[int, ...] - | tuple[tuple[int, ...], ...] - | Mapping[Any, None | int | tuple[int, ...]] - ) = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + self, + chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + *, name_prefix: str = "xarray-", token: str | None = None, lock: bool = False, @@ -1269,7 +1309,7 @@ def chunk( chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, **chunks_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Coerce this array's data into a dask arrays with the given chunks. If this variable is a non-dask array, it will be converted to dask @@ -1329,8 +1369,13 @@ def chunk( if isinstance(chunks, (float, str, int)): # ignoring type; unclear why it won't accept a Literal into the value. - chunks = dict.fromkeys(self.dims, chunks) # type: ignore + chunks = dict.fromkeys(self.dims, chunks) elif isinstance(chunks, (tuple, list)): + utils.emit_user_level_warning( + "Supplying chunks as dimension-order tuples is deprecated. " + "It will raise an error in the future. Instead use a dict with dimension names as keys.", + category=DeprecationWarning, + ) chunks = dict(zip(self.dims, chunks)) else: chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") @@ -1347,12 +1392,12 @@ def chunk( return self._from_temp_dataset(ds) def isel( - self: T_DataArray, + self, indexers: Mapping[Any, Any] | None = None, drop: bool = False, missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by selecting indexes along the specified dimension(s). @@ -1438,13 +1483,13 @@ def isel( return self._replace(variable=variable, coords=coords, indexes=indexes) def sel( - self: T_DataArray, + self, indexers: Mapping[Any, Any] | None = None, method: str | None = None, tolerance=None, drop: bool = False, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by selecting index labels along the specified dimension(s). @@ -1557,10 +1602,10 @@ def sel( return self._from_temp_dataset(ds) def head( - self: T_DataArray, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by the the first `n` values along the specified dimension(s). Default `n` = 5 @@ -1600,10 +1645,10 @@ def head( return self._from_temp_dataset(ds) def tail( - self: T_DataArray, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by the the last `n` values along the specified dimension(s). Default `n` = 5 @@ -1647,10 +1692,10 @@ def tail( return self._from_temp_dataset(ds) def thin( - self: T_DataArray, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by each `n` value along the specified dimension(s). @@ -1696,11 +1741,13 @@ def thin( ds = self._to_temp_dataset().thin(indexers, **indexers_kwargs) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def broadcast_like( - self: T_DataArray, - other: DataArray | Dataset, + self, + other: T_DataArrayOrSet, + *, exclude: Iterable[Hashable] | None = None, - ) -> T_DataArray: + ) -> Self: """Broadcast this DataArray against another Dataset or DataArray. This is equivalent to xr.broadcast(other, self)[1] @@ -1770,12 +1817,10 @@ def broadcast_like( dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) - return _broadcast_helper( - cast("T_DataArray", args[1]), exclude, dims_map, common_coords - ) + return _broadcast_helper(args[1], exclude, dims_map, common_coords) def _reindex_callback( - self: T_DataArray, + self, aligner: alignment.Aligner, dim_pos_indexers: dict[Hashable, Any], variables: dict[Hashable, Variable], @@ -1783,7 +1828,7 @@ def _reindex_callback( fill_value: Any, exclude_dims: frozenset[Hashable], exclude_vars: frozenset[Hashable], - ) -> T_DataArray: + ) -> Self: """Callback called from ``Aligner`` to create a new reindexed DataArray.""" if isinstance(fill_value, dict): @@ -1803,18 +1848,26 @@ def _reindex_callback( exclude_dims, exclude_vars, ) - return self._from_temp_dataset(reindexed) + da = self._from_temp_dataset(reindexed) + da.encoding = self.encoding + + return da + + @_deprecate_positional_args("v2023.10.0") def reindex_like( - self: T_DataArray, - other: DataArray | Dataset, + self, + other: T_DataArrayOrSet, + *, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value=dtypes.NA, - ) -> T_DataArray: - """Conform this object onto the indexes of another object, filling in - missing values with ``fill_value``. The default fill value is NaN. + ) -> Self: + """ + Conform this object onto the indexes of another object, for indexes which the + objects share. Missing values are filled with ``fill_value``. The default fill + value is NaN. Parameters ---------- @@ -1902,20 +1955,14 @@ def reindex_like( * x (x) int64 40 30 20 10 * y (y) int64 90 80 70 - Reindexing with the other array having coordinates which the source array doesn't have: + Reindexing with the other array having additional coordinates: - >>> data = np.arange(12).reshape(4, 3) - >>> da1 = xr.DataArray( - ... data=data, - ... dims=["x", "y"], - ... coords={"x": [10, 20, 30, 40], "y": [70, 80, 90]}, - ... ) - >>> da2 = xr.DataArray( + >>> da3 = xr.DataArray( ... data=data, ... dims=["x", "y"], ... coords={"x": [20, 10, 29, 39], "y": [70, 80, 90]}, ... ) - >>> da1.reindex_like(da2) + >>> da1.reindex_like(da3) array([[ 3., 4., 5.], [ 0., 1., 2.], @@ -1927,7 +1974,7 @@ def reindex_like( Filling missing values with the previous valid index with respect to the coordinates' value: - >>> da1.reindex_like(da2, method="ffill") + >>> da1.reindex_like(da3, method="ffill") array([[3, 4, 5], [0, 1, 2], @@ -1939,7 +1986,7 @@ def reindex_like( Filling missing values while tolerating specified error for inexact matches: - >>> da1.reindex_like(da2, method="ffill", tolerance=5) + >>> da1.reindex_like(da3, method="ffill", tolerance=5) array([[ 3., 4., 5.], [ 0., 1., 2.], @@ -1951,7 +1998,7 @@ def reindex_like( Filling missing values with manually specified values: - >>> da1.reindex_like(da2, fill_value=19) + >>> da1.reindex_like(da3, fill_value=19) array([[ 3, 4, 5], [ 0, 1, 2], @@ -1961,9 +2008,28 @@ def reindex_like( * x (x) int64 20 10 29 39 * y (y) int64 70 80 90 + Note that unlike ``broadcast_like``, ``reindex_like`` doesn't create new dimensions: + + >>> da1.sel(x=20) + + array([3, 4, 5]) + Coordinates: + x int64 20 + * y (y) int64 70 80 90 + + ...so ``b`` in not added here: + + >>> da1.sel(x=20).reindex_like(da1) + + array([3, 4, 5]) + Coordinates: + x int64 20 + * y (y) int64 70 80 90 + See Also -------- DataArray.reindex + DataArray.broadcast_like align """ return alignment.reindex_like( @@ -1975,15 +2041,17 @@ def reindex_like( fill_value=fill_value, ) + @_deprecate_positional_args("v2023.10.0") def reindex( - self: T_DataArray, + self, indexers: Mapping[Any, Any] | None = None, + *, method: ReindexMethodOptions = None, tolerance: float | Iterable[float] | None = None, copy: bool = True, fill_value=dtypes.NA, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Conform this object onto the indexes of another object, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -2067,13 +2135,13 @@ def reindex( ) def interp( - self: T_DataArray, + self, coords: Mapping[Any, Any] | None = None, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, **coords_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Interpolate a DataArray onto new coordinates Performs univariate or multivariate interpolation of a DataArray onto @@ -2198,8 +2266,7 @@ def interp( """ if self.dtype.kind not in "uifc": raise TypeError( - "interp only works for a numeric type array. " - "Given {}.".format(self.dtype) + "interp only works for a numeric type array. " f"Given {self.dtype}." ) ds = self._to_temp_dataset().interp( coords, @@ -2211,12 +2278,12 @@ def interp( return self._from_temp_dataset(ds) def interp_like( - self: T_DataArray, - other: DataArray | Dataset, + self, + other: T_Xarray, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, - ) -> T_DataArray: + ) -> Self: """Interpolate this object onto the coordinates of another object, filling out of range values with NaN. @@ -2326,21 +2393,18 @@ def interp_like( """ if self.dtype.kind not in "uifc": raise TypeError( - "interp only works for a numeric type array. " - "Given {}.".format(self.dtype) + "interp only works for a numeric type array. " f"Given {self.dtype}." ) ds = self._to_temp_dataset().interp_like( other, method=method, kwargs=kwargs, assume_sorted=assume_sorted ) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def rename( self, new_name_or_name_dict: Hashable | Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> DataArray: + ) -> Self: """Returns a new DataArray with renamed coordinates, dimensions or a new name. Parameters @@ -2381,10 +2445,10 @@ def rename( return self._replace(name=new_name_or_name_dict) def swap_dims( - self: T_DataArray, + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs, - ) -> T_DataArray: + ) -> Self: """Returns a new DataArray with swapped dimensions. Parameters @@ -2439,14 +2503,12 @@ def swap_dims( ds = self._to_temp_dataset().swap_dims(dims_dict) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def expand_dims( self, dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None, axis: None | int | Sequence[int] = None, **dim_kwargs: Any, - ) -> DataArray: + ) -> Self: """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. The new object is a view into the underlying array, not a copy. @@ -2529,20 +2591,18 @@ def expand_dims( raise ValueError("dims should not contain duplicate values.") dim = dict.fromkeys(dim, 1) elif dim is not None and not isinstance(dim, Mapping): - dim = {cast(Hashable, dim): 1} + dim = {dim: 1} dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims") ds = self._to_temp_dataset().expand_dims(dim, axis) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def set_index( self, indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None, append: bool = False, **indexes_kwargs: Hashable | Sequence[Hashable], - ) -> DataArray: + ) -> Self: """Set DataArray (multi-)indexes using one or more existing coordinates. @@ -2600,13 +2660,11 @@ def set_index( ds = self._to_temp_dataset().set_index(indexes, append=append, **indexes_kwargs) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def reset_index( self, dims_or_levels: Hashable | Sequence[Hashable], drop: bool = False, - ) -> DataArray: + ) -> Self: """Reset the specified index(es) or multi-index level(s). This legacy method is specific to pandas (multi-)indexes and @@ -2640,11 +2698,11 @@ def reset_index( return self._from_temp_dataset(ds) def set_xindex( - self: T_DataArray, + self, coord_names: str | Sequence[Hashable], index_cls: type[Index] | None = None, **options, - ) -> T_DataArray: + ) -> Self: """Set a new, Xarray-compatible index from one or more existing coordinate(s). @@ -2669,10 +2727,10 @@ def set_xindex( return self._from_temp_dataset(ds) def reorder_levels( - self: T_DataArray, + self, dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None, **dim_order_kwargs: Sequence[int | Hashable], - ) -> T_DataArray: + ) -> Self: """Rearrange index levels using input order. Parameters @@ -2695,12 +2753,12 @@ def reorder_levels( return self._from_temp_dataset(ds) def stack( - self: T_DataArray, + self, dimensions: Mapping[Any, Sequence[Hashable]] | None = None, create_index: bool | None = True, index_cls: type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable], - ) -> T_DataArray: + ) -> Self: """ Stack any number of existing dimensions into a single new dimension. @@ -2767,14 +2825,14 @@ def stack( ) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved + @_deprecate_positional_args("v2023.10.0") def unstack( self, dim: Dims = None, + *, fill_value: Any = dtypes.NA, sparse: bool = False, - ) -> DataArray: + ) -> Self: """ Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. @@ -2829,7 +2887,7 @@ def unstack( -------- DataArray.stack """ - ds = self._to_temp_dataset().unstack(dim, fill_value, sparse) + ds = self._to_temp_dataset().unstack(dim, fill_value=fill_value, sparse=sparse) return self._from_temp_dataset(ds) def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Dataset: @@ -2868,9 +2926,9 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data b (x) int64 0 3 >>> stacked = data.to_stacked_array("z", ["x"]) >>> stacked.indexes["z"] - MultiIndex([('a', 0.0), - ('a', 1.0), - ('a', 2.0), + MultiIndex([('a', 0), + ('a', 1), + ('a', 2), ('b', nan)], name='z') >>> roundtripped = stacked.to_unstacked_dataset(dim="z") @@ -2898,11 +2956,11 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data return Dataset(data_dict) def transpose( - self: T_DataArray, + self, *dims: Hashable, transpose_coords: bool = True, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> T_DataArray: + ) -> Self: """Return a new DataArray object with transposed dimensions. Parameters @@ -2948,17 +3006,15 @@ def transpose( return self._replace(variable) @property - def T(self: T_DataArray) -> T_DataArray: + def T(self) -> Self: return self.transpose() - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def drop_vars( self, names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> DataArray: + ) -> Self: """Returns an array with dropped variables. Parameters @@ -3019,11 +3075,11 @@ def drop_vars( return self._from_temp_dataset(ds) def drop_indexes( - self: T_DataArray, + self, coord_names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_DataArray: + ) -> Self: """Drop the indexes assigned to the given coordinates. Parameters @@ -3044,13 +3100,13 @@ def drop_indexes( return self._from_temp_dataset(ds) def drop( - self: T_DataArray, + self, labels: Mapping[Any, Any] | None = None, dim: Hashable | None = None, *, errors: ErrorOptions = "raise", **labels_kwargs, - ) -> T_DataArray: + ) -> Self: """Backward compatible method based on `drop_vars` and `drop_sel` Using either `drop_vars` or `drop_sel` is encouraged @@ -3064,12 +3120,12 @@ def drop( return self._from_temp_dataset(ds) def drop_sel( - self: T_DataArray, + self, labels: Mapping[Any, Any] | None = None, *, errors: ErrorOptions = "raise", **labels_kwargs, - ) -> T_DataArray: + ) -> Self: """Drop index labels from this DataArray. Parameters @@ -3132,8 +3188,8 @@ def drop_sel( return self._from_temp_dataset(ds) def drop_isel( - self: T_DataArray, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs - ) -> T_DataArray: + self, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs + ) -> Self: """Drop index positions from this DataArray. Parameters @@ -3182,12 +3238,14 @@ def drop_isel( dataset = dataset.drop_isel(indexers=indexers, **indexers_kwargs) return self._from_temp_dataset(dataset) + @_deprecate_positional_args("v2023.10.0") def dropna( - self: T_DataArray, + self, dim: Hashable, + *, how: Literal["any", "all"] = "any", thresh: int | None = None, - ) -> T_DataArray: + ) -> Self: """Returns a new array with dropped labels for missing values along the provided dimension. @@ -3258,7 +3316,7 @@ def dropna( ds = self._to_temp_dataset().dropna(dim, how=how, thresh=thresh) return self._from_temp_dataset(ds) - def fillna(self: T_DataArray, value: Any) -> T_DataArray: + def fillna(self, value: Any) -> Self: """Fill missing values in this object. This operation follows the normal broadcasting and alignment rules that @@ -3321,7 +3379,7 @@ def fillna(self: T_DataArray, value: Any) -> T_DataArray: return out def interpolate_na( - self: T_DataArray, + self, dim: Hashable | None = None, method: InterpOptions = "linear", limit: int | None = None, @@ -3337,7 +3395,7 @@ def interpolate_na( ) = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Fill in NaNs by interpolating according to different methods. Parameters @@ -3345,7 +3403,7 @@ def interpolate_na( dim : Hashable or None, optional Specifies the dimension along which to interpolate. method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \ - "barycentric", "krog", "pchip", "spline", "akima"}, default: "linear" + "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear" String indicating which method to use for interpolation: - 'linear': linear interpolation. Additional keyword @@ -3354,7 +3412,7 @@ def interpolate_na( are passed to :py:func:`scipy.interpolate.interp1d`. If ``method='polynomial'``, the ``order`` keyword argument must also be provided. - - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their + - 'barycentric', 'krogh', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. use_coordinate : bool or str, default: True @@ -3444,9 +3502,7 @@ def interpolate_na( **kwargs, ) - def ffill( - self: T_DataArray, dim: Hashable, limit: int | None = None - ) -> T_DataArray: + def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values forward *Requires bottleneck.* @@ -3530,9 +3586,7 @@ def ffill( return ffill(self, dim, limit=limit) - def bfill( - self: T_DataArray, dim: Hashable, limit: int | None = None - ) -> T_DataArray: + def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward *Requires bottleneck.* @@ -3616,7 +3670,7 @@ def bfill( return bfill(self, dim, limit=limit) - def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray: + def combine_first(self, other: Self) -> Self: """Combine two DataArray objects, with union of coordinates. This operation follows the normal broadcasting and alignment rules of @@ -3635,7 +3689,7 @@ def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray: return ops.fillna(self, other, join="outer") def reduce( - self: T_DataArray, + self, func: Callable[..., Any], dim: Dims = None, *, @@ -3643,7 +3697,7 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Reduce this array by applying `func` along some dimension(s). Parameters @@ -3681,7 +3735,7 @@ def reduce( var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs) return self._replace_maybe_drop_dims(var) - def to_pandas(self) -> DataArray | pd.Series | pd.DataFrame: + def to_pandas(self) -> Self | pd.Series | pd.DataFrame: """Convert this array into a pandas object with the same shape. The type of the returned object depends on the number of DataArray @@ -3998,6 +4052,7 @@ def to_zarr( mode: Literal["w", "w-", "a", "r+", None] = None, synchronizer=None, group: str | None = None, + *, encoding: Mapping | None = None, compute: Literal[True] = True, consolidated: bool | None = None, @@ -4038,6 +4093,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, @@ -4235,7 +4291,7 @@ def to_dict( return d @classmethod - def from_dict(cls: type[T_DataArray], d: Mapping[str, Any]) -> T_DataArray: + def from_dict(cls, d: Mapping[str, Any]) -> Self: """Convert a dictionary into an xarray.DataArray Parameters @@ -4289,7 +4345,7 @@ def from_dict(cls: type[T_DataArray], d: Mapping[str, Any]) -> T_DataArray: except KeyError as e: raise ValueError( "cannot convert dict when coords are missing the key " - "'{dims_data}'".format(dims_data=str(e.args[0])) + f"'{str(e.args[0])}'" ) try: data = d["data"] @@ -4327,7 +4383,7 @@ def from_series(cls, series: pd.Series, sparse: bool = False) -> DataArray: temp_name = "__temporary_name" df = pd.DataFrame({temp_name: series}) ds = Dataset.from_dataframe(df, sparse=sparse) - result = cast(DataArray, ds[temp_name]) + result = ds[temp_name] result.name = series.name return result @@ -4352,7 +4408,7 @@ def to_cdms2(self) -> cdms2_Variable: return to_cdms2(self) @classmethod - def from_cdms2(cls, variable: cdms2_Variable) -> DataArray: + def from_cdms2(cls, variable: cdms2_Variable) -> Self: """Convert a cdms2.Variable into an xarray.DataArray .. deprecated:: 2023.06.0 @@ -4379,13 +4435,13 @@ def to_iris(self) -> iris_Cube: return to_iris(self) @classmethod - def from_iris(cls, cube: iris_Cube) -> DataArray: + def from_iris(cls, cube: iris_Cube) -> Self: """Convert a iris.cube.Cube into an xarray.DataArray""" from xarray.convert import from_iris return from_iris(cube) - def _all_compat(self: T_DataArray, other: T_DataArray, compat_str: str) -> bool: + def _all_compat(self, other: Self, compat_str: str) -> bool: """Helper function for equals, broadcast_equals, and identical""" def compat(x, y): @@ -4395,7 +4451,7 @@ def compat(x, y): self, other ) - def broadcast_equals(self: T_DataArray, other: T_DataArray) -> bool: + def broadcast_equals(self, other: Self) -> bool: """Two DataArrays are broadcast equal if they are equal after broadcasting them against each other such that they have the same dimensions. @@ -4429,7 +4485,7 @@ def broadcast_equals(self: T_DataArray, other: T_DataArray) -> bool: [2, 2]]) Dimensions without coordinates: X, Y - .equals returns True if two DataArrays have the same values, dimensions, and coordinates. .broadcast_equals returns True if the results of broadcasting two DataArrays against eachother have the same values, dimensions, and coordinates. + .equals returns True if two DataArrays have the same values, dimensions, and coordinates. .broadcast_equals returns True if the results of broadcasting two DataArrays against each other have the same values, dimensions, and coordinates. >>> a.equals(b) False @@ -4444,7 +4500,7 @@ def broadcast_equals(self: T_DataArray, other: T_DataArray) -> bool: except (TypeError, AttributeError): return False - def equals(self: T_DataArray, other: T_DataArray) -> bool: + def equals(self, other: Self) -> bool: """True if two DataArrays have the same dimensions, coordinates and values; otherwise False. @@ -4506,7 +4562,7 @@ def equals(self: T_DataArray, other: T_DataArray) -> bool: except (TypeError, AttributeError): return False - def identical(self: T_DataArray, other: T_DataArray) -> bool: + def identical(self, other: Self) -> bool: """Like equals, but also checks the array name and attributes, and attributes on all coordinates. @@ -4573,19 +4629,19 @@ def _result_name(self, other: Any = None) -> Hashable | None: else: return None - def __array_wrap__(self: T_DataArray, obj, context=None) -> T_DataArray: + def __array_wrap__(self, obj, context=None) -> Self: new_var = self.variable.__array_wrap__(obj, context) return self._replace(new_var) - def __matmul__(self: T_DataArray, obj: T_DataArray) -> T_DataArray: + def __matmul__(self, obj: T_Xarray) -> T_Xarray: return self.dot(obj) - def __rmatmul__(self: T_DataArray, other: T_DataArray) -> T_DataArray: + def __rmatmul__(self, other: T_Xarray) -> T_Xarray: # currently somewhat duplicative, as only other DataArrays are # compatible with matmul return computation.dot(other, self) - def _unary_op(self: T_DataArray, f: Callable, *args, **kwargs) -> T_DataArray: + def _unary_op(self, f: Callable, *args, **kwargs) -> Self: keep_attrs = kwargs.pop("keep_attrs", None) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) @@ -4601,32 +4657,29 @@ def _unary_op(self: T_DataArray, f: Callable, *args, **kwargs) -> T_DataArray: return da def _binary_op( - self: T_DataArray, - other: Any, - f: Callable, - reflexive: bool = False, - ) -> T_DataArray: + self, other: DaCompatible, f: Callable, reflexive: bool = False + ) -> Self: from xarray.core.groupby import GroupBy if isinstance(other, (Dataset, GroupBy)): return NotImplemented if isinstance(other, DataArray): align_type = OPTIONS["arithmetic_join"] - self, other = align(self, other, join=align_type, copy=False) # type: ignore - other_variable = getattr(other, "variable", other) + self, other = align(self, other, join=align_type, copy=False) + other_variable_or_arraylike: DaCompatible = getattr(other, "variable", other) other_coords = getattr(other, "coords", None) variable = ( - f(self.variable, other_variable) + f(self.variable, other_variable_or_arraylike) if not reflexive - else f(other_variable, self.variable) + else f(other_variable_or_arraylike, self.variable) ) coords, indexes = self.coords._merge_raw(other_coords, reflexive) name = self._result_name(other) return self._replace(variable, coords, name, indexes=indexes) - def _inplace_binary_op(self: T_DataArray, other: Any, f: Callable) -> T_DataArray: + def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self: from xarray.core.groupby import GroupBy if isinstance(other, GroupBy): @@ -4685,12 +4738,14 @@ def _title_for_slice(self, truncate: int = 50) -> str: return title + @_deprecate_positional_args("v2023.10.0") def diff( - self: T_DataArray, + self, dim: Hashable, n: int = 1, + *, label: Literal["upper", "lower"] = "upper", - ) -> T_DataArray: + ) -> Self: """Calculate the n-th order discrete difference along given axis. Parameters @@ -4736,11 +4791,11 @@ def diff( return self._from_temp_dataset(ds) def shift( - self: T_DataArray, + self, shifts: Mapping[Any, int] | None = None, fill_value: Any = dtypes.NA, **shifts_kwargs: int, - ) -> T_DataArray: + ) -> Self: """Shift this DataArray by an offset along one or more dimensions. Only the data is moved; coordinates stay in place. This is consistent @@ -4786,11 +4841,11 @@ def shift( return self._replace(variable=variable) def roll( - self: T_DataArray, + self, shifts: Mapping[Hashable, int] | None = None, roll_coords: bool = False, **shifts_kwargs: int, - ) -> T_DataArray: + ) -> Self: """Roll this array by an offset along one or more dimensions. Unlike shift, roll treats the given dimensions as periodic, so will not @@ -4835,7 +4890,7 @@ def roll( return self._from_temp_dataset(ds) @property - def real(self: T_DataArray) -> T_DataArray: + def real(self) -> Self: """ The real part of the array. @@ -4846,7 +4901,7 @@ def real(self: T_DataArray) -> T_DataArray: return self._replace(self.variable.real) @property - def imag(self: T_DataArray) -> T_DataArray: + def imag(self) -> Self: """ The imaginary part of the array. @@ -4857,10 +4912,10 @@ def imag(self: T_DataArray) -> T_DataArray: return self._replace(self.variable.imag) def dot( - self: T_DataArray, - other: T_DataArray, + self, + other: T_Xarray, dims: Dims = None, - ) -> T_DataArray: + ) -> T_Xarray: """Perform dot product of two DataArrays along their shared dims. Equivalent to taking taking tensordot over all shared dims. @@ -4910,13 +4965,14 @@ def dot( return computation.dot(self, other, dims=dims) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def sortby( self, - variables: Hashable | DataArray | Sequence[Hashable | DataArray], + variables: Hashable + | DataArray + | Sequence[Hashable | DataArray] + | Callable[[Self], Hashable | DataArray | Sequence[Hashable | DataArray]], ascending: bool = True, - ) -> DataArray: + ) -> Self: """Sort object by labels or values (along an axis). Sorts the dataarray, either along specified dimensions, @@ -4935,9 +4991,10 @@ def sortby( Parameters ---------- - variables : Hashable, DataArray, or sequence of Hashable or DataArray - 1D DataArray objects or name(s) of 1D variable(s) in - coords whose values are used to sort this array. + variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are + used to sort this array. If a callable, the callable is passed this object, + and the result is used as the value for cond. ascending : bool, default: True Whether to sort by ascending or descending order. @@ -4957,34 +5014,47 @@ def sortby( Examples -------- >>> da = xr.DataArray( - ... np.random.rand(5), + ... np.arange(5, 0, -1), ... coords=[pd.date_range("1/1/2000", periods=5)], ... dims="time", ... ) >>> da - array([0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ]) + array([5, 4, 3, 2, 1]) Coordinates: * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 >>> da.sortby(da) - array([0.4236548 , 0.54488318, 0.5488135 , 0.60276338, 0.71518937]) + array([1, 2, 3, 4, 5]) Coordinates: - * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-02 + * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01 + + >>> da.sortby(lambda x: x) + + array([1, 2, 3, 4, 5]) + Coordinates: + * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01 """ + # We need to convert the callable here rather than pass it through to the + # dataset method, since otherwise the dataset method would try to call the + # callable with the dataset as the object + if callable(variables): + variables = variables(self) ds = self._to_temp_dataset().sortby(variables, ascending=ascending) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def quantile( - self: T_DataArray, + self, q: ArrayLike, dim: Dims = None, + *, method: QuantileMethods = "linear", keep_attrs: bool | None = None, skipna: bool | None = None, interpolation: QuantileMethods | None = None, - ) -> T_DataArray: + ) -> Self: """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements. @@ -5000,15 +5070,15 @@ def quantile( desired quantile lies between two data points. The options sorted by their R type as summarized in the H&F paper [1]_ are: - 1. "inverted_cdf" (*) - 2. "averaged_inverted_cdf" (*) - 3. "closest_observation" (*) - 4. "interpolated_inverted_cdf" (*) - 5. "hazen" (*) - 6. "weibull" (*) + 1. "inverted_cdf" + 2. "averaged_inverted_cdf" + 3. "closest_observation" + 4. "interpolated_inverted_cdf" + 5. "hazen" + 6. "weibull" 7. "linear" (default) - 8. "median_unbiased" (*) - 9. "normal_unbiased" (*) + 8. "median_unbiased" + 9. "normal_unbiased" The first three methods are discontiuous. The following discontinuous variations of the default "linear" (7.) option are also available: @@ -5022,8 +5092,6 @@ def quantile( was previously called "interpolation", renamed in accordance with numpy version 1.22.0. - (*) These methods require numpy version 1.22 or newer. - keep_attrs : bool or None, optional If True, the dataset's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new @@ -5096,12 +5164,14 @@ def quantile( ) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def rank( - self: T_DataArray, + self, dim: Hashable, + *, pct: bool = False, keep_attrs: bool | None = None, - ) -> T_DataArray: + ) -> Self: """Ranks the data. Equal values are assigned a rank that is the average of the ranks that @@ -5141,11 +5211,11 @@ def rank( return self._from_temp_dataset(ds) def differentiate( - self: T_DataArray, + self, coord: Hashable, edge_order: Literal[1, 2] = 1, datetime_unit: DatetimeUnitOptions = None, - ) -> T_DataArray: + ) -> Self: """ Differentiate the array with the second order accurate central differences. @@ -5159,9 +5229,10 @@ def differentiate( The coordinate to be used to compute the gradient. edge_order : {1, 2}, default: 1 N-th order accurate differences at the boundaries. - datetime_unit : {"Y", "M", "W", "D", "h", "m", "s", "ms", \ + datetime_unit : {"W", "D", "h", "m", "s", "ms", \ "us", "ns", "ps", "fs", "as", None}, optional - Unit to compute gradient. Only valid for datetime coordinate. + Unit to compute gradient. Only valid for datetime coordinate. "Y" and "M" are not available as + datetime_unit. Returns ------- @@ -5202,13 +5273,11 @@ def differentiate( ds = self._to_temp_dataset().differentiate(coord, edge_order, datetime_unit) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def integrate( self, coord: Hashable | Sequence[Hashable] = None, datetime_unit: DatetimeUnitOptions = None, - ) -> DataArray: + ) -> Self: """Integrate along the given coordinate using the trapezoidal rule. .. note:: @@ -5258,13 +5327,11 @@ def integrate( ds = self._to_temp_dataset().integrate(coord, datetime_unit) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def cumulative_integrate( self, coord: Hashable | Sequence[Hashable] = None, datetime_unit: DatetimeUnitOptions = None, - ) -> DataArray: + ) -> Self: """Integrate cumulatively along the given coordinate using the trapezoidal rule. .. note:: @@ -5322,7 +5389,7 @@ def cumulative_integrate( ds = self._to_temp_dataset().cumulative_integrate(coord, datetime_unit) return self._from_temp_dataset(ds) - def unify_chunks(self) -> DataArray: + def unify_chunks(self) -> Self: """Unify chunk size along all chunked dimensions of this DataArray. Returns @@ -5507,7 +5574,7 @@ def polyfit( ) def pad( - self: T_DataArray, + self, pad_width: Mapping[Any, int | tuple[int, int]] | None = None, mode: PadModeOptions = "constant", stat_length: int @@ -5522,7 +5589,7 @@ def pad( reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, **pad_width_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Pad this array along one or more dimensions. .. warning:: @@ -5674,13 +5741,15 @@ def pad( ) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def idxmin( self, dim: Hashable | None = None, + *, skipna: bool | None = None, fill_value: Any = dtypes.NA, keep_attrs: bool | None = None, - ) -> DataArray: + ) -> Self: """Return the coordinate label of the minimum value along a dimension. Returns a new `DataArray` named after the dimension with the values of @@ -5739,8 +5808,8 @@ def idxmin( >>> array = xr.DataArray( ... [ ... [2.0, 1.0, 2.0, 0.0, -2.0], - ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], - ... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN], + ... [-4.0, np.nan, 2.0, np.nan, -2.0], + ... [np.nan, np.nan, 1.0, np.nan, np.nan], ... ], ... dims=["y", "x"], ... coords={"y": [-1, 0, 1], "x": np.arange(5.0) ** 2}, @@ -5770,13 +5839,15 @@ def idxmin( keep_attrs=keep_attrs, ) + @_deprecate_positional_args("v2023.10.0") def idxmax( self, dim: Hashable = None, + *, skipna: bool | None = None, fill_value: Any = dtypes.NA, keep_attrs: bool | None = None, - ) -> DataArray: + ) -> Self: """Return the coordinate label of the maximum value along a dimension. Returns a new `DataArray` named after the dimension with the values of @@ -5835,8 +5906,8 @@ def idxmax( >>> array = xr.DataArray( ... [ ... [2.0, 1.0, 2.0, 0.0, -2.0], - ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], - ... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN], + ... [-4.0, np.nan, 2.0, np.nan, -2.0], + ... [np.nan, np.nan, 1.0, np.nan, np.nan], ... ], ... dims=["y", "x"], ... coords={"y": [-1, 0, 1], "x": np.arange(5.0) ** 2}, @@ -5866,15 +5937,15 @@ def idxmax( keep_attrs=keep_attrs, ) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved + @_deprecate_positional_args("v2023.10.0") def argmin( self, dim: Dims = None, + *, axis: int | None = None, keep_attrs: bool | None = None, skipna: bool | None = None, - ) -> DataArray | dict[Hashable, DataArray]: + ) -> Self | dict[Hashable, Self]: """Index or indices of the minimum of the DataArray over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of DataArrays, @@ -5968,15 +6039,15 @@ def argmin( else: return self._replace_maybe_drop_dims(result) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved + @_deprecate_positional_args("v2023.10.0") def argmax( self, dim: Dims = None, + *, axis: int | None = None, keep_attrs: bool | None = None, skipna: bool | None = None, - ) -> DataArray | dict[Hashable, DataArray]: + ) -> Self | dict[Hashable, Self]: """Index or indices of the maximum of the DataArray over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of DataArrays, @@ -6229,7 +6300,7 @@ def curvefit( >>> def exp_decay(t, time_constant, amplitude): ... return np.exp(-t / time_constant) * amplitude ... - >>> t = np.linspace(0, 10, 11) + >>> t = np.arange(11) >>> da = xr.DataArray( ... np.stack( ... [ @@ -6254,7 +6325,7 @@ def curvefit( 0.00910995]]) Coordinates: * x (x) int64 0 1 2 - * time (time) float64 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 + * time (time) int64 0 1 2 3 4 5 6 7 8 9 10 Fit the exponential decay function to the data along the ``time`` dimension: @@ -6317,11 +6388,13 @@ def curvefit( kwargs=kwargs, ) + @_deprecate_positional_args("v2023.10.0") def drop_duplicates( - self: T_DataArray, + self, dim: Hashable | Iterable[Hashable], + *, keep: Literal["first", "last", False] = "first", - ) -> T_DataArray: + ) -> Self: """Returns a new DataArray with duplicate dimension values removed. Parameters @@ -6403,7 +6476,7 @@ def convert_calendar( align_on: str | None = None, missing: Any | None = None, use_cftime: bool | None = None, - ) -> DataArray: + ) -> Self: """Convert the DataArray to another calendar. Only converts the individual timestamps, does not modify any data except @@ -6523,7 +6596,7 @@ def interp_calendar( self, target: pd.DatetimeIndex | CFTimeIndex | DataArray, dim: str = "time", - ) -> DataArray: + ) -> Self: """Interpolates the DataArray to another calendar based on decimal year measure. Each timestamp in `source` and `target` are first converted to their decimal diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7bd92ea32a0..248f516b61b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -24,6 +24,13 @@ from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload import numpy as np + +# remove once numpy 2.0 is the oldest supported version +try: + from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import RankWarning + import pandas as pd from xarray.coding.calendar_ops import convert_calendar, interp_calendar @@ -50,7 +57,12 @@ get_chunksizes, ) from xarray.core.computation import unify_chunks -from xarray.core.coordinates import DatasetCoordinates, assert_coordinate_consistent +from xarray.core.coordinates import ( + Coordinates, + DatasetCoordinates, + assert_coordinate_consistent, + create_coords_with_default_indexes, +) from xarray.core.daskmanager import DaskManager from xarray.core.duck_array_ops import datetime_to_numeric from xarray.core.indexes import ( @@ -70,21 +82,25 @@ dataset_merge_method, dataset_update_method, merge_coordinates_without_align, - merge_data_and_coords, + merge_core, ) from xarray.core.missing import get_clean_interp_index from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.parallelcompat import ( - get_chunked_array_type, - guess_chunkmanager, -) +from xarray.core.parallelcompat import get_chunked_array_type, guess_chunkmanager from xarray.core.pycompat import ( array_type, is_chunked_array, is_duck_array, is_duck_dask_array, ) -from xarray.core.types import QuantileMethods, T_Dataset +from xarray.core.types import ( + QuantileMethods, + Self, + T_ChunkDim, + T_Chunks, + T_DataArrayOrSet, + T_Dataset, +) from xarray.core.utils import ( Default, Frozen, @@ -107,16 +123,16 @@ calculate_dimensions, ) from xarray.plot.accessor import DatasetPlotAccessor +from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: from numpy.typing import ArrayLike from xarray.backends import AbstractDataStore, ZarrStore from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes - from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.groupby import DatasetGroupBy - from xarray.core.merge import CoercibleMapping + from xarray.core.merge import CoercibleMapping, CoercibleValue, _MergeResult from xarray.core.parallelcompat import ChunkManagerEntrypoint from xarray.core.resample import DatasetResample from xarray.core.rolling import DatasetCoarsen, DatasetRolling @@ -125,9 +141,11 @@ CoarsenBoundaryOptions, CombineAttrsOptions, CompatOptions, + DataVars, DatetimeLike, DatetimeUnitOptions, Dims, + DsCompatible, ErrorOptions, ErrorOptionsWithWarn, InterpOptions, @@ -208,11 +226,6 @@ def _get_virtual_variable( return ref_name, var_name, virtual_var -def _assert_empty(args: tuple, msg: str = "%s") -> None: - if args: - raise ValueError(msg % args) - - def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint): """ Return map from each dim to chunk sizes, accounting for backend's preferred chunks. @@ -296,7 +309,8 @@ def _maybe_chunk( # when rechunking by different amounts, make sure dask names change # by providing chunks as an input to tokenize. # subtle bugs result otherwise. see GH3350 - token2 = tokenize(name, token if token else var._data, chunks) + # we use str() for speed, and use the name for the final array name on the next line + token2 = tokenize(token if token else var._data, str(chunks)) name2 = f"{name_prefix}{name}-{token2}" from_array_kwargs = utils.consolidate_dask_from_array_kwargs( @@ -400,6 +414,26 @@ def _initialize_feasible(lb, ub): return param_defaults, bounds_defaults +def merge_data_and_coords(data_vars: DataVars, coords) -> _MergeResult: + """Used in Dataset.__init__.""" + if isinstance(coords, Coordinates): + coords = coords.copy() + else: + coords = create_coords_with_default_indexes(coords, data_vars) + + # exclude coords from alignment (all variables in a Coordinates object should + # already be aligned together) and use coordinates' indexes to align data_vars + return merge_core( + [data_vars, coords], + compat="broadcast_equals", + join="outer", + explicit_coords=tuple(coords), + indexes=coords.xindexes, + priority_arg=1, + skip_align_args=[1], + ) + + class DataVariables(Mapping[Any, "DataArray"]): __slots__ = ("_dataset",) @@ -414,14 +448,16 @@ def __iter__(self) -> Iterator[Hashable]: ) def __len__(self) -> int: - return len(self._dataset._variables) - len(self._dataset._coord_names) + length = len(self._dataset._variables) - len(self._dataset._coord_names) + assert length >= 0, "something is wrong with Dataset._coord_names" + return length def __contains__(self, key: Hashable) -> bool: return key in self._dataset._variables and key not in self._dataset._coord_names def __getitem__(self, key: Hashable) -> DataArray: if key not in self._dataset._coord_names: - return cast("DataArray", self._dataset[key]) + return self._dataset[key] raise KeyError(key) def __repr__(self) -> str: @@ -491,8 +527,11 @@ class Dataset( Dataset implements the mapping interface with keys given by variable names and values given by DataArray objects for each variable name. - One dimensional variables with name equal to their dimension are - index coordinates used for label based indexing. + By default, pandas indexes are created for one dimensional variables with + name equal to their dimension (i.e., :term:`Dimension coordinate`) so those + variables can be readily used as coordinates for label based indexing. When a + :py:class:`~xarray.Coordinates` object is passed to ``coords``, any existing + index(es) built from those coordinates will be added to the Dataset. To load data from a file or file-like object, use the `open_dataset` function. @@ -513,22 +552,21 @@ class Dataset( - mapping {var name: (dimension name, array-like)} - mapping {var name: (tuple of dimension names, array-like)} - mapping {dimension name: array-like} - (it will be automatically moved to coords, see below) + (if array-like is not a scalar it will be automatically moved to coords, + see below) Each dimension must have the same length in all variables in which it appears. - coords : dict-like, optional - Another mapping in similar form as the `data_vars` argument, - except the each item is saved on the dataset as a "coordinate". + coords : :py:class:`~xarray.Coordinates` or dict-like, optional + A :py:class:`~xarray.Coordinates` object or another mapping in + similar form as the `data_vars` argument, except that each item + is saved on the dataset as a "coordinate". These variables have an associated meaning: they describe constant/fixed/independent quantities, unlike the varying/measured/dependent quantities that belong in - `variables`. Coordinates values may be given by 1-dimensional - arrays or scalars, in which case `dims` do not need to be - supplied: 1D arrays will be assumed to give index values along - the dimension with the same name. + `variables`. - The following notations are accepted: + The following notations are accepted for arbitrary mappings: - mapping {coord name: DataArray} - mapping {coord name: Variable} @@ -538,8 +576,16 @@ class Dataset( (the dimension name is implicitly set to be the same as the coord name) - The last notation implies that the coord name is the same as - the dimension name. + The last notation implies either that the coordinate value is a scalar + or that it is a 1-dimensional array and the coord name is the same as + the dimension name (i.e., a :term:`Dimension coordinate`). In the latter + case, the 1-dimensional array will be assumed to give index values + along the dimension with the same name. + + Alternatively, a :py:class:`~xarray.Coordinates` object may be used in + order to explicitly pass indexes (e.g., a multi-index or any custom + Xarray index) or to bypass the creation of a default index for any + :term:`Dimension coordinate` included in that object. attrs : dict-like, optional Global attributes to save on this dataset. @@ -602,6 +648,7 @@ class Dataset( precipitation float64 8.326 Attributes: description: Weather related data. + """ _attrs: dict[Hashable, Any] | None @@ -629,12 +676,10 @@ def __init__( self, # could make a VariableArgs to use more generally, and refine these # categories - data_vars: Mapping[Any, Any] | None = None, + data_vars: DataVars | None = None, coords: Mapping[Any, Any] | None = None, attrs: Mapping[Any, Any] | None = None, ) -> None: - # TODO(shoyer): expose indexes as a public argument in __init__ - if data_vars is None: data_vars = {} if coords is None: @@ -650,7 +695,7 @@ def __init__( coords = coords._variables variables, coord_names, dims, indexes, _ = merge_data_and_coords( - data_vars, coords, compat="broadcast_equals" + data_vars, coords ) self._attrs = dict(attrs) if attrs is not None else None @@ -661,8 +706,13 @@ def __init__( self._dims = dims self._indexes = indexes + # TODO: dirty workaround for mypy 1.5 error with inherited DatasetOpsMixin vs. Mapping + # related to https://github.com/python/mypy/issues/9319? + def __eq__(self, other: DsCompatible) -> Self: # type: ignore[override] + return super().__eq__(other) + @classmethod - def load_store(cls: type[T_Dataset], store, decoder=None) -> T_Dataset: + def load_store(cls, store, decoder=None) -> Self: """Create a new dataset from the contents of a backends.*DataStore object """ @@ -706,10 +756,16 @@ def encoding(self) -> dict[Any, Any]: def encoding(self, value: Mapping[Any, Any]) -> None: self._encoding = dict(value) - def reset_encoding(self: T_Dataset) -> T_Dataset: + def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: """Return a new Dataset without encoding on the dataset or any of its variables/coords.""" - variables = {k: v.reset_encoding() for k, v in self.variables.items()} + variables = {k: v.drop_encoding() for k, v in self.variables.items()} return self._replace(variables=variables, encoding={}) @property @@ -762,7 +818,7 @@ def dtypes(self) -> Frozen[Hashable, np.dtype]: } ) - def load(self: T_Dataset, **kwargs) -> T_Dataset: + def load(self, **kwargs) -> Self: """Manually trigger loading and/or computation of this dataset's data from disk or a remote source into memory and return this dataset. Unlike compute, the original dataset is modified and returned. @@ -862,7 +918,7 @@ def __dask_postcompute__(self): def __dask_postpersist__(self): return self._dask_postpersist, () - def _dask_postcompute(self: T_Dataset, results: Iterable[Variable]) -> T_Dataset: + def _dask_postcompute(self, results: Iterable[Variable]) -> Self: import dask variables = {} @@ -885,8 +941,8 @@ def _dask_postcompute(self: T_Dataset, results: Iterable[Variable]) -> T_Dataset ) def _dask_postpersist( - self: T_Dataset, dsk: Mapping, *, rename: Mapping[str, str] | None = None - ) -> T_Dataset: + self, dsk: Mapping, *, rename: Mapping[str, str] | None = None + ) -> Self: from dask import is_dask_collection from dask.highlevelgraph import HighLevelGraph from dask.optimization import cull @@ -935,7 +991,7 @@ def _dask_postpersist( self._close, ) - def compute(self: T_Dataset, **kwargs) -> T_Dataset: + def compute(self, **kwargs) -> Self: """Manually trigger loading and/or computation of this dataset's data from disk or a remote source into memory and return a new dataset. Unlike load, the original dataset is left unaltered. @@ -957,7 +1013,7 @@ def compute(self: T_Dataset, **kwargs) -> T_Dataset: new = self.copy(deep=False) return new.load(**kwargs) - def _persist_inplace(self: T_Dataset, **kwargs) -> T_Dataset: + def _persist_inplace(self, **kwargs) -> Self: """Persist all Dask arrays in memory""" # access .data to coerce everything to numpy or dask arrays lazy_data = { @@ -974,7 +1030,7 @@ def _persist_inplace(self: T_Dataset, **kwargs) -> T_Dataset: return self - def persist(self: T_Dataset, **kwargs) -> T_Dataset: + def persist(self, **kwargs) -> Self: """Trigger computation, keeping data as dask arrays This operation can be used to trigger computation on underlying dask @@ -997,7 +1053,7 @@ def persist(self: T_Dataset, **kwargs) -> T_Dataset: @classmethod def _construct_direct( - cls: type[T_Dataset], + cls, variables: dict[Any, Variable], coord_names: set[Hashable], dims: dict[Any, int] | None = None, @@ -1005,7 +1061,7 @@ def _construct_direct( indexes: dict[Any, Index] | None = None, encoding: dict | None = None, close: Callable[[], None] | None = None, - ) -> T_Dataset: + ) -> Self: """Shortcut around __init__ for internal use when we want to skip costly validation """ @@ -1024,7 +1080,7 @@ def _construct_direct( return obj def _replace( - self: T_Dataset, + self, variables: dict[Hashable, Variable] | None = None, coord_names: set[Hashable] | None = None, dims: dict[Any, int] | None = None, @@ -1032,7 +1088,7 @@ def _replace( indexes: dict[Hashable, Index] | None = None, encoding: dict | None | Default = _default, inplace: bool = False, - ) -> T_Dataset: + ) -> Self: """Fastpath constructor for internal use. Returns an object with optionally with replaced attributes. @@ -1074,13 +1130,13 @@ def _replace( return obj def _replace_with_new_dims( - self: T_Dataset, + self, variables: dict[Hashable, Variable], coord_names: set | None = None, attrs: dict[Hashable, Any] | None | Default = _default, indexes: dict[Hashable, Index] | None = None, inplace: bool = False, - ) -> T_Dataset: + ) -> Self: """Replace variables with recalculated dimensions.""" dims = calculate_dimensions(variables) return self._replace( @@ -1088,13 +1144,13 @@ def _replace_with_new_dims( ) def _replace_vars_and_dims( - self: T_Dataset, + self, variables: dict[Hashable, Variable], coord_names: set | None = None, dims: dict[Hashable, int] | None = None, attrs: dict[Hashable, Any] | None | Default = _default, inplace: bool = False, - ) -> T_Dataset: + ) -> Self: """Deprecated version of _replace_with_new_dims(). Unlike _replace_with_new_dims(), this method always recalculates @@ -1107,13 +1163,13 @@ def _replace_vars_and_dims( ) def _overwrite_indexes( - self: T_Dataset, + self, indexes: Mapping[Hashable, Index], variables: Mapping[Hashable, Variable] | None = None, drop_variables: list[Hashable] | None = None, drop_indexes: list[Hashable] | None = None, rename_dims: Mapping[Hashable, Hashable] | None = None, - ) -> T_Dataset: + ) -> Self: """Maybe replace indexes. This function may do a lot more depending on index query @@ -1180,9 +1236,7 @@ def _overwrite_indexes( else: return replaced - def copy( - self: T_Dataset, deep: bool = False, data: Mapping[Any, ArrayLike] | None = None - ) -> T_Dataset: + def copy(self, deep: bool = False, data: DataVars | None = None) -> Self: """Returns a copy of this dataset. If `deep=True`, a deep copy is made of each of the component variables. @@ -1282,11 +1336,11 @@ def copy( return self._copy(deep=deep, data=data) def _copy( - self: T_Dataset, + self, deep: bool = False, - data: Mapping[Any, ArrayLike] | None = None, + data: DataVars | None = None, memo: dict[int, Any] | None = None, - ) -> T_Dataset: + ) -> Self: if data is None: data = {} elif not utils.is_dict_like(data): @@ -1299,13 +1353,13 @@ def _copy( if keys_not_in_vars: raise ValueError( "Data must only contain variables in original " - "dataset. Extra variables: {}".format(keys_not_in_vars) + f"dataset. Extra variables: {keys_not_in_vars}" ) keys_missing_from_data = var_keys - data_keys if keys_missing_from_data: raise ValueError( "Data must contain all variables in original " - "dataset. Data is missing {}".format(keys_missing_from_data) + f"dataset. Data is missing {keys_missing_from_data}" ) indexes, index_vars = self.xindexes.copy_indexes(deep=deep) @@ -1324,13 +1378,13 @@ def _copy( return self._replace(variables, indexes=indexes, attrs=attrs, encoding=encoding) - def __copy__(self: T_Dataset) -> T_Dataset: + def __copy__(self) -> Self: return self._copy(deep=False) - def __deepcopy__(self: T_Dataset, memo: dict[int, Any] | None = None) -> T_Dataset: + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: return self._copy(deep=True, memo=memo) - def as_numpy(self: T_Dataset) -> T_Dataset: + def as_numpy(self) -> Self: """ Coerces wrapped data and coordinates into numpy arrays, returning a Dataset. @@ -1342,7 +1396,7 @@ def as_numpy(self: T_Dataset) -> T_Dataset: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - def _copy_listed(self: T_Dataset, names: Iterable[Hashable]) -> T_Dataset: + def _copy_listed(self, names: Iterable[Hashable]) -> Self: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. """ @@ -1436,13 +1490,20 @@ def __bool__(self) -> bool: def __iter__(self) -> Iterator[Hashable]: return iter(self.data_vars) - def __array__(self, dtype=None): - raise TypeError( - "cannot directly convert an xarray.Dataset into a " - "numpy array. Instead, create an xarray.DataArray " - "first, either with indexing on the Dataset or by " - "invoking the `to_array()` method." - ) + if TYPE_CHECKING: + # needed because __getattr__ is returning Any and otherwise + # this class counts as part of the SupportsArray Protocol + __array__ = None # type: ignore[var-annotated,unused-ignore] + + else: + + def __array__(self, dtype=None): + raise TypeError( + "cannot directly convert an xarray.Dataset into a " + "numpy array. Instead, create an xarray.DataArray " + "first, either with indexing on the Dataset or by " + "invoking the `to_array()` method." + ) @property def nbytes(self) -> int: @@ -1455,7 +1516,7 @@ def nbytes(self) -> int: return sum(v.nbytes for v in self.variables.values()) @property - def loc(self: T_Dataset) -> _LocIndexer[T_Dataset]: + def loc(self) -> _LocIndexer[Self]: """Attribute for location based indexing. Only supports __getitem__, and only when the key is a dict of the form {dim: labels}. """ @@ -1467,12 +1528,12 @@ def __getitem__(self, key: Hashable) -> DataArray: # Mapping is Iterable @overload - def __getitem__(self: T_Dataset, key: Iterable[Hashable]) -> T_Dataset: + def __getitem__(self, key: Iterable[Hashable]) -> Self: ... def __getitem__( - self: T_Dataset, key: Mapping[Any, Any] | Hashable | Iterable[Hashable] - ) -> T_Dataset | DataArray: + self, key: Mapping[Any, Any] | Hashable | Iterable[Hashable] + ) -> Self | DataArray: """Access variables or coordinates of this dataset as a :py:class:`~xarray.DataArray` or a subset of variables or a indexed dataset. @@ -1637,7 +1698,7 @@ def __delitem__(self, key: Hashable) -> None: # https://github.com/python/mypy/issues/4266 __hash__ = None # type: ignore[assignment] - def _all_compat(self, other: Dataset, compat_str: str) -> bool: + def _all_compat(self, other: Self, compat_str: str) -> bool: """Helper function for equals and identical""" # some stores (e.g., scipy) do not seem to preserve order, so don't @@ -1649,7 +1710,7 @@ def compat(x: Variable, y: Variable) -> bool: self._variables, other._variables, compat=compat ) - def broadcast_equals(self, other: Dataset) -> bool: + def broadcast_equals(self, other: Self) -> bool: """Two Datasets are broadcast equal if they are equal after broadcasting all variables against each other. @@ -1657,17 +1718,66 @@ def broadcast_equals(self, other: Dataset) -> bool: the other dataset can still be broadcast equal if the the non-scalar variable is a constant. + Examples + -------- + + # 2D array with shape (1, 3) + + >>> data = np.array([[1, 2, 3]]) + >>> a = xr.Dataset( + ... {"variable_name": (("space", "time"), data)}, + ... coords={"space": [0], "time": [0, 1, 2]}, + ... ) + >>> a + + Dimensions: (space: 1, time: 3) + Coordinates: + * space (space) int64 0 + * time (time) int64 0 1 2 + Data variables: + variable_name (space, time) int64 1 2 3 + + # 2D array with shape (3, 1) + + >>> data = np.array([[1], [2], [3]]) + >>> b = xr.Dataset( + ... {"variable_name": (("time", "space"), data)}, + ... coords={"time": [0, 1, 2], "space": [0]}, + ... ) + >>> b + + Dimensions: (time: 3, space: 1) + Coordinates: + * time (time) int64 0 1 2 + * space (space) int64 0 + Data variables: + variable_name (time, space) int64 1 2 3 + + .equals returns True if two Datasets have the same values, dimensions, and coordinates. .broadcast_equals returns True if the + results of broadcasting two Datasets against each other have the same values, dimensions, and coordinates. + + >>> a.equals(b) + False + + >>> a.broadcast_equals(b) + True + + >>> a2, b2 = xr.broadcast(a, b) + >>> a2.equals(b2) + True + See Also -------- Dataset.equals Dataset.identical + Dataset.broadcast """ try: return self._all_compat(other, "broadcast_equals") except (TypeError, AttributeError): return False - def equals(self, other: Dataset) -> bool: + def equals(self, other: Self) -> bool: """Two Datasets are equal if they have matching variables and coordinates, all of which are equal. @@ -1677,6 +1787,67 @@ def equals(self, other: Dataset) -> bool: This method is necessary because `v1 == v2` for ``Dataset`` does element-wise comparisons (like numpy.ndarrays). + Examples + -------- + + # 2D array with shape (1, 3) + + >>> data = np.array([[1, 2, 3]]) + >>> dataset1 = xr.Dataset( + ... {"variable_name": (("space", "time"), data)}, + ... coords={"space": [0], "time": [0, 1, 2]}, + ... ) + >>> dataset1 + + Dimensions: (space: 1, time: 3) + Coordinates: + * space (space) int64 0 + * time (time) int64 0 1 2 + Data variables: + variable_name (space, time) int64 1 2 3 + + # 2D array with shape (3, 1) + + >>> data = np.array([[1], [2], [3]]) + >>> dataset2 = xr.Dataset( + ... {"variable_name": (("time", "space"), data)}, + ... coords={"time": [0, 1, 2], "space": [0]}, + ... ) + >>> dataset2 + + Dimensions: (time: 3, space: 1) + Coordinates: + * time (time) int64 0 1 2 + * space (space) int64 0 + Data variables: + variable_name (time, space) int64 1 2 3 + >>> dataset1.equals(dataset2) + False + + >>> dataset1.broadcast_equals(dataset2) + True + + .equals returns True if two Datasets have the same values, dimensions, and coordinates. .broadcast_equals returns True if the + results of broadcasting two Datasets against each other have the same values, dimensions, and coordinates. + + Similar for missing values too: + + >>> ds1 = xr.Dataset( + ... { + ... "temperature": (["x", "y"], [[1, np.nan], [3, 4]]), + ... }, + ... coords={"x": [0, 1], "y": [0, 1]}, + ... ) + + >>> ds2 = xr.Dataset( + ... { + ... "temperature": (["x", "y"], [[1, np.nan], [3, 4]]), + ... }, + ... coords={"x": [0, 1], "y": [0, 1]}, + ... ) + >>> ds1.equals(ds2) + True + See Also -------- Dataset.broadcast_equals @@ -1687,10 +1858,70 @@ def equals(self, other: Dataset) -> bool: except (TypeError, AttributeError): return False - def identical(self, other: Dataset) -> bool: + def identical(self, other: Self) -> bool: """Like equals, but also checks all dataset attributes and the attributes on all variables and coordinates. + Example + ------- + + >>> a = xr.Dataset( + ... {"Width": ("X", [1, 2, 3])}, + ... coords={"X": [1, 2, 3]}, + ... attrs={"units": "m"}, + ... ) + >>> b = xr.Dataset( + ... {"Width": ("X", [1, 2, 3])}, + ... coords={"X": [1, 2, 3]}, + ... attrs={"units": "m"}, + ... ) + >>> c = xr.Dataset( + ... {"Width": ("X", [1, 2, 3])}, + ... coords={"X": [1, 2, 3]}, + ... attrs={"units": "ft"}, + ... ) + >>> a + + Dimensions: (X: 3) + Coordinates: + * X (X) int64 1 2 3 + Data variables: + Width (X) int64 1 2 3 + Attributes: + units: m + + >>> b + + Dimensions: (X: 3) + Coordinates: + * X (X) int64 1 2 3 + Data variables: + Width (X) int64 1 2 3 + Attributes: + units: m + + >>> c + + Dimensions: (X: 3) + Coordinates: + * X (X) int64 1 2 3 + Data variables: + Width (X) int64 1 2 3 + Attributes: + units: ft + + >>> a.equals(b) + True + + >>> a.identical(b) + True + + >>> a.equals(c) + True + + >>> a.identical(c) + False + See Also -------- Dataset.broadcast_equals @@ -1719,13 +1950,19 @@ def indexes(self) -> Indexes[pd.Index]: @property def xindexes(self) -> Indexes[Index]: - """Mapping of xarray Index objects used for label based indexing.""" + """Mapping of :py:class:`~xarray.indexes.Index` objects + used for label based indexing. + """ return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) @property def coords(self) -> DatasetCoordinates: - """Dictionary of xarray.DataArray objects corresponding to coordinate - variables + """Mapping of :py:class:`~xarray.DataArray` objects corresponding to + coordinate variables. + + See Also + -------- + Coordinates """ return DatasetCoordinates(self) @@ -1734,7 +1971,7 @@ def data_vars(self) -> DataVariables: """Dictionary of DataArray objects corresponding to data variables""" return DataVariables(self) - def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Dataset: + def set_coords(self, names: Hashable | Iterable[Hashable]) -> Self: """Given names of one or more variables, set them as coordinates Parameters @@ -1792,10 +2029,10 @@ def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Datas return obj def reset_coords( - self: T_Dataset, + self, names: Dims = None, drop: bool = False, - ) -> T_Dataset: + ) -> Self: """Given names of coordinates, reset them to become variables Parameters @@ -2003,7 +2240,9 @@ def to_netcdf( Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1, - "zlib": True}, ...}`` + "zlib": True}, ...}``. + If ``encoding`` is specified the original encoding of the variables of + the dataset is ignored. The `h5netcdf` engine supports both the NetCDF4-style compression encoding parameters ``{"zlib": True, "complevel": 9}`` and the h5py @@ -2062,6 +2301,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: Literal[True] = True, consolidated: bool | None = None, append_dim: Hashable | None = None, @@ -2069,6 +2309,7 @@ def to_zarr( safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + write_empty_chunks: bool | None = None, chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> ZarrStore: ... @@ -2091,6 +2332,7 @@ def to_zarr( safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + write_empty_chunks: bool | None = None, chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> Delayed: ... @@ -2103,6 +2345,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, @@ -2110,6 +2353,7 @@ def to_zarr( safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + write_empty_chunks: bool | None = None, chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> ZarrStore | Delayed: """Write dataset contents to a zarr group. @@ -2199,6 +2443,17 @@ def to_zarr( The desired zarr spec version to target (currently 2 or 3). The default of None will attempt to determine the zarr version from ``store`` when possible, otherwise defaulting to 2. + write_empty_chunks : bool or None, optional + If True, all chunks will be stored regardless of their + contents. If False, each chunk is compared to the array's fill value + prior to storing. If a chunk is uniformly equal to the fill value, then + that chunk is not be stored, and the store entry for that chunk's key + is deleted. This setting enables sparser storage, as only chunks with + non-fill-value data are stored, at the expense of overhead associated + with checking the data of each chunk. If None (default) fall back to + specification(s) in ``encoding`` or Zarr defaults. A ``ValueError`` + will be raised if the value of this (if not None) differs with + ``encoding``. chunkmanager_store_kwargs : dict, optional Additional keyword arguments passed on to the `ChunkManager.store` method used to store chunked arrays. For example for a dask array additional kwargs will be passed eventually to @@ -2248,6 +2503,7 @@ def to_zarr( region=region, safe_chunks=safe_chunks, zarr_version=zarr_version, + write_empty_chunks=write_empty_chunks, chunkmanager_store_kwargs=chunkmanager_store_kwargs, ) @@ -2329,18 +2585,16 @@ def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]: return get_chunksizes(self.variables.values()) def chunk( - self: T_Dataset, - chunks: ( - int | Literal["auto"] | Mapping[Any, None | int | str | tuple[int, ...]] - ) = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + self, + chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str | None = None, lock: bool = False, inline_array: bool = False, chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, - **chunks_kwargs: None | int | str | tuple[int, ...], - ) -> T_Dataset: + **chunks_kwargs: T_ChunkDim, + ) -> Self: """Coerce all arrays in this dataset into dask arrays with the given chunks. @@ -2390,23 +2644,29 @@ def chunk( xarray.unify_chunks dask.array.from_array """ - if chunks is None and chunks_kwargs is None: + if chunks is None and not chunks_kwargs: warnings.warn( "None value for 'chunks' is deprecated. " "It will raise an error in the future. Use instead '{}'", - category=FutureWarning, + category=DeprecationWarning, ) chunks = {} - - if isinstance(chunks, (Number, str, int)): - chunks = dict.fromkeys(self.dims, chunks) + chunks_mapping: Mapping[Any, Any] + if not isinstance(chunks, Mapping) and chunks is not None: + if isinstance(chunks, (tuple, list)): + utils.emit_user_level_warning( + "Supplying chunks as dimension-order tuples is deprecated. " + "It will raise an error in the future. Instead use a dict with dimensions as keys.", + category=DeprecationWarning, + ) + chunks_mapping = dict.fromkeys(self.dims, chunks) else: - chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") + chunks_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") - bad_dims = chunks.keys() - self.dims.keys() + bad_dims = chunks_mapping.keys() - self.dims.keys() if bad_dims: raise ValueError( - f"some chunks keys are not dimensions on this object: {bad_dims}" + f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.dims)}" ) chunkmanager = guess_chunkmanager(chunked_array_type) @@ -2417,7 +2677,7 @@ def chunk( k: _maybe_chunk( k, v, - chunks, + chunks_mapping, token, lock, name_prefix, @@ -2469,7 +2729,7 @@ def _validate_indexers( if v.ndim > 1: raise IndexError( "Unlabeled multi-dimensional array cannot be " - "used for indexing: {}".format(k) + f"used for indexing: {k}" ) yield k, v @@ -2509,9 +2769,9 @@ def _get_indexers_coords_and_indexes(self, indexers): if v.dtype.kind == "b": if v.ndim != 1: # we only support 1-d boolean array raise ValueError( - "{:d}d-boolean array is used for indexing along " - "dimension {!r}, but only 1d boolean arrays are " - "supported.".format(v.ndim, k) + f"{v.ndim:d}d-boolean array is used for indexing along " + f"dimension {k!r}, but only 1d boolean arrays are " + "supported." ) # Make sure in case of boolean DataArray, its # coordinate also should be indexed. @@ -2534,12 +2794,12 @@ def _get_indexers_coords_and_indexes(self, indexers): return attached_coords, attached_indexes def isel( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, drop: bool = False, missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with each array indexed along the specified dimension(s). @@ -2682,12 +2942,12 @@ def isel( ) def _isel_fancy( - self: T_Dataset, + self, indexers: Mapping[Any, Any], *, drop: bool, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> T_Dataset: + ) -> Self: valid_indexers = dict(self._validate_indexers(indexers, missing_dims)) variables: dict[Hashable, Variable] = {} @@ -2723,13 +2983,13 @@ def _isel_fancy( return self._replace_with_new_dims(variables, coord_names, indexes=indexes) def sel( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, drop: bool = False, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with each array indexed by tick labels along the specified dimension(s). @@ -2809,10 +3069,10 @@ def sel( return result._overwrite_indexes(*query_results.as_tuple()[1:]) def head( - self: T_Dataset, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with the first `n` values of each array for the specified dimension(s). @@ -2899,10 +3159,10 @@ def head( return self.isel(indexers_slices) def tail( - self: T_Dataset, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with the last `n` values of each array for the specified dimension(s). @@ -2990,10 +3250,10 @@ def tail( return self.isel(indexers_slices) def thin( - self: T_Dataset, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with each array indexed along every `n`-th value for the specified dimension(s) @@ -3075,10 +3335,10 @@ def thin( return self.isel(indexers_slices) def broadcast_like( - self: T_Dataset, - other: Dataset | DataArray, + self, + other: T_DataArrayOrSet, exclude: Iterable[Hashable] | None = None, - ) -> T_Dataset: + ) -> Self: """Broadcast this DataArray against another Dataset or DataArray. This is equivalent to xr.broadcast(other, self)[1] @@ -3098,9 +3358,7 @@ def broadcast_like( dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) - return _broadcast_helper( - cast("T_Dataset", args[1]), exclude, dims_map, common_coords - ) + return _broadcast_helper(args[1], exclude, dims_map, common_coords) def _reindex_callback( self, @@ -3111,7 +3369,7 @@ def _reindex_callback( fill_value: Any, exclude_dims: frozenset[Hashable], exclude_vars: frozenset[Hashable], - ) -> Dataset: + ) -> Self: """Callback called from ``Aligner`` to create a new reindexed Dataset.""" new_variables = variables.copy() @@ -3159,18 +3417,22 @@ def _reindex_callback( new_variables, new_coord_names, indexes=new_indexes ) + reindexed.encoding = self.encoding + return reindexed def reindex_like( - self: T_Dataset, - other: Dataset | DataArray, + self, + other: T_Xarray, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = xrdtypes.NA, - ) -> T_Dataset: - """Conform this object onto the indexes of another object, filling in - missing values with ``fill_value``. The default fill value is NaN. + ) -> Self: + """ + Conform this object onto the indexes of another object, for indexes which the + objects share. Missing values are filled with ``fill_value``. The default fill + value is NaN. Parameters ---------- @@ -3216,7 +3478,9 @@ def reindex_like( See Also -------- Dataset.reindex + DataArray.reindex_like align + """ return alignment.reindex_like( self, @@ -3228,14 +3492,14 @@ def reindex_like( ) def reindex( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = xrdtypes.NA, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Conform this object onto a new set of indexes, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -3444,7 +3708,7 @@ def reindex( ) def _reindex( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, @@ -3452,7 +3716,7 @@ def _reindex( fill_value: Any = xrdtypes.NA, sparse: bool = False, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """ Same as reindex but supports sparse option. """ @@ -3468,14 +3732,14 @@ def _reindex( ) def interp( - self: T_Dataset, + self, coords: Mapping[Any, Any] | None = None, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, method_non_numeric: str = "nearest", **coords_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Interpolate a Dataset onto new coordinates Performs univariate or multivariate interpolation of a Dataset onto @@ -3495,7 +3759,7 @@ def interp( If DataArrays are passed as new coordinates, their dimensions are used for the broadcasting. Missing values are skipped. method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \ - "barycentric", "krog", "pchip", "spline", "akima"}, default: "linear" + "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear" String indicating which method to use for interpolation: - 'linear': linear interpolation. Additional keyword @@ -3504,7 +3768,7 @@ def interp( are passed to :py:func:`scipy.interpolate.interp1d`. If ``method='polynomial'``, the ``order`` keyword argument must also be provided. - - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their + - 'barycentric', 'krogh', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. assume_sorted : bool, default: False @@ -3651,7 +3915,7 @@ def _validate_interp_indexer(x, new_x): "coordinate, the coordinates to " "interpolate to must be either datetime " "strings or datetimes. " - "Instead got\n{}".format(new_x) + f"Instead got\n{new_x}" ) return x, new_x @@ -3748,12 +4012,12 @@ def _validate_interp_indexer(x, new_x): def interp_like( self, - other: Dataset | DataArray, + other: T_Xarray, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, method_non_numeric: str = "nearest", - ) -> Dataset: + ) -> Self: """Interpolate this object onto the coordinates of another object, filling the out of range values with NaN. @@ -3771,7 +4035,7 @@ def interp_like( names to an 1d array-like, which provides coordinates upon which to index the variables in this dataset. Missing values are skipped. method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \ - "barycentric", "krog", "pchip", "spline", "akima"}, default: "linear" + "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear" String indicating which method to use for interpolation: - 'linear': linear interpolation. Additional keyword @@ -3780,7 +4044,7 @@ def interp_like( are passed to :py:func:`scipy.interpolate.interp1d`. If ``method='polynomial'``, the ``order`` keyword argument must also be provided. - - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their + - 'barycentric', 'krogh', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. assume_sorted : bool, default: False @@ -3903,10 +4167,10 @@ def _rename_all( return variables, coord_names, dims, indexes def _rename( - self: T_Dataset, + self, name_dict: Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> T_Dataset: + ) -> Self: """Also used internally by DataArray so that the warning (if any) is raised at the right stack level. """ @@ -3921,6 +4185,9 @@ def _rename( create_dim_coord = False new_k = name_dict[k] + if k == new_k: + continue # Same name, nothing to do + if k in self.dims and new_k in self._coord_names: coord_dims = self._variables[name_dict[k]].dims if coord_dims == (k,): @@ -3945,10 +4212,10 @@ def _rename( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename( - self: T_Dataset, + self, name_dict: Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> T_Dataset: + ) -> Self: """Returns a new object with renamed variables, coordinates and dimensions. Parameters @@ -3975,10 +4242,10 @@ def rename( return self._rename(name_dict=name_dict, **names) def rename_dims( - self: T_Dataset, + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims: Hashable, - ) -> T_Dataset: + ) -> Self: """Returns a new object with renamed dimensions only. Parameters @@ -4007,8 +4274,8 @@ def rename_dims( for k, v in dims_dict.items(): if k not in self.dims: raise ValueError( - f"cannot rename {k!r} because it is not a " - "dimension in this dataset" + f"cannot rename {k!r} because it is not found " + f"in the dimensions of this dataset {tuple(self.dims)}" ) if v in self.dims or v in self: raise ValueError( @@ -4022,10 +4289,10 @@ def rename_dims( return self._replace(variables, coord_names, dims=sizes, indexes=indexes) def rename_vars( - self: T_Dataset, + self, name_dict: Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> T_Dataset: + ) -> Self: """Returns a new object with renamed variables including coordinates Parameters @@ -4062,8 +4329,8 @@ def rename_vars( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def swap_dims( - self: T_Dataset, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs - ) -> T_Dataset: + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs + ) -> Self: """Returns a new object with swapped dimensions. Parameters @@ -4130,7 +4397,7 @@ def swap_dims( if k not in self.dims: raise ValueError( f"cannot swap from dimension {k!r} because it is " - "not an existing dimension" + f"not one of the dimensions of this dataset {tuple(self.dims)}" ) if v in self.variables and self.variables[v].dims != (k,): raise ValueError( @@ -4166,14 +4433,12 @@ def swap_dims( return self._replace_with_new_dims(variables, coord_names, indexes=indexes) - # change type of self and return to T_Dataset once - # https://github.com/python/mypy/issues/12846 is resolved def expand_dims( self, dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None, axis: None | int | Sequence[int] = None, **dim_kwargs: Any, - ) -> Dataset: + ) -> Self: """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. The new object is a view into the underlying array, not a copy. @@ -4207,6 +4472,64 @@ def expand_dims( expanded : Dataset This object, but with additional dimension(s). + Examples + -------- + >>> dataset = xr.Dataset({"temperature": ([], 25.0)}) + >>> dataset + + Dimensions: () + Data variables: + temperature float64 25.0 + + # Expand the dataset with a new dimension called "time" + + >>> dataset.expand_dims(dim="time") + + Dimensions: (time: 1) + Dimensions without coordinates: time + Data variables: + temperature (time) float64 25.0 + + # 1D data + + >>> temperature_1d = xr.DataArray([25.0, 26.5, 24.8], dims="x") + >>> dataset_1d = xr.Dataset({"temperature": temperature_1d}) + >>> dataset_1d + + Dimensions: (x: 3) + Dimensions without coordinates: x + Data variables: + temperature (x) float64 25.0 26.5 24.8 + + # Expand the dataset with a new dimension called "time" using axis argument + + >>> dataset_1d.expand_dims(dim="time", axis=0) + + Dimensions: (time: 1, x: 3) + Dimensions without coordinates: time, x + Data variables: + temperature (time, x) float64 25.0 26.5 24.8 + + # 2D data + + >>> temperature_2d = xr.DataArray(np.random.rand(3, 4), dims=("y", "x")) + >>> dataset_2d = xr.Dataset({"temperature": temperature_2d}) + >>> dataset_2d + + Dimensions: (y: 3, x: 4) + Dimensions without coordinates: y, x + Data variables: + temperature (y, x) float64 0.5488 0.7152 0.6028 ... 0.3834 0.7917 0.5289 + + # Expand the dataset with a new dimension called "time" using axis argument + + >>> dataset_2d.expand_dims(dim="time", axis=2) + + Dimensions: (y: 3, x: 4, time: 1) + Dimensions without coordinates: y, x, time + Data variables: + temperature (y, x, time) float64 0.5488 0.7152 0.6028 ... 0.7917 0.5289 + See Also -------- DataArray.expand_dims @@ -4243,8 +4566,7 @@ def expand_dims( raise ValueError(f"Dimension {d} already exists.") if d in self._variables and not utils.is_scalar(self._variables[d]): raise ValueError( - "{dim} already exists as coordinate or" - " variable name.".format(dim=d) + f"{d} already exists as coordinate or" " variable name." ) variables: dict[Hashable, Variable] = {} @@ -4267,8 +4589,7 @@ def expand_dims( pass # Do nothing if the dimensions value is just an int else: raise TypeError( - "The value of new dimension {k} must be " - "an iterable or an int".format(k=k) + f"The value of new dimension {k} must be " "an iterable or an int" ) for k, v in self._variables.items(): @@ -4307,14 +4628,12 @@ def expand_dims( variables, coord_names=coord_names, indexes=indexes ) - # change type of self and return to T_Dataset once - # https://github.com/python/mypy/issues/12846 is resolved def set_index( self, indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None, append: bool = False, **indexes_kwargs: Hashable | Sequence[Hashable], - ) -> Dataset: + ) -> Self: """Set Dataset (multi-)indexes using one or more existing coordinates or variables. @@ -4412,7 +4731,9 @@ def set_index( if len(var_names) == 1 and (not append or dim not in self._indexes): var_name = var_names[0] var = self._variables[var_name] - if var.dims != (dim,): + # an error with a better message will be raised for scalar variables + # when creating the PandasIndex + if var.ndim > 0 and var.dims != (dim,): raise ValueError( f"dimension mismatch: try setting an index for dimension {dim!r} with " f"variable {var_name!r} that has dimensions {var.dims}" @@ -4472,11 +4793,13 @@ def set_index( variables, coord_names=coord_names, indexes=indexes_ ) + @_deprecate_positional_args("v2023.10.0") def reset_index( - self: T_Dataset, + self, dims_or_levels: Hashable | Sequence[Hashable], + *, drop: bool = False, - ) -> T_Dataset: + ) -> Self: """Reset the specified index(es) or multi-index level(s). This legacy method is specific to pandas (multi-)indexes and @@ -4584,11 +4907,11 @@ def drop_or_convert(var_names): ) def set_xindex( - self: T_Dataset, + self, coord_names: str | Sequence[Hashable], index_cls: type[Index] | None = None, **options, - ) -> T_Dataset: + ) -> Self: """Set a new, Xarray-compatible index from one or more existing coordinate(s). @@ -4696,10 +5019,10 @@ def set_xindex( ) def reorder_levels( - self: T_Dataset, + self, dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None, **dim_order_kwargs: Sequence[int | Hashable], - ) -> T_Dataset: + ) -> Self: """Rearrange index levels using input order. Parameters @@ -4800,12 +5123,12 @@ def _get_stack_index( return stack_index, stack_coords def _stack_once( - self: T_Dataset, + self, dims: Sequence[Hashable | ellipsis], new_dim: Hashable, index_cls: type[Index], create_index: bool | None = True, - ) -> T_Dataset: + ) -> Self: if dims == ...: raise ValueError("Please use [...] for dims, rather than just ...") if ... in dims: @@ -4859,12 +5182,12 @@ def _stack_once( ) def stack( - self: T_Dataset, + self, dimensions: Mapping[Any, Sequence[Hashable | ellipsis]] | None = None, create_index: bool | None = True, index_cls: type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable | ellipsis], - ) -> T_Dataset: + ) -> Self: """ Stack any number of existing dimensions into a single new dimension. @@ -4986,34 +5309,31 @@ def to_stacked_array( stacking_dims = tuple(dim for dim in self.dims if dim not in sample_dims) - for variable in self: - dims = self[variable].dims - dims_include_sample_dims = set(sample_dims) <= set(dims) - if not dims_include_sample_dims: + for key, da in self.data_vars.items(): + missing_sample_dims = set(sample_dims) - set(da.dims) + if missing_sample_dims: raise ValueError( - "All variables in the dataset must contain the " - "dimensions {}.".format(dims) + "Variables in the dataset must contain all ``sample_dims`` " + f"({sample_dims!r}) but '{key}' misses {sorted(map(str, missing_sample_dims))}" ) - def ensure_stackable(val): - assign_coords = {variable_dim: val.name} - for dim in stacking_dims: - if dim not in val.dims: - assign_coords[dim] = None + def stack_dataarray(da): + # add missing dims/ coords and the name of the variable + + missing_stack_coords = {variable_dim: da.name} + for dim in set(stacking_dims) - set(da.dims): + missing_stack_coords[dim] = None - expand_dims = set(stacking_dims).difference(set(val.dims)) - expand_dims.add(variable_dim) - # must be list for .expand_dims - expand_dims = list(expand_dims) + missing_stack_dims = list(missing_stack_coords) return ( - val.assign_coords(**assign_coords) - .expand_dims(expand_dims) + da.assign_coords(**missing_stack_coords) + .expand_dims(missing_stack_dims) .stack({new_dim: (variable_dim,) + stacking_dims}) ) # concatenate the arrays - stackable_vars = [ensure_stackable(self[key]) for key in self.data_vars] + stackable_vars = [stack_dataarray(da) for da in self.data_vars.values()] data_array = concat(stackable_vars, dim=new_dim) if name is not None: @@ -5022,12 +5342,12 @@ def ensure_stackable(val): return data_array def _unstack_once( - self: T_Dataset, + self, dim: Hashable, index_and_vars: tuple[Index, dict[Hashable, Variable]], fill_value, sparse: bool = False, - ) -> T_Dataset: + ) -> Self: index, index_vars = index_and_vars variables: dict[Hashable, Variable] = {} indexes = {k: v for k, v in self._indexes.items() if k != dim} @@ -5062,12 +5382,12 @@ def _unstack_once( ) def _unstack_full_reindex( - self: T_Dataset, + self, dim: Hashable, index_and_vars: tuple[Index, dict[Hashable, Variable]], fill_value, sparse: bool, - ) -> T_Dataset: + ) -> Self: index, index_vars = index_and_vars variables: dict[Hashable, Variable] = {} indexes = {k: v for k, v in self._indexes.items() if k != dim} @@ -5112,12 +5432,14 @@ def _unstack_full_reindex( variables, coord_names=coord_names, indexes=indexes ) + @_deprecate_positional_args("v2023.10.0") def unstack( - self: T_Dataset, + self, dim: Dims = None, + *, fill_value: Any = xrdtypes.NA, sparse: bool = False, - ) -> T_Dataset: + ) -> Self: """ Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. @@ -5154,10 +5476,10 @@ def unstack( else: dims = list(dim) - missing_dims = [d for d in dims if d not in self.dims] + missing_dims = set(dims) - set(self.dims) if missing_dims: raise ValueError( - f"Dataset does not contain the dimensions: {missing_dims}" + f"Dimensions {tuple(missing_dims)} not found in data dimensions {tuple(self.dims)}" ) # each specified dimension must have exactly one multi-index @@ -5214,7 +5536,7 @@ def unstack( result = result._unstack_once(d, stacked_indexes[d], fill_value, sparse) return result - def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset: + def update(self, other: CoercibleMapping) -> Self: """Update this dataset's variables with those from another dataset. Just like :py:meth:`dict.update` this is a in-place operation. @@ -5254,14 +5576,14 @@ def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset: return self._replace(inplace=True, **merge_result._asdict()) def merge( - self: T_Dataset, + self, other: CoercibleMapping | DataArray, overwrite_vars: Hashable | Iterable[Hashable] = frozenset(), compat: CompatOptions = "no_conflicts", join: JoinOptions = "outer", fill_value: Any = xrdtypes.NA, combine_attrs: CombineAttrsOptions = "override", - ) -> T_Dataset: + ) -> Self: """Merge the arrays of two datasets into a single dataset. This method generally does not allow for overriding data, with the @@ -5365,11 +5687,11 @@ def _assert_all_in_dataset( ) def drop_vars( - self: T_Dataset, + self, names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_Dataset: + ) -> Self: """Drop variables from this dataset. Parameters @@ -5381,10 +5703,100 @@ def drop_vars( passed are not in the dataset. If 'ignore', any given names that are in the dataset are dropped and no error is raised. + Examples + -------- + + >>> dataset = xr.Dataset( + ... { + ... "temperature": ( + ... ["time", "latitude", "longitude"], + ... [[[25.5, 26.3], [27.1, 28.0]]], + ... ), + ... "humidity": ( + ... ["time", "latitude", "longitude"], + ... [[[65.0, 63.8], [58.2, 59.6]]], + ... ), + ... "wind_speed": ( + ... ["time", "latitude", "longitude"], + ... [[[10.2, 8.5], [12.1, 9.8]]], + ... ), + ... }, + ... coords={ + ... "time": pd.date_range("2023-07-01", periods=1), + ... "latitude": [40.0, 40.2], + ... "longitude": [-75.0, -74.8], + ... }, + ... ) + >>> dataset + + Dimensions: (time: 1, latitude: 2, longitude: 2) + Coordinates: + * time (time) datetime64[ns] 2023-07-01 + * latitude (latitude) float64 40.0 40.2 + * longitude (longitude) float64 -75.0 -74.8 + Data variables: + temperature (time, latitude, longitude) float64 25.5 26.3 27.1 28.0 + humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6 + wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8 + + # Drop the 'humidity' variable + + >>> dataset.drop_vars(["humidity"]) + + Dimensions: (time: 1, latitude: 2, longitude: 2) + Coordinates: + * time (time) datetime64[ns] 2023-07-01 + * latitude (latitude) float64 40.0 40.2 + * longitude (longitude) float64 -75.0 -74.8 + Data variables: + temperature (time, latitude, longitude) float64 25.5 26.3 27.1 28.0 + wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8 + + # Drop the 'humidity', 'temperature' variables + + >>> dataset.drop_vars(["humidity", "temperature"]) + + Dimensions: (time: 1, latitude: 2, longitude: 2) + Coordinates: + * time (time) datetime64[ns] 2023-07-01 + * latitude (latitude) float64 40.0 40.2 + * longitude (longitude) float64 -75.0 -74.8 + Data variables: + wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8 + + # Attempt to drop non-existent variable with errors="ignore" + + >>> dataset.drop_vars(["pressure"], errors="ignore") + + Dimensions: (time: 1, latitude: 2, longitude: 2) + Coordinates: + * time (time) datetime64[ns] 2023-07-01 + * latitude (latitude) float64 40.0 40.2 + * longitude (longitude) float64 -75.0 -74.8 + Data variables: + temperature (time, latitude, longitude) float64 25.5 26.3 27.1 28.0 + humidity (time, latitude, longitude) float64 65.0 63.8 58.2 59.6 + wind_speed (time, latitude, longitude) float64 10.2 8.5 12.1 9.8 + + # Attempt to drop non-existent variable with errors="raise" + + >>> dataset.drop_vars(["pressure"], errors="raise") + Traceback (most recent call last): + ValueError: These variables cannot be found in this dataset: ['pressure'] + + Raises + ------ + ValueError + Raised if you attempt to drop a variable which is not present, and the kwarg ``errors='raise'``. + Returns ------- dropped : Dataset + See Also + -------- + DataArray.drop_vars + """ # the Iterable check is required for mypy if is_scalar(names) or not isinstance(names, Iterable): @@ -5421,11 +5833,11 @@ def drop_vars( ) def drop_indexes( - self: T_Dataset, + self, coord_names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_Dataset: + ) -> Self: """Drop the indexes assigned to the given coordinates. Parameters @@ -5452,7 +5864,10 @@ def drop_indexes( if errors == "raise": invalid_coords = coord_names - self._coord_names if invalid_coords: - raise ValueError(f"those coordinates don't exist: {invalid_coords}") + raise ValueError( + f"The coordinates {tuple(invalid_coords)} are not found in the " + f"dataset coordinates {tuple(self.coords.keys())}" + ) unindexed_coords = set(coord_names) - set(self._indexes) if unindexed_coords: @@ -5474,13 +5889,13 @@ def drop_indexes( return self._replace(variables=variables, indexes=indexes) def drop( - self: T_Dataset, + self, labels=None, dim=None, *, errors: ErrorOptions = "raise", **labels_kwargs, - ) -> T_Dataset: + ) -> Self: """Backward compatible method based on `drop_vars` and `drop_sel` Using either `drop_vars` or `drop_sel` is encouraged @@ -5530,8 +5945,8 @@ def drop( return self.drop_sel(labels, errors=errors) def drop_sel( - self: T_Dataset, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs - ) -> T_Dataset: + self, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs + ) -> Self: """Drop index labels from this dataset. Parameters @@ -5600,7 +6015,7 @@ def drop_sel( ds = ds.loc[{dim: new_index}] return ds - def drop_isel(self: T_Dataset, indexers=None, **indexers_kwargs) -> T_Dataset: + def drop_isel(self, indexers=None, **indexers_kwargs) -> Self: """Drop index positions from this Dataset. Parameters @@ -5666,11 +6081,11 @@ def drop_isel(self: T_Dataset, indexers=None, **indexers_kwargs) -> T_Dataset: return ds def drop_dims( - self: T_Dataset, + self, drop_dims: str | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_Dataset: + ) -> Self: """Drop dimensions and associated variables from this dataset. Parameters @@ -5700,17 +6115,17 @@ def drop_dims( missing_dims = drop_dims - set(self.dims) if missing_dims: raise ValueError( - f"Dataset does not contain the dimensions: {missing_dims}" + f"Dimensions {tuple(missing_dims)} not found in data dimensions {tuple(self.dims)}" ) drop_vars = {k for k, v in self._variables.items() if set(v.dims) & drop_dims} return self.drop_vars(drop_vars) def transpose( - self: T_Dataset, + self, *dims: Hashable, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> T_Dataset: + ) -> Self: """Return a new Dataset object with all array dimensions transposed. Although the order of dimensions on each array will change, the dataset @@ -5762,13 +6177,15 @@ def transpose( ds._variables[name] = var.transpose(*var_dims) return ds + @_deprecate_positional_args("v2023.10.0") def dropna( - self: T_Dataset, + self, dim: Hashable, + *, how: Literal["any", "all"] = "any", thresh: int | None = None, subset: Iterable[Hashable] | None = None, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with dropped labels for missing values along the provided dimension. @@ -5860,7 +6277,9 @@ def dropna( # depending on the order of the supplied axes. if dim not in self.dims: - raise ValueError(f"{dim} must be a single dataset dimension") + raise ValueError( + f"Dimension {dim!r} not found in data dimensions {tuple(self.dims)}" + ) if subset is None: subset = iter(self.data_vars) @@ -5872,7 +6291,7 @@ def dropna( array = self._variables[k] if dim in array.dims: dims = [d for d in array.dims if d != dim] - count += np.asarray(array.count(dims)) # type: ignore[attr-defined] + count += np.asarray(array.count(dims)) size += math.prod([self.dims[d] for d in dims]) if thresh is not None: @@ -5888,7 +6307,7 @@ def dropna( return self.isel({dim: mask}) - def fillna(self: T_Dataset, value: Any) -> T_Dataset: + def fillna(self, value: Any) -> Self: """Fill missing values in this object. This operation follows the normal broadcasting and alignment rules that @@ -5969,7 +6388,7 @@ def fillna(self: T_Dataset, value: Any) -> T_Dataset: return out def interpolate_na( - self: T_Dataset, + self, dim: Hashable | None = None, method: InterpOptions = "linear", limit: int | None = None, @@ -5978,7 +6397,7 @@ def interpolate_na( int | float | str | pd.Timedelta | np.timedelta64 | datetime.timedelta ) = None, **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Fill in NaNs by interpolating according to different methods. Parameters @@ -5986,7 +6405,7 @@ def interpolate_na( dim : Hashable or None, optional Specifies the dimension along which to interpolate. method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \ - "barycentric", "krog", "pchip", "spline", "akima"}, default: "linear" + "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear" String indicating which method to use for interpolation: - 'linear': linear interpolation. Additional keyword @@ -5995,7 +6414,7 @@ def interpolate_na( are passed to :py:func:`scipy.interpolate.interp1d`. If ``method='polynomial'``, the ``order`` keyword argument must also be provided. - - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their + - 'barycentric', 'krogh', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. use_coordinate : bool or Hashable, default: True @@ -6040,6 +6459,11 @@ def interpolate_na( interpolated: Dataset Filled in Dataset. + Warning + -------- + When passing fill_value as a keyword argument with method="linear", it does not use + ``numpy.interp`` but it uses ``scipy.interpolate.interp1d``, which provides the fill_value parameter. + See Also -------- numpy.interp @@ -6103,7 +6527,7 @@ def interpolate_na( ) return new - def ffill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset: + def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values forward *Requires bottleneck.* @@ -6167,7 +6591,7 @@ def ffill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) return new - def bfill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset: + def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward *Requires bottleneck.* @@ -6232,7 +6656,7 @@ def bfill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit) return new - def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset: + def combine_first(self, other: Self) -> Self: """Combine two Datasets, default to data_vars of self. The new coordinates follow the normal broadcasting and alignment rules @@ -6252,7 +6676,7 @@ def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset: return out def reduce( - self: T_Dataset, + self, func: Callable, dim: Dims = None, *, @@ -6260,7 +6684,7 @@ def reduce( keepdims: bool = False, numeric_only: bool = False, **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Reduce this dataset by applying `func` along some dimension(s). Parameters @@ -6336,10 +6760,10 @@ def reduce( else: dims = set(dim) - missing_dimensions = [d for d in dims if d not in self.dims] + missing_dimensions = tuple(d for d in dims if d not in self.dims) if missing_dimensions: raise ValueError( - f"Dataset does not contain the dimensions: {missing_dimensions}" + f"Dimensions {missing_dimensions} not found in data dimensions {tuple(self.dims)}" ) if keep_attrs is None: @@ -6385,12 +6809,12 @@ def reduce( ) def map( - self: T_Dataset, + self, func: Callable, keep_attrs: bool | None = None, args: Iterable[Any] = (), **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Apply a function to each data variable in this dataset Parameters @@ -6445,12 +6869,12 @@ def map( return type(self)(variables, attrs=attrs) def apply( - self: T_Dataset, + self, func: Callable, keep_attrs: bool | None = None, args: Iterable[Any] = (), **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """ Backward compatible implementation of ``map`` @@ -6466,10 +6890,10 @@ def apply( return self.map(func, keep_attrs, args, **kwargs) def assign( - self: T_Dataset, + self, variables: Mapping[Any, Any] | None = None, **variables_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Assign new data variables to a Dataset, returning a new object with all the original variables in addition to the new ones. @@ -6498,6 +6922,10 @@ def assign( possible, but you cannot reference other variables created within the same ``assign`` call. + The new assigned variables that replace existing coordinates in the + original dataset are still listed as coordinates in the returned + Dataset. + See Also -------- pandas.DataFrame.assign @@ -6553,11 +6981,23 @@ def assign( """ variables = either_dict_or_kwargs(variables, variables_kwargs, "assign") data = self.copy() + # do all calculations first... results: CoercibleMapping = data._calc_assign_results(variables) - data.coords._maybe_drop_multiindex_coords(set(results.keys())) + + # split data variables to add/replace vs. coordinates to replace + results_data_vars: dict[Hashable, CoercibleValue] = {} + results_coords: dict[Hashable, CoercibleValue] = {} + for k, v in results.items(): + if k in data._coord_names: + results_coords[k] = v + else: + results_data_vars[k] = v + # ... and then assign - data.update(results) + data.coords.update(results_coords) + data.update(results_data_vars) + return data def to_array( @@ -6619,8 +7059,8 @@ def _normalize_dim_order( dim_order = list(self.dims) elif set(dim_order) != set(self.dims): raise ValueError( - "dim_order {} does not match the set of dimensions of this " - "Dataset: {}".format(dim_order, list(self.dims)) + f"dim_order {dim_order} does not match the set of dimensions of this " + f"Dataset: {list(self.dims)}" ) ordered_dims = {k: self.dims[k] for k in dim_order} @@ -6758,9 +7198,7 @@ def _set_numpy_data_from_dataframe( self[name] = (dims, data) @classmethod - def from_dataframe( - cls: type[T_Dataset], dataframe: pd.DataFrame, sparse: bool = False - ) -> T_Dataset: + def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: """Convert a pandas.DataFrame into an xarray.Dataset Each column will be converted into an independent variable in the @@ -6974,7 +7412,7 @@ def to_dict( return d @classmethod - def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset: + def from_dict(cls, d: Mapping[Any, Any]) -> Self: """Convert a dictionary into an xarray.Dataset. Parameters @@ -7051,8 +7489,7 @@ def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset: } except KeyError as e: raise ValueError( - "cannot convert dict without the key " - "'{dims_data}'".format(dims_data=str(e.args[0])) + "cannot convert dict without the key " f"'{str(e.args[0])}'" ) obj = cls(variable_dict) @@ -7065,7 +7502,7 @@ def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset: return obj - def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset: + def _unary_op(self, f, *args, **kwargs) -> Self: variables = {} keep_attrs = kwargs.pop("keep_attrs", None) if keep_attrs is None: @@ -7076,7 +7513,7 @@ def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset: else: variables[k] = f(v, *args, **kwargs) if keep_attrs: - variables[k].attrs = v._attrs + variables[k]._attrs = v._attrs attrs = self._attrs if keep_attrs else None return self._replace_with_new_dims(variables, attrs=attrs) @@ -7088,7 +7525,7 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset: return NotImplemented align_type = OPTIONS["arithmetic_join"] if join is None else join if isinstance(other, (DataArray, Dataset)): - self, other = align(self, other, join=align_type, copy=False) # type: ignore[assignment] + self, other = align(self, other, join=align_type, copy=False) g = f if not reflexive else lambda x, y: f(y, x) ds = self._calculate_binary_op(g, other, join=align_type) keep_attrs = _get_keep_attrs(default=False) @@ -7096,7 +7533,7 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset: ds.attrs = self.attrs return ds - def _inplace_binary_op(self: T_Dataset, other, f) -> T_Dataset: + def _inplace_binary_op(self, other, f) -> Self: from xarray.core.dataarray import DataArray from xarray.core.groupby import GroupBy @@ -7170,12 +7607,14 @@ def _copy_attrs_from(self, other): if v in self.variables: self.variables[v].attrs = other.variables[v].attrs + @_deprecate_positional_args("v2023.10.0") def diff( - self: T_Dataset, + self, dim: Hashable, n: int = 1, + *, label: Literal["upper", "lower"] = "upper", - ) -> T_Dataset: + ) -> Self: """Calculate the n-th order discrete difference along given axis. Parameters @@ -7258,11 +7697,11 @@ def diff( return difference def shift( - self: T_Dataset, + self, shifts: Mapping[Any, int] | None = None, fill_value: Any = xrdtypes.NA, **shifts_kwargs: int, - ) -> T_Dataset: + ) -> Self: """Shift this dataset by an offset along one or more dimensions. Only data variables are moved; coordinates stay in place. This is @@ -7306,9 +7745,11 @@ def shift( foo (x) object nan nan 'a' 'b' 'c' """ shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "shift") - invalid = [k for k in shifts if k not in self.dims] + invalid = tuple(k for k in shifts if k not in self.dims) if invalid: - raise ValueError(f"dimensions {invalid!r} do not exist") + raise ValueError( + f"Dimensions {invalid} not found in data dimensions {tuple(self.dims)}" + ) variables = {} for name, var in self.variables.items(): @@ -7327,11 +7768,11 @@ def shift( return self._replace(variables) def roll( - self: T_Dataset, + self, shifts: Mapping[Any, int] | None = None, roll_coords: bool = False, **shifts_kwargs: int, - ) -> T_Dataset: + ) -> Self: """Roll this dataset by an offset along one or more dimensions. Unlike shift, roll treats the given dimensions as periodic, so will not @@ -7385,7 +7826,9 @@ def roll( shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "roll") invalid = [k for k in shifts if k not in self.dims] if invalid: - raise ValueError(f"dimensions {invalid!r} do not exist") + raise ValueError( + f"Dimensions {invalid} not found in data dimensions {tuple(self.dims)}" + ) unrolled_vars: tuple[Hashable, ...] @@ -7411,10 +7854,13 @@ def roll( return self._replace(variables, indexes=indexes) def sortby( - self: T_Dataset, - variables: Hashable | DataArray | list[Hashable | DataArray], + self, + variables: Hashable + | DataArray + | Sequence[Hashable | DataArray] + | Callable[[Self], Hashable | DataArray | list[Hashable | DataArray]], ascending: bool = True, - ) -> T_Dataset: + ) -> Self: """ Sort object by labels or values (along an axis). @@ -7434,9 +7880,10 @@ def sortby( Parameters ---------- - variables : Hashable, DataArray, or list of hashable or DataArray - 1D DataArray objects or name(s) of 1D variable(s) in - coords/data_vars whose values are used to sort the dataset. + kariables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are + used to sort this array. If a callable, the callable is passed this object, + and the result is used as the value for cond. ascending : bool, default: True Whether to sort by ascending or descending order. @@ -7462,8 +7909,7 @@ def sortby( ... }, ... coords={"x": ["b", "a"], "y": [1, 0]}, ... ) - >>> ds = ds.sortby("x") - >>> ds + >>> ds.sortby("x") Dimensions: (x: 2, y: 2) Coordinates: @@ -7472,17 +7918,28 @@ def sortby( Data variables: A (x, y) int64 3 4 1 2 B (x, y) int64 7 8 5 6 + >>> ds.sortby(lambda x: -x["y"]) + + Dimensions: (x: 2, y: 2) + Coordinates: + * x (x) T_Dataset: + ) -> Self: """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements for each variable @@ -7522,15 +7981,15 @@ def quantile( desired quantile lies between two data points. The options sorted by their R type as summarized in the H&F paper [1]_ are: - 1. "inverted_cdf" (*) - 2. "averaged_inverted_cdf" (*) - 3. "closest_observation" (*) - 4. "interpolated_inverted_cdf" (*) - 5. "hazen" (*) - 6. "weibull" (*) + 1. "inverted_cdf" + 2. "averaged_inverted_cdf" + 3. "closest_observation" + 4. "interpolated_inverted_cdf" + 5. "hazen" + 6. "weibull" 7. "linear" (default) - 8. "median_unbiased" (*) - 9. "normal_unbiased" (*) + 8. "median_unbiased" + 9. "normal_unbiased" The first three methods are discontiuous. The following discontinuous variations of the default "linear" (7.) option are also available: @@ -7544,8 +8003,6 @@ def quantile( was previously called "interpolation", renamed in accordance with numpy version 1.22.0. - (*) These methods require numpy version 1.22 or newer. - keep_attrs : bool, optional If True, the dataset's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new @@ -7636,10 +8093,11 @@ def quantile( else: dims = set(dim) - _assert_empty( - tuple(d for d in dims if d not in self.dims), - "Dataset does not contain the dimensions: %s", - ) + invalid_dims = set(dims) - set(self.dims) + if invalid_dims: + raise ValueError( + f"Dimensions {tuple(invalid_dims)} not found in data dimensions {tuple(self.dims)}" + ) q = np.asarray(q, dtype=np.float64) @@ -7675,12 +8133,14 @@ def quantile( ) return new.assign_coords(quantile=q) + @_deprecate_positional_args("v2023.10.0") def rank( - self: T_Dataset, + self, dim: Hashable, + *, pct: bool = False, keep_attrs: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Ranks the data. Equal values are assigned a rank that is the average of the ranks that @@ -7715,7 +8175,9 @@ def rank( ) if dim not in self.dims: - raise ValueError(f"Dataset does not contain the dimension: {dim}") + raise ValueError( + f"Dimension {dim!r} not found in data dimensions {tuple(self.dims)}" + ) variables = {} for name, var in self.variables.items(): @@ -7732,11 +8194,11 @@ def rank( return self._replace(variables, coord_names, attrs=attrs) def differentiate( - self: T_Dataset, + self, coord: Hashable, edge_order: Literal[1, 2] = 1, datetime_unit: DatetimeUnitOptions | None = None, - ) -> T_Dataset: + ) -> Self: """ Differentiate with the second order accurate central differences. @@ -7765,13 +8227,16 @@ def differentiate( from xarray.core.variable import Variable if coord not in self.variables and coord not in self.dims: - raise ValueError(f"Coordinate {coord} does not exist.") + variables_and_dims = tuple(set(self.variables.keys()).union(self.dims)) + raise ValueError( + f"Coordinate {coord!r} not found in variables or dimensions {variables_and_dims}." + ) coord_var = self[coord].variable if coord_var.ndim != 1: raise ValueError( - "Coordinate {} must be 1 dimensional but is {}" - " dimensional".format(coord, coord_var.ndim) + f"Coordinate {coord} must be 1 dimensional but is {coord_var.ndim}" + " dimensional" ) dim = coord_var.dims[0] @@ -7801,10 +8266,10 @@ def differentiate( return self._replace(variables) def integrate( - self: T_Dataset, + self, coord: Hashable | Sequence[Hashable], datetime_unit: DatetimeUnitOptions = None, - ) -> T_Dataset: + ) -> Self: """Integrate along the given coordinate using the trapezoidal rule. .. note:: @@ -7867,13 +8332,16 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): from xarray.core.variable import Variable if coord not in self.variables and coord not in self.dims: - raise ValueError(f"Coordinate {coord} does not exist.") + variables_and_dims = tuple(set(self.variables.keys()).union(self.dims)) + raise ValueError( + f"Coordinate {coord!r} not found in variables or dimensions {variables_and_dims}." + ) coord_var = self[coord].variable if coord_var.ndim != 1: raise ValueError( - "Coordinate {} must be 1 dimensional but is {}" - " dimensional".format(coord, coord_var.ndim) + f"Coordinate {coord} must be 1 dimensional but is {coord_var.ndim}" + " dimensional" ) dim = coord_var.dims[0] @@ -7917,10 +8385,10 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): ) def cumulative_integrate( - self: T_Dataset, + self, coord: Hashable | Sequence[Hashable], datetime_unit: DatetimeUnitOptions = None, - ) -> T_Dataset: + ) -> Self: """Integrate along the given coordinate using the trapezoidal rule. .. note:: @@ -7992,7 +8460,7 @@ def cumulative_integrate( return result @property - def real(self: T_Dataset) -> T_Dataset: + def real(self) -> Self: """ The real part of each data variable. @@ -8003,7 +8471,7 @@ def real(self: T_Dataset) -> T_Dataset: return self.map(lambda x: x.real, keep_attrs=True) @property - def imag(self: T_Dataset) -> T_Dataset: + def imag(self) -> Self: """ The imaginary part of each data variable. @@ -8015,7 +8483,7 @@ def imag(self: T_Dataset) -> T_Dataset: plot = utils.UncachedAccessor(DatasetPlotAccessor) - def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset: + def filter_by_attrs(self, **kwargs) -> Self: """Returns a ``Dataset`` with variables that match specific conditions. Can pass in ``key=value`` or ``key=callable``. A Dataset is returned @@ -8110,7 +8578,7 @@ def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset: selection.append(var_name) return self[selection] - def unify_chunks(self: T_Dataset) -> T_Dataset: + def unify_chunks(self) -> Self: """Unify chunk size along all chunked dimensions of this Dataset. Returns @@ -8232,7 +8700,7 @@ def map_blocks( return map_blocks(func, self, args, kwargs, template) def polyfit( - self: T_Dataset, + self, dim: Hashable, deg: int, skipna: bool | None = None, @@ -8240,7 +8708,7 @@ def polyfit( w: Hashable | Any = None, full: bool = False, cov: bool | Literal["unscaled"] = False, - ) -> T_Dataset: + ) -> Self: """ Least squares polynomial fit. @@ -8376,9 +8844,9 @@ def polyfit( with warnings.catch_warnings(): if full: # Copy np.polyfit behavior - warnings.simplefilter("ignore", np.RankWarning) + warnings.simplefilter("ignore", RankWarning) else: # Raise only once per variable - warnings.simplefilter("once", np.RankWarning) + warnings.simplefilter("once", RankWarning) coeffs, residuals = duck_array_ops.least_squares( lhs, rhs.data, rcond=rcond, skipna=skipna_da @@ -8428,7 +8896,7 @@ def polyfit( return type(self)(data_vars=variables, attrs=self.attrs.copy()) def pad( - self: T_Dataset, + self, pad_width: Mapping[Any, int | tuple[int, int]] | None = None, mode: PadModeOptions = "constant", stat_length: int @@ -8442,7 +8910,7 @@ def pad( reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, **pad_width_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Pad this dataset along one or more dimensions. .. warning:: @@ -8613,13 +9081,15 @@ def pad( attrs = self._attrs if keep_attrs else None return self._replace_with_new_dims(variables, indexes=indexes, attrs=attrs) + @_deprecate_positional_args("v2023.10.0") def idxmin( - self: T_Dataset, + self, dim: Hashable | None = None, + *, skipna: bool | None = None, fill_value: Any = xrdtypes.NA, keep_attrs: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Return the coordinate label of the minimum value along a dimension. Returns a new `Dataset` named after the dimension with the values of @@ -8668,8 +9138,8 @@ def idxmin( >>> array2 = xr.DataArray( ... [ ... [2.0, 1.0, 2.0, 0.0, -2.0], - ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], - ... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN], + ... [-4.0, np.nan, 2.0, np.nan, -2.0], + ... [np.nan, np.nan, 1.0, np.nan, np.nan], ... ], ... dims=["y", "x"], ... coords={"y": [-1, 0, 1], "x": ["a", "b", "c", "d", "e"]}, @@ -8710,13 +9180,15 @@ def idxmin( ) ) + @_deprecate_positional_args("v2023.10.0") def idxmax( - self: T_Dataset, + self, dim: Hashable | None = None, + *, skipna: bool | None = None, fill_value: Any = xrdtypes.NA, keep_attrs: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Return the coordinate label of the maximum value along a dimension. Returns a new `Dataset` named after the dimension with the values of @@ -8765,8 +9237,8 @@ def idxmax( >>> array2 = xr.DataArray( ... [ ... [2.0, 1.0, 2.0, 0.0, -2.0], - ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], - ... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN], + ... [-4.0, np.nan, 2.0, np.nan, -2.0], + ... [np.nan, np.nan, 1.0, np.nan, np.nan], ... ], ... dims=["y", "x"], ... coords={"y": [-1, 0, 1], "x": ["a", "b", "c", "d", "e"]}, @@ -8807,7 +9279,7 @@ def idxmax( ) ) - def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: + def argmin(self, dim: Hashable | None = None, **kwargs) -> Self: """Indices of the minima of the member variables. If there are multiple minima, the indices of the first one found will be @@ -8910,7 +9382,7 @@ def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: "Dataset.argmin() with a sequence or ... for dim" ) - def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: + def argmax(self, dim: Hashable | None = None, **kwargs) -> Self: """Indices of the maxima of the member variables. If there are multiple maxima, the indices of the first one found will be @@ -9004,13 +9476,13 @@ def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: ) def query( - self: T_Dataset, + self, queries: Mapping[Any, Any] | None = None, parser: QueryParserOptions = "pandas", engine: QueryEngineOptions = None, missing_dims: ErrorOptionsWithWarn = "raise", **queries_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Return a new dataset with each array indexed along the specified dimension(s), where the indexers are given as strings containing Python expressions to be evaluated against the data variables in the @@ -9100,7 +9572,7 @@ def query( return self.isel(indexers, missing_dims=missing_dims) def curvefit( - self: T_Dataset, + self, coords: str | DataArray | Iterable[str | DataArray], func: Callable[..., Any], reduce_dims: Dims = None, @@ -9110,7 +9582,7 @@ def curvefit( param_names: Sequence[str] | None = None, errors: ErrorOptions = "raise", kwargs: dict[str, Any] | None = None, - ) -> T_Dataset: + ) -> Self: """ Curve fitting optimization for arbitrary functions. @@ -9333,11 +9805,13 @@ def _wrapper(Y, *args, **kwargs): return result + @_deprecate_positional_args("v2023.10.0") def drop_duplicates( - self: T_Dataset, + self, dim: Hashable | Iterable[Hashable], + *, keep: Literal["first", "last", False] = "first", - ) -> T_Dataset: + ) -> Self: """Returns a new Dataset with duplicate dimension values removed. Parameters @@ -9369,19 +9843,21 @@ def drop_duplicates( missing_dims = set(dims) - set(self.dims) if missing_dims: - raise ValueError(f"'{missing_dims}' not found in dimensions") + raise ValueError( + f"Dimensions {tuple(missing_dims)} not found in data dimensions {tuple(self.dims)}" + ) indexes = {dim: ~self.get_index(dim).duplicated(keep=keep) for dim in dims} return self.isel(indexes) def convert_calendar( - self: T_Dataset, + self, calendar: CFCalendar, dim: Hashable = "time", align_on: Literal["date", "year", None] = None, missing: Any | None = None, use_cftime: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Convert the Dataset to another calendar. Only converts the individual timestamps, does not modify any data except @@ -9498,10 +9974,10 @@ def convert_calendar( ) def interp_calendar( - self: T_Dataset, + self, target: pd.DatetimeIndex | CFTimeIndex | DataArray, dim: Hashable = "time", - ) -> T_Dataset: + ) -> Self: """Interpolates the Dataset to another calendar based on decimal year measure. Each timestamp in `source` and `target` are first converted to their decimal diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 7ac342e3d52..ccf84146819 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +from typing import Any import numpy as np @@ -40,11 +41,11 @@ def __eq__(self, other): PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( (np.number, np.character), # numpy promotes to character (np.bool_, np.character), # numpy promotes to character - (np.bytes_, np.unicode_), # numpy promotes to unicode + (np.bytes_, np.str_), # numpy promotes to unicode ) -def maybe_promote(dtype): +def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: """Simpler equivalent of pandas.core.common._maybe_promote Parameters @@ -57,27 +58,33 @@ def maybe_promote(dtype): fill_value : Valid missing value for the promoted dtype. """ # N.B. these casting rules should match pandas + dtype_: np.typing.DTypeLike + fill_value: Any if np.issubdtype(dtype, np.floating): + dtype_ = dtype fill_value = np.nan elif np.issubdtype(dtype, np.timedelta64): # See https://github.com/numpy/numpy/issues/10685 # np.timedelta64 is a subclass of np.integer # Check np.timedelta64 before np.integer fill_value = np.timedelta64("NaT") + dtype_ = dtype elif np.issubdtype(dtype, np.integer): - dtype = np.float32 if dtype.itemsize <= 2 else np.float64 + dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64 fill_value = np.nan elif np.issubdtype(dtype, np.complexfloating): + dtype_ = dtype fill_value = np.nan + np.nan * 1j elif np.issubdtype(dtype, np.datetime64): + dtype_ = dtype fill_value = np.datetime64("NaT") else: - dtype = object + dtype_ = object fill_value = np.nan - dtype = np.dtype(dtype) - fill_value = dtype.type(fill_value) - return dtype, fill_value + dtype_out = np.dtype(dtype_) + fill_value = dtype_out.type(fill_value) + return dtype_out, fill_value NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 4f245e59f73..b9f7db9737f 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -18,7 +18,6 @@ from numpy import any as array_any # noqa from numpy import ( # noqa around, # noqa - einsum, gradient, isclose, isin, @@ -48,6 +47,17 @@ def get_array_namespace(x): return np +def einsum(*args, **kwargs): + from xarray.core.options import OPTIONS + + if OPTIONS["use_opt_einsum"] and module_available("opt_einsum"): + import opt_einsum + + return opt_einsum.contract(*args, **kwargs) + else: + return np.einsum(*args, **kwargs) + + def _dask_or_eager_func( name, eager_module=np, @@ -184,6 +194,9 @@ def cumulative_trapezoid(y, x, axis): def astype(data, dtype, **kwargs): if hasattr(data, "__array_namespace__"): xp = get_array_namespace(data) + if xp == np: + # numpy currently doesn't have a astype: + return data.astype(dtype, **kwargs) return xp.astype(data, dtype, **kwargs) return data.astype(dtype, **kwargs) @@ -337,6 +350,10 @@ def reshape(array, shape): return xp.reshape(array, shape) +def ravel(array): + return reshape(array, (-1,)) + + @contextlib.contextmanager def _ignore_warnings_if(condition): if condition: @@ -363,7 +380,7 @@ def f(values, axis=None, skipna=None, **kwargs): values = asarray(values) if coerce_strings and values.dtype.kind in "SU": - values = values.astype(object) + values = astype(values, object) func = None if skipna or (skipna is None and values.dtype.kind in "cfO"): diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 06f84c3eee1..561b8d3cc0d 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -10,17 +10,21 @@ from datetime import datetime, timedelta from itertools import chain, zip_longest from reprlib import recursive_repr +from typing import TYPE_CHECKING import numpy as np import pandas as pd from pandas.errors import OutOfBoundsDatetime -from xarray.core.duck_array_ops import array_equiv -from xarray.core.indexing import ExplicitlyIndexed, MemoryCachedArray +from xarray.core.duck_array_ops import array_equiv, astype +from xarray.core.indexing import MemoryCachedArray from xarray.core.options import OPTIONS, _get_boolean_with_default -from xarray.core.pycompat import array_type +from xarray.core.pycompat import array_type, to_duck_array, to_numpy from xarray.core.utils import is_duck_array +if TYPE_CHECKING: + from xarray.core.coordinates import AbstractCoordinates + def pretty_print(x, numchars: int): """Given an object `x`, call `str(x)` and format the returned string so @@ -64,6 +68,8 @@ def first_n_items(array, n_desired): # might not be a numpy.ndarray. Moreover, access to elements of the array # could be very expensive (e.g. if it's only available over DAP), so go out # of our way to get them in a single call to __getitem__ using only slices. + from xarray.core.variable import Variable + if n_desired < 1: raise ValueError("must request at least one item") @@ -74,7 +80,14 @@ def first_n_items(array, n_desired): if n_desired < array.size: indexer = _get_indexer_at_least_n_items(array.shape, n_desired, from_end=False) array = array[indexer] - return np.asarray(array).flat[:n_desired] + + # We pass variable objects in to handle indexing + # with indexer above. It would not work with our + # lazy indexing classes at the moment, so we cannot + # pass Variable._data + if isinstance(array, Variable): + array = array._data + return np.ravel(to_duck_array(array))[:n_desired] def last_n_items(array, n_desired): @@ -83,13 +96,22 @@ def last_n_items(array, n_desired): # might not be a numpy.ndarray. Moreover, access to elements of the array # could be very expensive (e.g. if it's only available over DAP), so go out # of our way to get them in a single call to __getitem__ using only slices. + from xarray.core.variable import Variable + if (n_desired == 0) or (array.size == 0): return [] if n_desired < array.size: indexer = _get_indexer_at_least_n_items(array.shape, n_desired, from_end=True) array = array[indexer] - return np.asarray(array).flat[-n_desired:] + + # We pass variable objects in to handle indexing + # with indexer above. It would not work with our + # lazy indexing classes at the moment, so we cannot + # pass Variable._data + if isinstance(array, Variable): + array = array._data + return np.ravel(to_duck_array(array))[-n_desired:] def last_item(array): @@ -99,7 +121,8 @@ def last_item(array): return [] indexer = (slice(-1, None),) * array.ndim - return np.ravel(np.asarray(array[indexer])).tolist() + # to_numpy since dask doesn't support tolist + return np.ravel(to_numpy(array[indexer])).tolist() def calc_max_rows_first(max_rows: int) -> int: @@ -156,6 +179,8 @@ def format_item(x, timedelta_format=None, quote_strings=True): if isinstance(x, (np.timedelta64, timedelta)): return format_timedelta(x, timedelta_format=timedelta_format) elif isinstance(x, (str, bytes)): + if hasattr(x, "dtype"): + x = x.item() return repr(x) if quote_strings else x elif hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating): return f"{x.item():.4}" @@ -165,10 +190,10 @@ def format_item(x, timedelta_format=None, quote_strings=True): def format_items(x): """Returns a succinct summaries of all items in a sequence as strings""" - x = np.asarray(x) + x = to_duck_array(x) timedelta_format = "datetime" if np.issubdtype(x.dtype, np.timedelta64): - x = np.asarray(x, dtype="timedelta64[ns]") + x = astype(x, dtype="timedelta64[ns]") day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]") time_needed = x[~pd.isnull(x)] != day_part day_needed = day_part != np.timedelta64(0, "ns") @@ -398,7 +423,7 @@ def _mapping_repr( ) -def coords_repr(coords, col_width=None, max_rows=None): +def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None): if col_width is None: col_width = _calculate_col_width(coords) return _mapping_repr( @@ -412,7 +437,7 @@ def coords_repr(coords, col_width=None, max_rows=None): ) -def inline_index_repr(index, max_width=None): +def inline_index_repr(index: pd.Index, max_width=None): if hasattr(index, "_repr_inline_"): repr_ = index._repr_inline_(max_width=max_width) else: @@ -424,21 +449,37 @@ def inline_index_repr(index, max_width=None): def summarize_index( - name: Hashable, index, col_width: int, max_width: int | None = None -): + names: tuple[Hashable, ...], + index, + col_width: int, + max_width: int | None = None, +) -> str: if max_width is None: max_width = OPTIONS["display_width"] - preformatted = pretty_print(f" {name} ", col_width) + def prefixes(length: int) -> list[str]: + if length in (0, 1): + return [" "] + + return ["┌"] + ["│"] * max(length - 2, 0) + ["└"] + + preformatted = [ + pretty_print(f" {prefix} {name}", col_width) + for prefix, name in zip(prefixes(len(names)), names) + ] - index_width = max_width - len(preformatted) + head, *tail = preformatted + index_width = max_width - len(head) repr_ = inline_index_repr(index, max_width=index_width) - return preformatted + repr_ + return "\n".join([head + repr_] + [line.rstrip() for line in tail]) -def nondefault_indexes(indexes): +def filter_nondefault_indexes(indexes, filter_indexes: bool): from xarray.core.indexes import PandasIndex, PandasMultiIndex + if not filter_indexes: + return indexes + default_indexes = (PandasIndex, PandasMultiIndex) return { @@ -448,7 +489,9 @@ def nondefault_indexes(indexes): } -def indexes_repr(indexes, col_width=None, max_rows=None): +def indexes_repr(indexes, max_rows: int | None = None) -> str: + col_width = _calculate_col_width(chain.from_iterable(indexes)) + return _mapping_repr( indexes, "Indexes", @@ -560,12 +603,9 @@ def limit_lines(string: str, *, limit: int): def short_array_repr(array): from xarray.core.common import AbstractArray - if isinstance(array, ExplicitlyIndexed): - array = array.get_duck_array() - elif isinstance(array, AbstractArray): + if isinstance(array, AbstractArray): array = array.data - if not is_duck_array(array): - array = np.asarray(array) + array = to_duck_array(array) # default to lower precision so a full (abbreviated) line can fit on # one line with the default display_width @@ -592,13 +632,19 @@ def short_data_repr(array): return short_array_repr(array) elif is_duck_array(internal_data): return limit_lines(repr(array.data), limit=40) - elif array._in_memory: + elif getattr(array, "_in_memory", None): return short_array_repr(array) else: # internal xarray array type return f"[{array.size} values with dtype={array.dtype}]" +def _get_indexes_dict(indexes): + return { + tuple(index_vars.keys()): idx for idx, index_vars in indexes.group_by_index() + } + + @recursive_repr("") def array_repr(arr): from xarray.core.variable import Variable @@ -643,15 +689,13 @@ def array_repr(arr): display_default_indexes = _get_boolean_with_default( "display_default_indexes", False ) - if display_default_indexes: - xindexes = arr.xindexes - else: - xindexes = nondefault_indexes(arr.xindexes) + + xindexes = filter_nondefault_indexes( + _get_indexes_dict(arr.xindexes), not display_default_indexes + ) if xindexes: - summary.append( - indexes_repr(xindexes, col_width=col_width, max_rows=max_rows) - ) + summary.append(indexes_repr(xindexes, max_rows=max_rows)) if arr.attrs: summary.append(attrs_repr(arr.attrs, max_rows=max_rows)) @@ -682,12 +726,11 @@ def dataset_repr(ds): display_default_indexes = _get_boolean_with_default( "display_default_indexes", False ) - if display_default_indexes: - xindexes = ds.xindexes - else: - xindexes = nondefault_indexes(ds.xindexes) + xindexes = filter_nondefault_indexes( + _get_indexes_dict(ds.xindexes), not display_default_indexes + ) if xindexes: - summary.append(indexes_repr(xindexes, col_width=col_width, max_rows=max_rows)) + summary.append(indexes_repr(xindexes, max_rows=max_rows)) if ds.attrs: summary.append(attrs_repr(ds.attrs, max_rows=max_rows)) diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 60bb901c31a..3627554cf57 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -28,7 +28,7 @@ def _load_static_files(): ] -def short_data_repr_html(array): +def short_data_repr_html(array) -> str: """Format "data" for DataArray and Variable.""" internal_data = getattr(array, "variable", array)._data if hasattr(internal_data, "_repr_html_"): @@ -37,7 +37,7 @@ def short_data_repr_html(array): return f"
{text}
" -def format_dims(dims, dims_with_index): +def format_dims(dims, dims_with_index) -> str: if not dims: return "" @@ -53,7 +53,7 @@ def format_dims(dims, dims_with_index): return f"
    {dims_li}
" -def summarize_attrs(attrs): +def summarize_attrs(attrs) -> str: attrs_dl = "".join( f"
{escape(str(k))} :
" f"
{escape(str(v))}
" for k, v in attrs.items() @@ -62,17 +62,17 @@ def summarize_attrs(attrs): return f"
{attrs_dl}
" -def _icon(icon_name): +def _icon(icon_name) -> str: # icon_name should be defined in xarray/static/html/icon-svg-inline.html return ( - "" - "" + f"" + f"" "" - "".format(icon_name) + "" ) -def summarize_variable(name, var, is_index=False, dtype=None): +def summarize_variable(name, var, is_index=False, dtype=None) -> str: variable = var.variable if hasattr(var, "variable") else var cssclass_idx = " class='xr-has-index'" if is_index else "" @@ -109,7 +109,7 @@ def summarize_variable(name, var, is_index=False, dtype=None): ) -def summarize_coords(variables): +def summarize_coords(variables) -> str: li_items = [] for k, v in variables.items(): li_content = summarize_variable(k, v, is_index=k in variables.xindexes) @@ -120,7 +120,7 @@ def summarize_coords(variables): return f"
    {vars_li}
" -def summarize_vars(variables): +def summarize_vars(variables) -> str: vars_li = "".join( f"
  • {summarize_variable(k, v)}
  • " for k, v in variables.items() @@ -129,14 +129,14 @@ def summarize_vars(variables): return f"
      {vars_li}
    " -def short_index_repr_html(index): +def short_index_repr_html(index) -> str: if hasattr(index, "_repr_html_"): return index._repr_html_() return f"
    {escape(repr(index))}
    " -def summarize_index(coord_names, index): +def summarize_index(coord_names, index) -> str: name = "
    ".join([escape(str(n)) for n in coord_names]) index_id = f"index-{uuid.uuid4()}" @@ -155,7 +155,7 @@ def summarize_index(coord_names, index): ) -def summarize_indexes(indexes): +def summarize_indexes(indexes) -> str: indexes_li = "".join( f"
  • {summarize_index(v, i)}
  • " for v, i in indexes.items() @@ -165,7 +165,7 @@ def summarize_indexes(indexes): def collapsible_section( name, inline_details="", details="", n_items=None, enabled=True, collapsed=False -): +) -> str: # "unique" id to expand/collapse the section data_id = "section-" + str(uuid.uuid4()) @@ -187,7 +187,7 @@ def collapsible_section( def _mapping_section( mapping, name, details_func, max_items_collapse, expand_option_name, enabled=True -): +) -> str: n_items = len(mapping) expanded = _get_boolean_with_default( expand_option_name, n_items < max_items_collapse @@ -203,7 +203,7 @@ def _mapping_section( ) -def dim_section(obj): +def dim_section(obj) -> str: dim_list = format_dims(obj.dims, obj.xindexes.dims) return collapsible_section( @@ -211,7 +211,7 @@ def dim_section(obj): ) -def array_section(obj): +def array_section(obj) -> str: # "unique" id to expand/collapse the section data_id = "section-" + str(uuid.uuid4()) collapsed = ( @@ -296,7 +296,7 @@ def _obj_repr(obj, header_components, sections): ) -def array_repr(arr): +def array_repr(arr) -> str: dims = OrderedDict((k, v) for k, v in zip(arr.dims, arr.shape)) if hasattr(arr, "xindexes"): indexed_dims = arr.xindexes.dims @@ -326,7 +326,7 @@ def array_repr(arr): return _obj_repr(arr, header_components, sections) -def dataset_repr(ds): +def dataset_repr(ds) -> str: obj_type = f"xarray.{type(ds).__name__}" header_components = [f"
    {escape(obj_type)}
    "] diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 2a4465dc325..788e1efa80b 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -43,6 +43,7 @@ peek_at, ) from xarray.core.variable import IndexVariable, Variable +from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: from numpy.typing import ArrayLike @@ -699,7 +700,7 @@ class GroupBy(Generic[T_Xarray]): _groups: dict[GroupKey, GroupIndex] | None _dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None - _sizes: Frozen[Hashable, int] | None + _sizes: Mapping[Hashable, int] | None def __init__( self, @@ -746,7 +747,7 @@ def __init__( self._sizes = None @property - def sizes(self) -> Frozen[Hashable, int]: + def sizes(self) -> Mapping[Hashable, int]: """Ordered mapping from dimension names to lengths. Immutable. @@ -888,6 +889,21 @@ def _binary_op(self, other, f, reflexive=False): group = group.where(~mask, drop=True) codes = codes.where(~mask, drop=True).astype(int) + # if other is dask-backed, that's a hint that the + # "expanded" dataset is too big to hold in memory. + # this can be the case when `other` was read from disk + # and contains our lazy indexing classes + # We need to check for dask-backed Datasets + # so utils.is_duck_dask_array does not work for this check + if obj.chunks and not other.chunks: + # TODO: What about datasets with some dask vars, and others not? + # This handles dims other than `name`` + chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims} + # a chunk size of 1 seems reasonable since we expect individual elements of + # other to be repeated multiple times across the reduced dimension(s) + chunks[name] = 1 + other = other.chunk(chunks) + # codes are defined for coord, so we align `other` with `coord` # before indexing other, _ = align(other, coord, join="right", copy=False) @@ -973,7 +989,7 @@ def _flox_reduce( if kwargs["func"] not in ["sum", "prod"]: raise TypeError("Received an unexpected keyword argument 'min_count'") elif kwargs["min_count"] is None: - # set explicitly to avoid unncessarily accumulating count + # set explicitly to avoid unnecessarily accumulating count kwargs["min_count"] = 0 # weird backcompat @@ -1077,10 +1093,12 @@ def fillna(self, value: Any) -> T_Xarray: """ return ops.fillna(self, value) + @_deprecate_positional_args("v2023.10.0") def quantile( self, q: ArrayLike, dim: Dims = None, + *, method: QuantileMethods = "linear", keep_attrs: bool | None = None, skipna: bool | None = None, @@ -1102,15 +1120,15 @@ def quantile( desired quantile lies between two data points. The options sorted by their R type as summarized in the H&F paper [1]_ are: - 1. "inverted_cdf" (*) - 2. "averaged_inverted_cdf" (*) - 3. "closest_observation" (*) - 4. "interpolated_inverted_cdf" (*) - 5. "hazen" (*) - 6. "weibull" (*) + 1. "inverted_cdf" + 2. "averaged_inverted_cdf" + 3. "closest_observation" + 4. "interpolated_inverted_cdf" + 5. "hazen" + 6. "weibull" 7. "linear" (default) - 8. "median_unbiased" (*) - 9. "normal_unbiased" (*) + 8. "median_unbiased" + 9. "normal_unbiased" The first three methods are discontiuous. The following discontinuous variations of the default "linear" (7.) option are also available: @@ -1120,9 +1138,8 @@ def quantile( * "midpoint" * "nearest" - See :py:func:`numpy.quantile` or [1]_ for details. Methods marked with - an asterisk require numpy version 1.22 or newer. The "method" argument was - previously called "interpolation", renamed in accordance with numpy + See :py:func:`numpy.quantile` or [1]_ for details. The "method" argument + was previously called "interpolation", renamed in accordance with numpy version 1.22.0. keep_attrs : bool or None, default: None If True, the dataarray's attributes (`attrs`) will be copied from diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 33b9b7bcff9..1697762f7ae 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -24,7 +24,7 @@ ) if TYPE_CHECKING: - from xarray.core.types import ErrorOptions, JoinOptions, T_Index + from xarray.core.types import ErrorOptions, JoinOptions, Self from xarray.core.variable import Variable @@ -60,11 +60,11 @@ class Index: @classmethod def from_variables( - cls: type[T_Index], + cls, variables: Mapping[Any, Variable], *, options: Mapping[str, Any], - ) -> T_Index: + ) -> Self: """Create a new index object from one or more coordinate variables. This factory method must be implemented in all subclasses of Index. @@ -88,11 +88,11 @@ def from_variables( @classmethod def concat( - cls: type[T_Index], - indexes: Sequence[T_Index], + cls, + indexes: Sequence[Self], dim: Hashable, positions: Iterable[Iterable[int]] | None = None, - ) -> T_Index: + ) -> Self: """Create a new index by concatenating one or more indexes of the same type. @@ -120,9 +120,7 @@ def concat( raise NotImplementedError() @classmethod - def stack( - cls: type[T_Index], variables: Mapping[Any, Variable], dim: Hashable - ) -> T_Index: + def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) -> Self: """Create a new index by stacking coordinate variables into a single new dimension. @@ -208,8 +206,8 @@ def to_pandas_index(self) -> pd.Index: raise TypeError(f"{self!r} cannot be cast to a pandas.Index object") def isel( - self: T_Index, indexers: Mapping[Any, int | slice | np.ndarray | Variable] - ) -> T_Index | None: + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Self | None: """Maybe returns a new index from the current index itself indexed by positional indexers. @@ -264,7 +262,7 @@ def sel(self, labels: dict[Any, Any]) -> IndexSelResult: """ raise NotImplementedError(f"{self!r} doesn't support label-based selection") - def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index: + def join(self, other: Self, how: JoinOptions = "inner") -> Self: """Return a new index from the combination of this index with another index of the same type. @@ -286,7 +284,7 @@ def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index: f"{self!r} doesn't support alignment with inner/outer join method" ) - def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]: + def reindex_like(self, other: Self) -> dict[Hashable, Any]: """Query the index with another index of the same type. Implementation is optional but required in order to support alignment. @@ -304,10 +302,10 @@ def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]: """ raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") - def equals(self: T_Index, other: T_Index) -> bool: + def equals(self, other: Self) -> bool: """Compare this index with another index of the same type. - Implemenation is optional but required in order to support alignment. + Implementation is optional but required in order to support alignment. Parameters ---------- @@ -321,7 +319,7 @@ def equals(self: T_Index, other: T_Index) -> bool: """ raise NotImplementedError() - def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None: + def roll(self, shifts: Mapping[Any, int]) -> Self | None: """Roll this index by an offset along one or more dimensions. This method can be re-implemented in subclasses of Index, e.g., when the @@ -347,10 +345,10 @@ def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None: return None def rename( - self: T_Index, + self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable], - ) -> T_Index: + ) -> Self: """Maybe update the index with new coordinate and dimension names. This method should be re-implemented in subclasses of Index if it has @@ -377,7 +375,7 @@ def rename( """ return self - def copy(self: T_Index, deep: bool = True) -> T_Index: + def copy(self, deep: bool = True) -> Self: """Return a (deep) copy of this index. Implementation in subclasses of Index is optional. The base class @@ -396,15 +394,13 @@ def copy(self: T_Index, deep: bool = True) -> T_Index: """ return self._copy(deep=deep) - def __copy__(self: T_Index) -> T_Index: + def __copy__(self) -> Self: return self.copy(deep=False) def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Index: return self._copy(deep=True, memo=memo) - def _copy( - self: T_Index, deep: bool = True, memo: dict[int, Any] | None = None - ) -> T_Index: + def _copy(self, deep: bool = True, memo: dict[int, Any] | None = None) -> Self: cls = self.__class__ copied = cls.__new__(cls) if deep: @@ -414,7 +410,7 @@ def _copy( copied.__dict__.update(self.__dict__) return copied - def __getitem__(self: T_Index, indexer: Any) -> T_Index: + def __getitem__(self, indexer: Any) -> Self: raise NotImplementedError() def _repr_inline_(self, max_width): @@ -616,7 +612,14 @@ def from_variables( name, var = next(iter(variables.items())) - if var.ndim != 1: + if var.ndim == 0: + raise ValueError( + f"cannot set a PandasIndex from the scalar variable {name!r}, " + "only 1-dimensional variables are supported. " + f"Note: you might want to use `obj.expand_dims({name!r})` to create a " + f"new dimension and turn {name!r} as an indexed dimension coordinate." + ) + elif var.ndim != 1: raise ValueError( "PandasIndex only accepts a 1-dimensional variable, " f"variable {name!r} has {var.ndim} dimensions" @@ -667,10 +670,10 @@ def _concat_indexes(indexes, dim, positions=None) -> pd.Index: @classmethod def concat( cls, - indexes: Sequence[PandasIndex], + indexes: Sequence[Self], dim: Hashable, positions: Iterable[Iterable[int]] | None = None, - ) -> PandasIndex: + ) -> Self: new_pd_index = cls._concat_indexes(indexes, dim, positions) if not indexes: @@ -793,7 +796,11 @@ def equals(self, other: Index): return False return self.index.equals(other.index) and self.dim == other.dim - def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasIndex: + def join( + self, + other: Self, + how: str = "inner", + ) -> Self: if how == "outer": index = self.index.union(other.index) else: @@ -804,7 +811,7 @@ def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasInd return type(self)(index, self.dim, coord_dtype=coord_dtype) def reindex_like( - self, other: PandasIndex, method=None, tolerance=None + self, other: Self, method=None, tolerance=None ) -> dict[Hashable, Any]: if not self.index.is_unique: raise ValueError( @@ -956,12 +963,12 @@ def from_variables( return obj @classmethod - def concat( # type: ignore[override] + def concat( cls, - indexes: Sequence[PandasMultiIndex], + indexes: Sequence[Self], dim: Hashable, positions: Iterable[Iterable[int]] | None = None, - ) -> PandasMultiIndex: + ) -> Self: new_pd_index = cls._concat_indexes(indexes, dim, positions) if not indexes: @@ -1196,12 +1203,12 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: coord_name, label = next(iter(labels.items())) if is_dict_like(label): - invalid_levels = [ + invalid_levels = tuple( name for name in label if name not in self.index.names - ] + ) if invalid_levels: raise ValueError( - f"invalid multi-index level names {invalid_levels}" + f"multi-index level names {invalid_levels} not found in indexes {tuple(self.index.names)}" ) return self.sel(label) @@ -1381,19 +1388,22 @@ def create_default_index_implicit( class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): - """Immutable proxy for Dataset or DataArrary indexes. + """Immutable proxy for Dataset or DataArray indexes. - Keys are coordinate names and values may correspond to either pandas or - xarray indexes. + It is a mapping where keys are coordinate names and values are either pandas + or xarray indexes. - Also provides some utility methods. + It also contains the indexed coordinate variables and provides some utility + methods. """ + _index_type: type[Index] | type[pd.Index] _indexes: dict[Any, T_PandasOrXarrayIndex] _variables: dict[Any, Variable] __slots__ = ( + "_index_type", "_indexes", "_variables", "_dims", @@ -1404,8 +1414,9 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): def __init__( self, - indexes: dict[Any, T_PandasOrXarrayIndex], - variables: dict[Any, Variable], + indexes: Mapping[Any, T_PandasOrXarrayIndex] | None = None, + variables: Mapping[Any, Variable] | None = None, + index_type: type[Index] | type[pd.Index] = Index, ): """Constructor not for public consumption. @@ -1414,11 +1425,33 @@ def __init__( indexes : dict Indexes held by this object. variables : dict - Indexed coordinate variables in this object. + Indexed coordinate variables in this object. Entries must + match those of `indexes`. + index_type : type + The type of all indexes, i.e., either :py:class:`xarray.indexes.Index` + or :py:class:`pandas.Index`. """ - self._indexes = indexes - self._variables = variables + if indexes is None: + indexes = {} + if variables is None: + variables = {} + + unmatched_keys = set(indexes) ^ set(variables) + if unmatched_keys: + raise ValueError( + f"unmatched keys found in indexes and variables: {unmatched_keys}" + ) + + if any(not isinstance(idx, index_type) for idx in indexes.values()): + index_type_str = f"{index_type.__module__}.{index_type.__name__}" + raise TypeError( + f"values of indexes must all be instances of {index_type_str}" + ) + + self._index_type = index_type + self._indexes = dict(**indexes) + self._variables = dict(**variables) self._dims: Mapping[Hashable, int] | None = None self.__coord_name_id: dict[Any, int] | None = None @@ -1566,10 +1599,10 @@ def to_pandas_indexes(self) -> Indexes[pd.Index]: elif isinstance(idx, Index): indexes[k] = idx.to_pandas_index() - return Indexes(indexes, self._variables) + return Indexes(indexes, self._variables, index_type=pd.Index) def copy_indexes( - self, deep: bool = True, memo: dict[int, Any] | None = None + self, deep: bool = True, memo: dict[int, T_PandasOrXarrayIndex] | None = None ) -> tuple[dict[Hashable, T_PandasOrXarrayIndex], dict[Hashable, Variable]]: """Return a new dictionary with copies of indexes, preserving unique indexes. @@ -1586,6 +1619,7 @@ def copy_indexes( new_indexes = {} new_index_vars = {} + idx: T_PandasOrXarrayIndex for idx, coords in self.group_by_index(): if isinstance(idx, pd.Index): convert_new_idx = True @@ -1621,7 +1655,8 @@ def __getitem__(self, key) -> T_PandasOrXarrayIndex: return self._indexes[key] def __repr__(self): - return formatting.indexes_repr(self) + indexes = formatting._get_indexes_dict(self) + return formatting.indexes_repr(indexes) def default_indexes( diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index acab9ccc60b..6e6ce01a41f 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -143,7 +143,10 @@ def group_indexers_by_index( elif key in obj.coords: raise KeyError(f"no index found for coordinate {key!r}") elif key not in obj.dims: - raise KeyError(f"{key!r} is not a valid dimension or coordinate") + raise KeyError( + f"{key!r} is not a valid dimension or coordinate for " + f"{obj.__class__.__name__} with dimensions {obj.dims!r}" + ) elif len(options): raise ValueError( f"cannot supply selection options {options!r} for dimension {key!r}" @@ -1303,7 +1306,7 @@ def __init__(self, array): if not isinstance(array, np.ndarray): raise TypeError( "NumpyIndexingAdapter only wraps np.ndarray. " - "Trying to wrap {}".format(type(array)) + f"Trying to wrap {type(array)}" ) self.array = array diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 56e51256ba1..a8e54ad1231 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -11,7 +11,6 @@ from xarray.core.duck_array_ops import lazy_array_equiv from xarray.core.indexes import ( Index, - Indexes, create_default_index_implicit, filter_indexes_from_coords, indexes_equal, @@ -34,7 +33,7 @@ tuple[DimsLike, ArrayLike, Mapping, Mapping], ] XarrayValue = Union[DataArray, Variable, VariableLike] - DatasetLike = Union[Dataset, Mapping[Any, XarrayValue]] + DatasetLike = Union[Dataset, Coordinates, Mapping[Any, XarrayValue]] CoercibleValue = Union[XarrayValue, pd.Series, pd.DataFrame] CoercibleMapping = Union[Dataset, Mapping[Any, CoercibleValue]] @@ -311,17 +310,22 @@ def collect_variables_and_indexes( ) -> dict[Hashable, list[MergeElement]]: """Collect variables and indexes from list of mappings of xarray objects. - Mappings must either be Dataset objects, or have values of one of the - following types: + Mappings can be Dataset or Coordinates objects, in which case both + variables and indexes are extracted from it. + + It can also have values of one of the following types: - an xarray.Variable - a tuple `(dims, data[, attrs[, encoding]])` that can be converted in an xarray.Variable - or an xarray.DataArray If a mapping of indexes is given, those indexes are assigned to all variables - with a matching key/name. + with a matching key/name. For dimension variables with no matching index, a + default (pandas) index is assigned. DataArray indexes that don't match mapping + keys are also extracted. """ + from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -338,8 +342,8 @@ def append_all(variables, indexes): append(name, variable, indexes.get(name)) for mapping in list_of_mappings: - if isinstance(mapping, Dataset): - append_all(mapping.variables, mapping._indexes) + if isinstance(mapping, (Coordinates, Dataset)): + append_all(mapping.variables, mapping.xindexes) continue for name, variable in mapping.items(): @@ -466,13 +470,15 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik List of Dataset or dictionary objects. Any inputs or values in the inputs that were pandas objects have been converted into native xarray objects. """ + from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - out = [] + out: list[DatasetLike] = [] for obj in objects: - if isinstance(obj, Dataset): - variables: DatasetLike = obj + variables: DatasetLike + if isinstance(obj, (Dataset, Coordinates)): + variables = obj else: variables = {} if isinstance(obj, PANDAS_TYPES): @@ -486,7 +492,7 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik def _get_priority_vars_and_indexes( - objects: list[DatasetLike], + objects: Sequence[DatasetLike], priority_arg: int | None, compat: CompatOptions = "equals", ) -> dict[Hashable, MergeElement]: @@ -498,7 +504,7 @@ def _get_priority_vars_and_indexes( Parameters ---------- - objects : list of dict-like of Variable + objects : sequence of dict-like of Variable Dictionaries in which to find the priority variables. priority_arg : int or None Integer object whose variable should take priority. @@ -556,62 +562,6 @@ def merge_coords( return variables, out_indexes -def merge_data_and_coords( - data_vars: Mapping[Any, Any], - coords: Mapping[Any, Any], - compat: CompatOptions = "broadcast_equals", - join: JoinOptions = "outer", -) -> _MergeResult: - """Used in Dataset.__init__.""" - indexes, coords = _create_indexes_from_coords(coords, data_vars) - objects = [data_vars, coords] - explicit_coords = coords.keys() - return merge_core( - objects, - compat, - join, - explicit_coords=explicit_coords, - indexes=Indexes(indexes, coords), - ) - - -def _create_indexes_from_coords( - coords: Mapping[Any, Any], data_vars: Mapping[Any, Any] | None = None -) -> tuple[dict, dict]: - """Maybe create default indexes from a mapping of coordinates. - - Return those indexes and updated coordinates. - """ - all_variables = dict(coords) - if data_vars is not None: - all_variables.update(data_vars) - - indexes = {} - updated_coords = {} - - # this is needed for backward compatibility: when a pandas multi-index - # is given as data variable, it is promoted as index / level coordinates - # TODO: depreciate this implicit behavior - index_vars = { - k: v - for k, v in all_variables.items() - if k in coords or isinstance(v, pd.MultiIndex) - } - - for name, obj in index_vars.items(): - variable = as_variable(obj, name=name) - - if variable.dims == (name,): - idx, idx_vars = create_default_index_implicit(variable, all_variables) - indexes.update({k: idx for k in idx_vars}) - updated_coords.update(idx_vars) - all_variables.update(idx_vars) - else: - updated_coords[name] = obj - - return indexes, updated_coords - - def assert_valid_explicit_coords( variables: Mapping[Any, Any], dims: Mapping[Any, int], @@ -702,6 +652,7 @@ def merge_core( explicit_coords: Iterable[Hashable] | None = None, indexes: Mapping[Any, Any] | None = None, fill_value: object = dtypes.NA, + skip_align_args: list[int] | None = None, ) -> _MergeResult: """Core logic for merging labeled objects. @@ -727,6 +678,8 @@ def merge_core( may be cast to pandas.Index objects. fill_value : scalar, optional Value to use for newly missing values + skip_align_args : list of int, optional + Optional arguments in `objects` that are not included in alignment. Returns ------- @@ -748,10 +701,20 @@ def merge_core( _assert_compat_valid(compat) + objects = list(objects) + if skip_align_args is None: + skip_align_args = [] + + skip_align_objs = [(pos, objects.pop(pos)) for pos in skip_align_args] + coerced = coerce_pandas_values(objects) aligned = deep_align( coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value ) + + for pos, obj in skip_align_objs: + aligned.insert(pos, obj) + collected = collect_variables_and_indexes(aligned, indexes=indexes) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) variables, out_indexes = merge_collected( @@ -761,6 +724,9 @@ def merge_core( dims = calculate_dimensions(variables) coord_names, noncoord_names = determine_coords(coerced) + if compat == "minimal": + # coordinates may be dropped in merged results + coord_names.intersection_update(variables) if explicit_coords is not None: assert_valid_explicit_coords(variables, dims, explicit_coords) coord_names.update(explicit_coords) @@ -1008,18 +974,23 @@ def merge( combine_nested combine_by_coords """ + + from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset dict_like_objects = [] for obj in objects: - if not isinstance(obj, (DataArray, Dataset, dict)): + if not isinstance(obj, (DataArray, Dataset, Coordinates, dict)): raise TypeError( "objects must be an iterable containing only " "Dataset(s), DataArray(s), and dictionaries." ) - obj = obj.to_dataset(promote_attrs=True) if isinstance(obj, DataArray) else obj + if isinstance(obj, DataArray): + obj = obj.to_dataset(promote_attrs=True) + elif isinstance(obj, Coordinates): + obj = obj.to_dataset() dict_like_objects.append(obj) merge_result = merge_core( diff --git a/xarray/core/missing.py b/xarray/core/missing.py index c6efaebc04c..90a9dd2e76c 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -501,7 +501,7 @@ def _get_interpolator( ) elif method == "barycentric": interp_class = _import_interpolant("BarycentricInterpolator", method) - elif method == "krog": + elif method in ["krogh", "krog"]: interp_class = _import_interpolant("KroghInterpolator", method) elif method == "pchip": interp_class = _import_interpolant("PchipInterpolator", method) @@ -678,7 +678,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): Notes ----- - This requiers scipy installed. + This requires scipy installed. See Also -------- @@ -724,13 +724,13 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): for i in range(new_x[0].ndim) } - # if useful, re-use localize for each chunk of new_x + # if useful, reuse localize for each chunk of new_x localize = (method in ["linear", "nearest"]) and new_x0_chunks_is_not_none # scipy.interpolate.interp1d always forces to float. # Use the same check for blockwise as well: if not issubclass(var.dtype.type, np.inexact): - dtype = np.float_ + dtype = float else: dtype = var.dtype diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 3b8ddfe032d..fc7240139aa 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -4,7 +4,7 @@ import numpy as np -from xarray.core import dtypes, nputils, utils +from xarray.core import dtypes, duck_array_ops, nputils, utils from xarray.core.duck_array_ops import ( astype, count, @@ -21,12 +21,16 @@ def _maybe_null_out(result, axis, mask, min_count=1): xarray version of pandas.core.nanops._maybe_null_out """ if axis is not None and getattr(result, "ndim", False): - null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0 + null_mask = ( + np.take(mask.shape, axis).prod() + - duck_array_ops.sum(mask, axis) + - min_count + ) < 0 dtype, fill_value = dtypes.maybe_promote(result.dtype) result = where(null_mask, fill_value, astype(result, dtype)) elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES: - null_mask = mask.size - mask.sum() + null_mask = mask.size - duck_array_ops.sum(mask) result = where(null_mask < min_count, np.nan, result) return result diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 1c5b0d3d972..316a77ead6a 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -5,6 +5,13 @@ import numpy as np import pandas as pd from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] +from packaging.version import Version + +# remove once numpy 2.0 is the oldest supported version +try: + from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import RankWarning from xarray.core.options import OPTIONS from xarray.core.pycompat import is_duck_array @@ -12,11 +19,20 @@ try: import bottleneck as bn - _USE_BOTTLENECK = True + _BOTTLENECK_AVAILABLE = True except ImportError: # use numpy methods instead bn = np - _USE_BOTTLENECK = False + _BOTTLENECK_AVAILABLE = False + +try: + import numbagg + + _HAS_NUMBAGG = Version(numbagg.__version__) >= Version("0.5.0") +except ImportError: + # use numpy methods instead + numbagg = np + _HAS_NUMBAGG = False def _select_along_axis(values, idx, axis): @@ -155,13 +171,30 @@ def __setitem__(self, key, value): self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions) -def _create_bottleneck_method(name, npmodule=np): +def _create_method(name, npmodule=np): def f(values, axis=None, **kwargs): dtype = kwargs.get("dtype", None) bn_func = getattr(bn, name, None) + nba_func = getattr(numbagg, name, None) if ( - _USE_BOTTLENECK + _HAS_NUMBAGG + and OPTIONS["use_numbagg"] + and isinstance(values, np.ndarray) + and nba_func is not None + # numbagg uses ddof=1 only, but numpy uses ddof=0 by default + and (("var" in name or "std" in name) and kwargs.get("ddof", 0) == 1) + # TODO: bool? + and values.dtype.kind in "uifc" + # and values.dtype.isnative + and (dtype is None or np.dtype(dtype) == values.dtype) + ): + # numbagg does not take care dtype, ddof + kwargs.pop("dtype", None) + kwargs.pop("ddof", None) + result = nba_func(values, axis=axis, **kwargs) + elif ( + _BOTTLENECK_AVAILABLE and OPTIONS["use_bottleneck"] and isinstance(values, np.ndarray) and bn_func is not None @@ -194,7 +227,7 @@ def _nanpolyfit_1d(arr, x, rcond=None): def warn_on_deficient_rank(rank, order): if rank != order: - warnings.warn("Polyfit may be poorly conditioned", np.RankWarning, stacklevel=2) + warnings.warn("Polyfit may be poorly conditioned", RankWarning, stacklevel=2) def least_squares(lhs, rhs, rcond=None, skipna=False): @@ -227,14 +260,14 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): return coeffs, residuals -nanmin = _create_bottleneck_method("nanmin") -nanmax = _create_bottleneck_method("nanmax") -nanmean = _create_bottleneck_method("nanmean") -nanmedian = _create_bottleneck_method("nanmedian") -nanvar = _create_bottleneck_method("nanvar") -nanstd = _create_bottleneck_method("nanstd") -nanprod = _create_bottleneck_method("nanprod") -nancumsum = _create_bottleneck_method("nancumsum") -nancumprod = _create_bottleneck_method("nancumprod") -nanargmin = _create_bottleneck_method("nanargmin") -nanargmax = _create_bottleneck_method("nanargmax") +nanmin = _create_method("nanmin") +nanmax = _create_method("nanmax") +nanmean = _create_method("nanmean") +nanmedian = _create_method("nanmedian") +nanvar = _create_method("nanvar") +nanstd = _create_method("nanstd") +nanprod = _create_method("nanprod") +nancumsum = _create_method("nancumsum") +nancumprod = _create_method("nancumprod") +nanargmin = _create_method("nanargmin") +nanargmax = _create_method("nanargmax") diff --git a/xarray/core/ops.py b/xarray/core/ops.py index e1c3573841a..b23d586fb79 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -53,7 +53,6 @@ "var", "median", ] -NAN_CUM_METHODS = ["cumsum", "cumprod"] # TODO: wrap take, dot, sort @@ -263,20 +262,6 @@ def inject_reduce_methods(cls): setattr(cls, name, func) -def inject_cum_methods(cls): - methods = [(name, getattr(duck_array_ops, name), True) for name in NAN_CUM_METHODS] - for name, f, include_skipna in methods: - numeric_only = getattr(f, "numeric_only", False) - func = cls._reduce_method(f, include_skipna, numeric_only) - func.__name__ = name - func.__doc__ = _CUM_DOCSTRING_TEMPLATE.format( - name=name, - cls=cls.__name__, - extra_args=cls._cum_extra_args_docstring.format(name=name), - ) - setattr(cls, name, func) - - def op_str(name): return f"__{name}__" @@ -316,16 +301,6 @@ def __init_subclass__(cls, **kwargs): inject_reduce_methods(cls) -class IncludeCumMethods: - __slots__ = () - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - - if getattr(cls, "_reduce_method", None): - inject_cum_methods(cls) - - class IncludeNumpySameMethods: __slots__ = () diff --git a/xarray/core/options.py b/xarray/core/options.py index eb0c56c7ee0..d116c350991 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -6,10 +6,8 @@ from xarray.core.utils import FrozenDict if TYPE_CHECKING: - try: - from matplotlib.colors import Colormap - except ImportError: - Colormap = str + from matplotlib.colors import Colormap + Options = Literal[ "arithmetic_join", "cmap_divergent", @@ -29,6 +27,8 @@ "keep_attrs", "warn_for_unclosed_files", "use_bottleneck", + "use_numbagg", + "use_opt_einsum", "use_flox", ] @@ -52,6 +52,8 @@ class T_Options(TypedDict): warn_for_unclosed_files: bool use_bottleneck: bool use_flox: bool + use_numbagg: bool + use_opt_einsum: bool OPTIONS: T_Options = { @@ -74,6 +76,8 @@ class T_Options(TypedDict): "warn_for_unclosed_files": False, "use_bottleneck": True, "use_flox": True, + "use_numbagg": True, + "use_opt_einsum": True, } _JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"]) @@ -100,6 +104,8 @@ def _positive_integer(value: int) -> bool: "file_cache_maxsize": _positive_integer, "keep_attrs": lambda choice: choice in [True, False, "default"], "use_bottleneck": lambda value: isinstance(value, bool), + "use_numbagg": lambda value: isinstance(value, bool), + "use_opt_einsum": lambda value: isinstance(value, bool), "use_flox": lambda value: isinstance(value, bool), "warn_for_unclosed_files": lambda value: isinstance(value, bool), } @@ -164,11 +170,11 @@ class set_options: cmap_divergent : str or matplotlib.colors.Colormap, default: "RdBu_r" Colormap to use for divergent data plots. If string, must be matplotlib built-in colormap. Can also be a Colormap object - (e.g. mpl.cm.magma) + (e.g. mpl.colormaps["magma"]) cmap_sequential : str or matplotlib.colors.Colormap, default: "viridis" Colormap to use for nondivergent data plots. If string, must be matplotlib built-in colormap. Can also be a Colormap object - (e.g. mpl.cm.magma) + (e.g. mpl.colormaps["magma"]) display_expand_attrs : {"default", True, False} Whether to expand the attributes section for display of ``DataArray`` or ``Dataset`` objects. Can be @@ -232,6 +238,11 @@ class set_options: use_flox : bool, default: True Whether to use ``numpy_groupies`` and `flox`` to accelerate groupby and resampling reductions. + use_numbagg : bool, default: True + Whether to use ``numbagg`` to accelerate reductions. + Takes precedence over ``use_bottleneck`` when both are True. + use_opt_einsum : bool, default: True + Whether to use ``opt_einsum`` to accelerate dot products. warn_for_unclosed_files : bool, default: False Whether or not to issue a warning when unclosed files are deallocated. This is mostly useful for debugging. diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 07c3c606bf2..949576b4ee8 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -443,7 +443,7 @@ def subset_dataset_to_block( for dim in variable.dims: chunk = chunk[chunk_index[dim]] - chunk_variable_task = (f"{name}-{gname}-{chunk[0]}",) + chunk_tuple + chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple graph[chunk_variable_task] = ( tuple, [variable.dims, chunk, variable.attrs], diff --git a/xarray/core/parallelcompat.py b/xarray/core/parallelcompat.py index 26efc5fc412..333059e00ae 100644 --- a/xarray/core/parallelcompat.py +++ b/xarray/core/parallelcompat.py @@ -25,7 +25,7 @@ T_ChunkedArray = TypeVar("T_ChunkedArray") if TYPE_CHECKING: - from xarray.core.types import T_Chunks, T_NormalizedChunks + from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks @functools.lru_cache(maxsize=1) @@ -257,7 +257,7 @@ def normalize_chunks( @abstractmethod def from_array( - self, data: np.ndarray, chunks: T_Chunks, **kwargs + self, data: T_DuckArray | np.typing.ArrayLike, chunks: T_Chunks, **kwargs ) -> T_ChunkedArray: """ Create a chunked array from a non-chunked numpy-like array. diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 9af5d693170..bc8b61164f1 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -101,3 +101,37 @@ def is_chunked_array(x) -> bool: def is_0d_dask_array(x): return is_duck_dask_array(x) and is_scalar(x) + + +def to_numpy(data) -> np.ndarray: + from xarray.core.indexing import ExplicitlyIndexed + from xarray.core.parallelcompat import get_chunked_array_type + + if isinstance(data, ExplicitlyIndexed): + data = data.get_duck_array() + + # TODO first attempt to call .to_numpy() once some libraries implement it + if hasattr(data, "chunks"): + chunkmanager = get_chunked_array_type(data) + data, *_ = chunkmanager.compute(data) + if isinstance(data, array_type("cupy")): + data = data.get() + # pint has to be imported dynamically as pint imports xarray + if isinstance(data, array_type("pint")): + data = data.magnitude + if isinstance(data, array_type("sparse")): + data = data.todense() + data = np.asarray(data) + + return data + + +def to_duck_array(data): + from xarray.core.indexing import ExplicitlyIndexed + + if isinstance(data, ExplicitlyIndexed): + return data.get_duck_array() + elif is_duck_array(data): + return data + else: + return np.asarray(data) diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index 43edbc08456..5e3f2f57397 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -45,7 +45,6 @@ from xarray.coding.cftime_offsets import ( BaseCFTimeOffset, - Day, MonthEnd, QuarterEnd, Tick, @@ -254,8 +253,7 @@ def _adjust_bin_edges( labels: np.ndarray, ): """This is required for determining the bin edges resampling with - daily frequencies greater than one day, month end, and year end - frequencies. + month end, quarter end, and year end frequencies. Consider the following example. Let's say you want to downsample the time series with the following coordinates to month end frequency: @@ -283,14 +281,8 @@ def _adjust_bin_edges( The labels are still: CFTimeIndex([2000-01-31 00:00:00, 2000-02-29 00:00:00], dtype='object') - - This is also required for daily frequencies longer than one day and - year-end frequencies. """ - is_super_daily = isinstance(freq, (MonthEnd, QuarterEnd, YearEnd)) or ( - isinstance(freq, Day) and freq.n > 1 - ) - if is_super_daily: + if isinstance(freq, (MonthEnd, QuarterEnd, YearEnd)): if closed == "right": datetime_bins = datetime_bins + datetime.timedelta(days=1, microseconds=-1) if datetime_bins[-2] > index.max(): diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 916fabe42ac..b85092982e3 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -61,6 +61,11 @@ class Rolling(Generic[T_Xarray]): __slots__ = ("obj", "window", "min_periods", "center", "dim") _attributes = ("window", "min_periods", "center", "dim") + dim: list[Hashable] + window: list[int] + center: list[bool] + obj: T_Xarray + min_periods: int def __init__( self, @@ -91,8 +96,8 @@ def __init__( ------- rolling : type of input argument """ - self.dim: list[Hashable] = [] - self.window: list[int] = [] + self.dim = [] + self.window = [] for d, w in windows.items(): self.dim.append(d) if w <= 0: @@ -100,7 +105,15 @@ def __init__( self.window.append(w) self.center = self._mapping_to_list(center, default=False) - self.obj: T_Xarray = obj + self.obj = obj + + missing_dims = tuple(dim for dim in self.dim if dim not in self.obj.dims) + if missing_dims: + # NOTE: we raise KeyError here but ValueError in Coarsen. + raise KeyError( + f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} " + f"dimensions {tuple(self.obj.dims)}" + ) # attributes if min_periods is not None and min_periods <= 0: @@ -467,9 +480,8 @@ def reduce( obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna ) - result = windows.reduce( - func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs - ) + dim = list(rolling_dim.values()) + result = windows.reduce(func, dim=dim, keep_attrs=keep_attrs, **kwargs) # Find valid windows based on count. counts = self._counts(keep_attrs=False) @@ -486,6 +498,7 @@ def _counts(self, keep_attrs: bool | None) -> DataArray: # array is faster to be reduced than object array. # The use of skipna==False is also faster since it does not need to # copy the strided array. + dim = list(rolling_dim.values()) counts = ( self.obj.notnull(keep_attrs=keep_attrs) .rolling( @@ -493,7 +506,7 @@ def _counts(self, keep_attrs: bool | None) -> DataArray: center={d: self.center[i] for i, d in enumerate(self.dim)}, ) .construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs) - .sum(dim=list(rolling_dim.values()), skipna=False, keep_attrs=keep_attrs) + .sum(dim=dim, skipna=False, keep_attrs=keep_attrs) ) return counts @@ -564,7 +577,7 @@ def _numpy_or_bottleneck_reduce( and not is_duck_dask_array(self.obj.data) and self.ndim == 1 ): - # TODO: renable bottleneck with dask after the issues + # TODO: re-enable bottleneck with dask after the issues # underlying https://github.com/pydata/xarray/issues/2940 are # fixed. return self._bottleneck_reduce( @@ -624,8 +637,7 @@ def __init__( xarray.DataArray.groupby """ super().__init__(obj, windows, min_periods, center) - if any(d not in self.obj.dims for d in self.dim): - raise KeyError(self.dim) + # Keep each Rolling object as a dictionary self.rollings = {} for key, da in self.obj.data_vars.items(): @@ -778,11 +790,14 @@ def construct( if not keep_attrs: dataset[key].attrs = {} + # Need to stride coords as well. TODO: is there a better way? + coords = self.obj.isel( + {d: slice(None, None, s) for d, s in zip(self.dim, strides)} + ).coords + attrs = self.obj.attrs if keep_attrs else {} - return Dataset(dataset, coords=self.obj.coords, attrs=attrs).isel( - {d: slice(None, None, s) for d, s in zip(self.dim, strides)} - ) + return Dataset(dataset, coords=coords, attrs=attrs) class Coarsen(CoarsenArithmetic, Generic[T_Xarray]): @@ -804,6 +819,10 @@ class Coarsen(CoarsenArithmetic, Generic[T_Xarray]): ) _attributes = ("windows", "side", "trim_excess") obj: T_Xarray + windows: Mapping[Hashable, int] + side: SideOptions | Mapping[Hashable, SideOptions] + boundary: CoarsenBoundaryOptions + coord_func: Mapping[Hashable, str | Callable] def __init__( self, @@ -839,17 +858,21 @@ def __init__( self.side = side self.boundary = boundary - absent_dims = [dim for dim in windows.keys() if dim not in self.obj.dims] - if absent_dims: + missing_dims = tuple(dim for dim in windows.keys() if dim not in self.obj.dims) + if missing_dims: raise ValueError( - f"Dimensions {absent_dims!r} not found in {self.obj.__class__.__name__}." + f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} " + f"dimensions {tuple(self.obj.dims)}" ) - if not utils.is_dict_like(coord_func): - coord_func = {d: coord_func for d in self.obj.dims} # type: ignore[misc] + + if utils.is_dict_like(coord_func): + coord_func_map = coord_func + else: + coord_func_map = {d: coord_func for d in self.obj.dims} for c in self.obj.coords: - if c not in coord_func: - coord_func[c] = duck_array_ops.mean # type: ignore[index] - self.coord_func: Mapping[Hashable, str | Callable] = coord_func + if c not in coord_func_map: + coord_func_map[c] = duck_array_ops.mean # type: ignore[index] + self.coord_func = coord_func_map def _get_keep_attrs(self, keep_attrs): if keep_attrs is None: diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 91edd3acb7c..c8160cefef3 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -6,13 +6,26 @@ import numpy as np from packaging.version import Version +from xarray.core.computation import apply_ufunc from xarray.core.options import _get_keep_attrs from xarray.core.pdcompat import count_not_none -from xarray.core.pycompat import is_duck_dask_array from xarray.core.types import T_DataWithCoords +try: + import numbagg + from numbagg import move_exp_nanmean, move_exp_nansum + + _NUMBAGG_VERSION: Version | None = Version(numbagg.__version__) +except ImportError: + _NUMBAGG_VERSION = None -def _get_alpha(com=None, span=None, halflife=None, alpha=None): + +def _get_alpha( + com: float | None = None, + span: float | None = None, + halflife: float | None = None, + alpha: float | None = None, +) -> float: # pandas defines in terms of com (converting to alpha in the algo) # so use its function to get a com and then convert to alpha @@ -20,31 +33,12 @@ def _get_alpha(com=None, span=None, halflife=None, alpha=None): return 1 / (1 + com) -def move_exp_nanmean(array, *, axis, alpha): - if is_duck_dask_array(array): - raise TypeError("rolling_exp is not currently support for dask-like arrays") - import numbagg - - # No longer needed in numbag > 0.2.0; remove in time - if axis == (): - return array.astype(np.float64) - else: - return numbagg.move_exp_nanmean(array, axis=axis, alpha=alpha) - - -def move_exp_nansum(array, *, axis, alpha): - if is_duck_dask_array(array): - raise TypeError("rolling_exp is not currently supported for dask-like arrays") - import numbagg - - # numbagg <= 0.2.0 did not have a __version__ attribute - if Version(getattr(numbagg, "__version__", "0.1.0")) < Version("0.2.0"): - raise ValueError("`rolling_exp(...).sum() requires numbagg>=0.2.1.") - - return numbagg.move_exp_nansum(array, axis=axis, alpha=alpha) - - -def _get_center_of_mass(comass, span, halflife, alpha): +def _get_center_of_mass( + comass: float | None, + span: float | None, + halflife: float | None, + alpha: float | None, +) -> float: """ Vendored from pandas.core.window.common._get_center_of_mass @@ -104,11 +98,31 @@ def __init__( obj: T_DataWithCoords, windows: Mapping[Any, int | float], window_type: str = "span", + min_weight: float = 0.0, ): + if _NUMBAGG_VERSION is None: + raise ImportError( + "numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed" + ) + elif _NUMBAGG_VERSION < Version("0.2.1"): + raise ImportError( + f"numbagg >= 0.2.1 is required for rolling_exp but currently version {_NUMBAGG_VERSION} is installed" + ) + elif _NUMBAGG_VERSION < Version("0.3.1") and min_weight > 0: + raise ImportError( + f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {_NUMBAGG_VERSION} is installed" + ) + self.obj: T_DataWithCoords = obj dim, window = next(iter(windows.items())) self.dim = dim self.alpha = _get_alpha(**{window_type: window}) + self.min_weight = min_weight + # Don't pass min_weight=0 so we can support older versions of numbagg + kwargs = dict(alpha=self.alpha, axis=-1) + if min_weight > 0: + kwargs["min_weight"] = min_weight + self.kwargs = kwargs def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords: """ @@ -133,9 +147,18 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords: if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) - return self.obj.reduce( - move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs - ) + dim_order = self.obj.dims + + return apply_ufunc( + move_exp_nanmean, + self.obj, + input_core_dims=[[self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=keep_attrs, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: """ @@ -160,6 +183,145 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) - return self.obj.reduce( - move_exp_nansum, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs - ) + dim_order = self.obj.dims + + return apply_ufunc( + move_exp_nansum, + self.obj, + input_core_dims=[[self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=keep_attrs, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def std(self) -> T_DataWithCoords: + """ + Exponentially weighted moving standard deviation. + + `keep_attrs` is always True for this method. Drop attrs separately to remove attrs. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").std() + + array([ nan, 0. , 0.67936622, 0.42966892, 0.25389527]) + Dimensions without coordinates: x + """ + + if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"): + raise ImportError( + f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {_NUMBAGG_VERSION} is installed" + ) + dim_order = self.obj.dims + + return apply_ufunc( + numbagg.move_exp_nanstd, + self.obj, + input_core_dims=[[self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=True, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def var(self) -> T_DataWithCoords: + """ + Exponentially weighted moving variance. + + `keep_attrs` is always True for this method. Drop attrs separately to remove attrs. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").var() + + array([ nan, 0. , 0.46153846, 0.18461538, 0.06446281]) + Dimensions without coordinates: x + """ + + if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"): + raise ImportError( + f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {_NUMBAGG_VERSION} is installed" + ) + dim_order = self.obj.dims + + return apply_ufunc( + numbagg.move_exp_nanvar, + self.obj, + input_core_dims=[[self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=True, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def cov(self, other: T_DataWithCoords) -> T_DataWithCoords: + """ + Exponentially weighted moving covariance. + + `keep_attrs` is always True for this method. Drop attrs separately to remove attrs. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").cov(da**2) + + array([ nan, 0. , 1.38461538, 0.55384615, 0.19338843]) + Dimensions without coordinates: x + """ + + if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"): + raise ImportError( + f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed" + ) + dim_order = self.obj.dims + + return apply_ufunc( + numbagg.move_exp_nancov, + self.obj, + other, + input_core_dims=[[self.dim], [self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=True, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def corr(self, other: T_DataWithCoords) -> T_DataWithCoords: + """ + Exponentially weighted moving correlation. + + `keep_attrs` is always True for this method. Drop attrs separately to remove attrs. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").corr(da.shift(x=1)) + + array([ nan, nan, nan, 0.4330127 , 0.48038446]) + Dimensions without coordinates: x + """ + + if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"): + raise ImportError( + f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed" + ) + dim_order = self.obj.dims + + return apply_ufunc( + numbagg.move_exp_nancorr, + self.obj, + other, + input_core_dims=[[self.dim], [self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=True, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) diff --git a/xarray/core/types.py b/xarray/core/types.py index f3342071107..1be5b00c43f 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -1,12 +1,14 @@ from __future__ import annotations import datetime -from collections.abc import Hashable, Iterable, Sequence +import sys +from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence from typing import ( TYPE_CHECKING, Any, Callable, Literal, + Protocol, SupportsIndex, TypeVar, Union, @@ -14,18 +16,30 @@ import numpy as np import pandas as pd -from packaging.version import Version + +try: + if sys.version_info >= (3, 11): + from typing import Self, TypeAlias + else: + from typing_extensions import Self, TypeAlias +except ImportError: + if TYPE_CHECKING: + raise + else: + Self: Any = None if TYPE_CHECKING: from numpy._typing import _SupportsDType from numpy.typing import ArrayLike from xarray.backends.common import BackendEntrypoint + from xarray.core.alignment import Aligner from xarray.core.common import AbstractArray, DataWithCoords + from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.groupby import DataArrayGroupBy, GroupBy - from xarray.core.indexes import Index + from xarray.core.indexes import Index, Indexes + from xarray.core.utils import Frozen from xarray.core.variable import Variable try: @@ -43,25 +57,12 @@ except ImportError: ZarrArray = np.ndarray - # TODO: Turn on when https://github.com/python/mypy/issues/11871 is fixed. - # Can be uncommented if using pyright though. - # import sys - - # try: - # if sys.version_info >= (3, 11): - # from typing import Self - # else: - # from typing_extensions import Self - # except ImportError: - # Self: Any = None - Self: Any = None - # Anything that can be coerced to a shape tuple _ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]] _DTypeLikeNested = Any # TODO: wait for support for recursive types # Xarray requires a Mapping[Hashable, dtype] in many places which - # conflics with numpys own DTypeLike (with dtypes for fields). + # conflicts with numpys own DTypeLike (with dtypes for fields). # https://numpy.org/devdocs/reference/typing.html#numpy.typing.DTypeLike # This is a copy of this DTypeLike that allows only non-Mapping dtypes. DTypeLikeSave = Union[ @@ -89,35 +90,109 @@ CFTimeDatetime = Any DatetimeLike = Union[pd.Timestamp, datetime.datetime, np.datetime64, CFTimeDatetime] else: - Self: Any = None DTypeLikeSave: Any = None +class Alignable(Protocol): + """Represents any Xarray type that supports alignment. + + It may be ``Dataset``, ``DataArray`` or ``Coordinates``. This protocol class + is needed since those types do not all have a common base class. + + """ + + @property + def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]: + ... + + @property + def sizes(self) -> Mapping[Hashable, int]: + ... + + @property + def xindexes(self) -> Indexes[Index]: + ... + + def _reindex_callback( + self, + aligner: Aligner, + dim_pos_indexers: dict[Hashable, Any], + variables: dict[Hashable, Variable], + indexes: dict[Hashable, Index], + fill_value: Any, + exclude_dims: frozenset[Hashable], + exclude_vars: frozenset[Hashable], + ) -> Self: + ... + + def _overwrite_indexes( + self, + indexes: Mapping[Any, Index], + variables: Mapping[Any, Variable] | None = None, + ) -> Self: + ... + + def __len__(self) -> int: + ... + + def __iter__(self) -> Iterator[Hashable]: + ... + + def copy( + self, + deep: bool = False, + ) -> Self: + ... + + +T_Alignable = TypeVar("T_Alignable", bound="Alignable") + T_Backend = TypeVar("T_Backend", bound="BackendEntrypoint") T_Dataset = TypeVar("T_Dataset", bound="Dataset") T_DataArray = TypeVar("T_DataArray", bound="DataArray") T_Variable = TypeVar("T_Variable", bound="Variable") +T_Coordinates = TypeVar("T_Coordinates", bound="Coordinates") T_Array = TypeVar("T_Array", bound="AbstractArray") T_Index = TypeVar("T_Index", bound="Index") +# `T_Xarray` is a type variable that can be either "DataArray" or "Dataset". When used +# in a function definition, all inputs and outputs annotated with `T_Xarray` must be of +# the same concrete type, either "DataArray" or "Dataset". This is generally preferred +# over `T_DataArrayOrSet`, given the type system can determine the exact type. +T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") + +# `T_DataArrayOrSet` is a type variable that is bounded to either "DataArray" or +# "Dataset". Use it for functions that might return either type, but where the exact +# type cannot be determined statically using the type system. T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"]) -# Maybe we rename this to T_Data or something less Fortran-y? -T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") +# For working directly with `DataWithCoords`. It will only allow using methods defined +# on `DataWithCoords`. T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") + +# Temporary placeholder for indicating an array api compliant type. +# hopefully in the future we can narrow this down more: +T_DuckArray = TypeVar("T_DuckArray", bound=Any) + ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] -DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"] -DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"] VarCompatible = Union["Variable", "ScalarOrArray"] -GroupByIncompatible = Union["Variable", "GroupBy"] +DaCompatible = Union["DataArray", "VarCompatible"] +DsCompatible = Union["Dataset", "DaCompatible"] +GroupByCompatible = Union["Dataset", "DataArray"] Dims = Union[str, Iterable[Hashable], "ellipsis", None] OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None] -T_Chunks = Union[int, dict[Any, Any], Literal["auto"], None] +# FYI in some cases we don't allow `None`, which this doesn't take account of. +T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]] +# We allow the tuple form of this (though arguably we could transition to named dims only) +T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]] T_NormalizedChunks = tuple[tuple[int, ...], ...] +DataVars = Mapping[Any, Any] + + ErrorOptions = Literal["raise", "ignore"] ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"] @@ -134,7 +209,7 @@ Interp1dOptions = Literal[ "linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial" ] -InterpolantOptions = Literal["barycentric", "krog", "pchip", "spline", "akima"] +InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"] InterpOptions = Union[Interp1dOptions, InterpolantOptions] DatetimeUnitOptions = Literal[ @@ -192,27 +267,18 @@ ] -if Version(np.__version__) >= Version("1.22.0"): - QuantileMethods = Literal[ - "inverted_cdf", - "averaged_inverted_cdf", - "closest_observation", - "interpolated_inverted_cdf", - "hazen", - "weibull", - "linear", - "median_unbiased", - "normal_unbiased", - "lower", - "higher", - "midpoint", - "nearest", - ] -else: - QuantileMethods = Literal[ # type: ignore[misc] - "linear", - "lower", - "higher", - "midpoint", - "nearest", - ] +QuantileMethods = Literal[ + "inverted_cdf", + "averaged_inverted_cdf", + "closest_observation", + "interpolated_inverted_cdf", + "hazen", + "weibull", + "linear", + "median_unbiased", + "normal_unbiased", + "lower", + "higher", + "midpoint", + "nearest", +] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index bd0ca57f33c..ad86b2c7fec 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -73,7 +73,7 @@ import pandas as pd if TYPE_CHECKING: - from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims + from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims, T_DuckArray K = TypeVar("K") V = TypeVar("V") @@ -253,7 +253,7 @@ def is_list_like(value: Any) -> TypeGuard[list | tuple]: return isinstance(value, (list, tuple)) -def is_duck_array(value: Any) -> bool: +def is_duck_array(value: Any) -> TypeGuard[T_DuckArray]: if isinstance(value, np.ndarray): return True return ( diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9d859c0d8a7..db109a40454 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -5,14 +5,14 @@ import math import numbers import warnings -from collections.abc import Hashable, Iterable, Mapping, Sequence +from collections.abc import Hashable, Mapping, Sequence from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, NoReturn, cast import numpy as np import pandas as pd from numpy.typing import ArrayLike -from packaging.version import Version import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils @@ -26,20 +26,16 @@ as_indexable, ) from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.parallelcompat import ( - get_chunked_array_type, - guess_chunkmanager, -) +from xarray.core.parallelcompat import get_chunked_array_type, guess_chunkmanager from xarray.core.pycompat import ( - array_type, integer_types, is_0d_dask_array, is_chunked_array, is_duck_dask_array, + to_numpy, ) +from xarray.core.types import T_Chunks from xarray.core.utils import ( - Frozen, - NdimSizeLenMixin, OrderedSet, _default, decode_numpy_dict_values, @@ -50,6 +46,7 @@ is_duck_array, maybe_coerce_to_str, ) +from xarray.namedarray.core import NamedArray NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( indexing.ExplicitlyIndexed, @@ -66,7 +63,8 @@ PadModeOptions, PadReflectOptions, QuantileMethods, - T_Variable, + Self, + T_DuckArray, ) NON_NANOSECOND_WARNING = ( @@ -86,7 +84,7 @@ class MissingDimensionsError(ValueError): # TODO: move this to an xarray.exceptions module? -def as_variable(obj, name=None) -> Variable | IndexVariable: +def as_variable(obj: T_DuckArray | Any, name=None) -> Variable | IndexVariable: """Convert an object into a Variable. Parameters @@ -142,7 +140,7 @@ def as_variable(obj, name=None) -> Variable | IndexVariable: elif isinstance(obj, (set, dict)): raise TypeError(f"variable {name!r} has invalid type {type(obj)!r}") elif name is not None: - data = as_compatible_data(obj) + data: T_DuckArray = as_compatible_data(obj) if data.ndim != 1: raise MissingDimensionsError( f"cannot set variable {name!r} with {data.ndim!r}-dimensional data " @@ -230,7 +228,9 @@ def _possibly_convert_datetime_or_timedelta_index(data): return data -def as_compatible_data(data, fastpath: bool = False): +def as_compatible_data( + data: T_DuckArray | ArrayLike, fastpath: bool = False +) -> T_DuckArray: """Prepare and wrap data to put in a Variable. - If data does not have the necessary attributes, convert it to ndarray. @@ -243,7 +243,7 @@ def as_compatible_data(data, fastpath: bool = False): """ if fastpath and getattr(data, "ndim", 0) > 0: # can't use fastpath (yet) for scalars - return _maybe_wrap_data(data) + return cast("T_DuckArray", _maybe_wrap_data(data)) from xarray.core.dataarray import DataArray @@ -252,7 +252,7 @@ def as_compatible_data(data, fastpath: bool = False): if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): data = _possibly_convert_datetime_or_timedelta_index(data) - return _maybe_wrap_data(data) + return cast("T_DuckArray", _maybe_wrap_data(data)) if isinstance(data, tuple): data = utils.to_0d_object_array(data) @@ -265,7 +265,7 @@ def as_compatible_data(data, fastpath: bool = False): data = np.timedelta64(getattr(data, "value", data), "ns") # we don't want nested self-described arrays - if isinstance(data, (pd.Series, pd.Index, pd.DataFrame)): + if isinstance(data, (pd.Series, pd.DataFrame)): data = data.values if isinstance(data, np.ma.MaskedArray): @@ -279,7 +279,7 @@ def as_compatible_data(data, fastpath: bool = False): if not isinstance(data, np.ndarray) and ( hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") ): - return data + return cast("T_DuckArray", data) # validate whether the data is valid data types. data = np.asarray(data) @@ -312,7 +312,7 @@ def _as_array_or_item(data): return data -class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic): +class Variable(NamedArray, AbstractArray, VariableArithmetic): """A netcdf-like variable consisting of dimensions, data and attributes which describe a single Array. A single Variable object is not fully described outside the context of its parent Dataset (if you want such a @@ -335,7 +335,14 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic): __slots__ = ("_dims", "_data", "_attrs", "_encoding") - def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): + def __init__( + self, + dims, + data: T_DuckArray | ArrayLike, + attrs=None, + encoding=None, + fastpath=False, + ): """ Parameters ---------- @@ -355,50 +362,32 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): Well-behaved code to serialize a Variable should ignore unrecognized encoding items. """ - self._data = as_compatible_data(data, fastpath=fastpath) - self._dims = self._parse_dimensions(dims) - self._attrs = None + super().__init__( + dims=dims, data=as_compatible_data(data, fastpath=fastpath), attrs=attrs + ) + self._encoding = None - if attrs is not None: - self.attrs = attrs if encoding is not None: self.encoding = encoding - @property - def dtype(self) -> np.dtype: - """ - Data-type of the array’s elements. - - See Also - -------- - ndarray.dtype - numpy.dtype - """ - return self._data.dtype - - @property - def shape(self) -> tuple[int, ...]: - """ - Tuple of array dimensions. - - See Also - -------- - numpy.ndarray.shape - """ - return self._data.shape + def _new( + self, + dims=_default, + data=_default, + attrs=_default, + ): + dims_ = copy.copy(self._dims) if dims is _default else dims - @property - def nbytes(self) -> int: - """ - Total bytes consumed by the elements of the data array. + if attrs is _default: + attrs_ = None if self._attrs is None else self._attrs.copy() + else: + attrs_ = attrs - If the underlying data array does not include ``nbytes``, estimates - the bytes consumed based on the ``size`` and ``dtype``. - """ - if hasattr(self._data, "nbytes"): - return self._data.nbytes + if data is _default: + return type(self)(dims_, copy.copy(self._data), attrs_) else: - return self.size * self.dtype.itemsize + cls_ = type(self) + return cls_(dims_, data, attrs_) @property def _in_memory(self): @@ -410,7 +399,7 @@ def _in_memory(self): ) @property - def data(self) -> Any: + def data(self): """ The Variable's data as an array. The underlying array type (e.g. dask, sparse, pint) is preserved. @@ -429,17 +418,13 @@ def data(self) -> Any: return self.values @data.setter - def data(self, data): + def data(self, data: T_DuckArray | ArrayLike) -> None: data = as_compatible_data(data) - if data.shape != self.shape: - raise ValueError( - f"replacement data must match the Variable's shape. " - f"replacement data has shape {data.shape}; Variable has shape {self.shape}" - ) + self._check_shape(data) self._data = data def astype( - self: T_Variable, + self, dtype, *, order=None, @@ -447,7 +432,7 @@ def astype( subok=None, copy=None, keep_attrs=True, - ) -> T_Variable: + ) -> Self: """ Copy of the Variable object, with data cast to a specified type. @@ -561,41 +546,6 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) - def __dask_tokenize__(self): - # Use v.data, instead of v._data, in order to cope with the wrappers - # around NetCDF and the like - from dask.base import normalize_token - - return normalize_token((type(self), self._dims, self.data, self._attrs)) - - def __dask_graph__(self): - if is_duck_dask_array(self._data): - return self._data.__dask_graph__() - else: - return None - - def __dask_keys__(self): - return self._data.__dask_keys__() - - def __dask_layers__(self): - return self._data.__dask_layers__() - - @property - def __dask_optimize__(self): - return self._data.__dask_optimize__ - - @property - def __dask_scheduler__(self): - return self._data.__dask_scheduler__ - - def __dask_postcompute__(self): - array_func, array_args = self._data.__dask_postcompute__() - return self._dask_finalize, (array_func,) + array_args - - def __dask_postpersist__(self): - array_func, array_args = self._data.__dask_postpersist__() - return self._dask_finalize, (array_func,) + array_args - def _dask_finalize(self, results, array_func, *args, **kwargs): data = array_func(results, *args, **kwargs) return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding) @@ -657,27 +607,6 @@ def to_dict( return item - @property - def dims(self) -> tuple[Hashable, ...]: - """Tuple of dimension names with which this variable is associated.""" - return self._dims - - @dims.setter - def dims(self, value: str | Iterable[Hashable]) -> None: - self._dims = self._parse_dimensions(value) - - def _parse_dimensions(self, dims: str | Iterable[Hashable]) -> tuple[Hashable, ...]: - if isinstance(dims, str): - dims = (dims,) - else: - dims = tuple(dims) - if len(dims) != self.ndim: - raise ValueError( - f"dimensions {dims} must have the same length as the " - f"number of data dimensions, ndim={self.ndim}" - ) - return dims - def _item_key_to_tuple(self, key): if utils.is_dict_like(key): return tuple(key.get(dim, slice(None)) for dim in self.dims) @@ -759,18 +688,18 @@ def _validate_indexers(self, key): if k.ndim > 1: raise IndexError( "Unlabeled multi-dimensional array cannot be " - "used for indexing: {}".format(k) + f"used for indexing: {k}" ) if k.dtype.kind == "b": if self.shape[self.get_axis_num(dim)] != len(k): raise IndexError( - "Boolean array size {:d} is used to index array " - "with shape {:s}.".format(len(k), str(self.shape)) + f"Boolean array size {len(k):d} is used to index array " + f"with shape {str(self.shape):s}." ) if k.ndim > 1: raise IndexError( - "{}-dimensional boolean indexing is " - "not supported. ".format(k.ndim) + f"{k.ndim}-dimensional boolean indexing is " + "not supported. " ) if is_duck_dask_array(k.data): raise KeyError( @@ -783,9 +712,7 @@ def _validate_indexers(self, key): raise IndexError( "Boolean indexer should be unlabeled or on the " "same dimension to the indexed array. Indexer is " - "on {:s} but the target dimension is {:s}.".format( - str(k.dims), dim - ) + f"on {str(k.dims):s} but the target dimension is {dim:s}." ) def _broadcast_indexes_outer(self, key): @@ -812,13 +739,6 @@ def _broadcast_indexes_outer(self, key): return dims, OuterIndexer(tuple(new_key)), None - def _nonzero(self): - """Equivalent numpy's nonzero but returns a tuple of Variables.""" - # TODO we should replace dask's native nonzero - # after https://github.com/dask/dask/issues/1076 is implemented. - nonzeros = np.nonzero(self.data) - return tuple(Variable((dim), nz) for nz, dim in zip(nonzeros, self.dims)) - def _broadcast_indexes_vectorized(self, key): variables = [] out_dims_set = OrderedSet() @@ -875,7 +795,7 @@ def _broadcast_indexes_vectorized(self, key): return out_dims, VectorizedIndexer(tuple(out_key)), new_order - def __getitem__(self: T_Variable, key) -> T_Variable: + def __getitem__(self, key) -> Self: """Return a new Variable object whose contents are consistent with getting the provided key from the underlying data. @@ -894,7 +814,7 @@ def __getitem__(self: T_Variable, key) -> T_Variable: data = np.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) - def _finalize_indexing_result(self: T_Variable, dims, data) -> T_Variable: + def _finalize_indexing_result(self, dims, data) -> Self: """Used by IndexVariable to return IndexVariable objects when possible.""" return self._replace(dims=dims, data=data) @@ -968,17 +888,6 @@ def __setitem__(self, key, value): indexable = as_indexable(self._data) indexable[index_tuple] = value - @property - def attrs(self) -> dict[Any, Any]: - """Dictionary of local attributes on this variable.""" - if self._attrs is None: - self._attrs = {} - return self._attrs - - @attrs.setter - def attrs(self, value: Mapping[Any, Any]) -> None: - self._attrs = dict(value) - @property def encoding(self) -> dict[Any, Any]: """Dictionary of encodings on this variable.""" @@ -993,91 +902,40 @@ def encoding(self, value): except ValueError: raise ValueError("encoding must be castable to a dictionary") - def reset_encoding(self: T_Variable) -> T_Variable: + def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: """Return a new Variable without encoding.""" return self._replace(encoding={}) - def copy( - self: T_Variable, deep: bool = True, data: ArrayLike | None = None - ) -> T_Variable: - """Returns a copy of this object. - - If `deep=True`, the data array is loaded into memory and copied onto - the new object. Dimensions, attributes and encodings are always copied. - - Use `data` to create a new object with the same structure as - original but entirely new data. - - Parameters - ---------- - deep : bool, default: True - Whether the data array is loaded into memory and copied onto - the new object. Default is True. - data : array_like, optional - Data to use in the new object. Must have same shape as original. - When `data` is used, `deep` is ignored. - - Returns - ------- - object : Variable - New object with dimensions, attributes, encodings, and optionally - data copied from original. - - Examples - -------- - Shallow copy versus deep copy - - >>> var = xr.Variable(data=[1, 2, 3], dims="x") - >>> var.copy() - - array([1, 2, 3]) - >>> var_0 = var.copy(deep=False) - >>> var_0[0] = 7 - >>> var_0 - - array([7, 2, 3]) - >>> var - - array([7, 2, 3]) - - Changing the data using the ``data`` argument maintains the - structure of the original object, but with the new data. Original - object is unaffected. - - >>> var.copy(data=[0.1, 0.2, 0.3]) - - array([0.1, 0.2, 0.3]) - >>> var - - array([7, 2, 3]) - - See Also - -------- - pandas.DataFrame.copy - """ - return self._copy(deep=deep, data=data) - def _copy( - self: T_Variable, + self, deep: bool = True, - data: ArrayLike | None = None, + data: T_DuckArray | ArrayLike | None = None, memo: dict[int, Any] | None = None, - ) -> T_Variable: + ) -> Self: if data is None: - ndata = self._data + data_old = self._data - if isinstance(ndata, indexing.MemoryCachedArray): + if not isinstance(data_old, indexing.MemoryCachedArray): + ndata = data_old + else: # don't share caching between copies - ndata = indexing.MemoryCachedArray(ndata.array) + # TODO: MemoryCachedArray doesn't match the array api: + ndata = indexing.MemoryCachedArray(data_old.array) # type: ignore[assignment] if deep: ndata = copy.deepcopy(ndata, memo) else: ndata = as_compatible_data(data) - if self.shape != ndata.shape: + if self.shape != ndata.shape: # type: ignore[attr-defined] raise ValueError( - f"Data shape {ndata.shape} must match shape of object {self.shape}" + f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined] ) attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs) @@ -1089,87 +947,33 @@ def _copy( return self._replace(data=ndata, attrs=attrs, encoding=encoding) def _replace( - self: T_Variable, + self, dims=_default, data=_default, attrs=_default, encoding=_default, - ) -> T_Variable: + ) -> Self: if dims is _default: dims = copy.copy(self._dims) if data is _default: data = copy.copy(self.data) if attrs is _default: attrs = copy.copy(self._attrs) + if encoding is _default: encoding = copy.copy(self._encoding) return type(self)(dims, data, attrs, encoding, fastpath=True) - def __copy__(self: T_Variable) -> T_Variable: - return self._copy(deep=False) - - def __deepcopy__( - self: T_Variable, memo: dict[int, Any] | None = None - ) -> T_Variable: - return self._copy(deep=True, memo=memo) - - # mutable objects should not be hashable - # https://github.com/python/mypy/issues/4266 - __hash__ = None # type: ignore[assignment] - - @property - def chunks(self) -> tuple[tuple[int, ...], ...] | None: - """ - Tuple of block lengths for this dataarray's data, in order of dimensions, or None if - the underlying data is not a dask array. - - See Also - -------- - Variable.chunk - Variable.chunksizes - xarray.unify_chunks - """ - return getattr(self._data, "chunks", None) - - @property - def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: - """ - Mapping from dimension names to block lengths for this variable's data, or None if - the underlying data is not a dask array. - Cannot be modified directly, but can be modified by calling .chunk(). - - Differs from variable.chunks because it returns a mapping of dimensions to chunk shapes - instead of a tuple of chunk shapes. - - See Also - -------- - Variable.chunk - Variable.chunks - xarray.unify_chunks - """ - if hasattr(self._data, "chunks"): - return Frozen({dim: c for dim, c in zip(self.dims, self.data.chunks)}) - else: - return {} - - _array_counter = itertools.count() - def chunk( self, - chunks: ( - int - | Literal["auto"] - | tuple[int, ...] - | tuple[tuple[int, ...], ...] - | Mapping[Any, None | int | tuple[int, ...]] - ) = {}, + chunks: T_Chunks = {}, name: str | None = None, lock: bool | None = None, inline_array: bool | None = None, chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, **chunks_kwargs: Any, - ) -> Variable: + ) -> Self: """Coerce this array's data into a dask array with the given chunks. If this variable is a non-dask array, it will be converted to dask @@ -1250,11 +1054,13 @@ def chunk( inline_array=inline_array, ) - data = self._data - if chunkmanager.is_chunked_array(data): - data = chunkmanager.rechunk(data, chunks) # type: ignore[arg-type] + data_old = self._data + if chunkmanager.is_chunked_array(data_old): + data_chunked = chunkmanager.rechunk(data_old, chunks) else: - if isinstance(data, indexing.ExplicitlyIndexed): + if not isinstance(data_old, indexing.ExplicitlyIndexed): + ndata = data_old + else: # Unambiguously handle array storage backends (like NetCDF4 and h5py) # that can't handle general array indexing. For example, in netCDF4 you # can do "outer" indexing along two dimensions independent, which works @@ -1263,81 +1069,37 @@ def chunk( # Using OuterIndexer is a pragmatic choice: dask does not yet handle # different indexing types in an explicit way: # https://github.com/dask/dask/issues/2883 - data = indexing.ImplicitToExplicitIndexingAdapter( - data, indexing.OuterIndexer + # TODO: ImplicitToExplicitIndexingAdapter doesn't match the array api: + ndata = indexing.ImplicitToExplicitIndexingAdapter( # type: ignore[assignment] + data_old, indexing.OuterIndexer ) if utils.is_dict_like(chunks): - chunks = tuple(chunks.get(n, s) for n, s in enumerate(data.shape)) + chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape)) - data = chunkmanager.from_array( - data, - chunks, # type: ignore[arg-type] + data_chunked = chunkmanager.from_array( + ndata, + chunks, **_from_array_kwargs, ) - return self._replace(data=data) + return self._replace(data=data_chunked) def to_numpy(self) -> np.ndarray: """Coerces wrapped data to numpy and returns a numpy.ndarray""" # TODO an entrypoint so array libraries can choose coercion method? - data = self.data - - # TODO first attempt to call .to_numpy() once some libraries implement it - if hasattr(data, "chunks"): - chunkmanager = get_chunked_array_type(data) - data, *_ = chunkmanager.compute(data) - if isinstance(data, array_type("cupy")): - data = data.get() - # pint has to be imported dynamically as pint imports xarray - if isinstance(data, array_type("pint")): - data = data.magnitude - if isinstance(data, array_type("sparse")): - data = data.todense() - data = np.asarray(data) + return to_numpy(self._data) - return data - - def as_numpy(self: T_Variable) -> T_Variable: + def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) - def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA): - """ - use sparse-array as backend. - """ - import sparse - - # TODO: what to do if dask-backended? - if fill_value is dtypes.NA: - dtype, fill_value = dtypes.maybe_promote(self.dtype) - else: - dtype = dtypes.result_type(self.dtype, fill_value) - - if sparse_format is _default: - sparse_format = "coo" - try: - as_sparse = getattr(sparse, f"as_{sparse_format.lower()}") - except AttributeError: - raise ValueError(f"{sparse_format} is not a valid sparse format") - - data = as_sparse(self.data.astype(dtype), fill_value=fill_value) - return self._replace(data=data) - - def _to_dense(self): - """ - Change backend from sparse to np.array - """ - if hasattr(self._data, "todense"): - return self._replace(data=self._data.todense()) - return self.copy(deep=False) - def isel( - self: T_Variable, + self, indexers: Mapping[Any, Any] | None = None, missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, - ) -> T_Variable: + ) -> Self: """Return a new array indexed along the specified dimension(s). Parameters @@ -1624,7 +1386,7 @@ def transpose( self, *dims: Hashable | ellipsis, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> Variable: + ) -> Self: """Return a new Variable object with transposed dimensions. Parameters @@ -1669,7 +1431,7 @@ def transpose( return self._replace(dims=dims, data=data) @property - def T(self) -> Variable: + def T(self) -> Self: return self.transpose() def set_dims(self, dims, shape=None): @@ -1743,7 +1505,9 @@ def _stack_once(self, dims: list[Hashable], new_dim: Hashable): new_data = duck_array_ops.reshape(reordered.data, new_shape) new_dims = reordered.dims[: len(other_dims)] + (new_dim,) - return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) + return type(self)( + new_dims, new_data, self._attrs, self._encoding, fastpath=True + ) def stack(self, dimensions=None, **dimensions_kwargs): """ @@ -1777,9 +1541,7 @@ def stack(self, dimensions=None, **dimensions_kwargs): result = result._stack_once(dims, new_dim) return result - def _unstack_once_full( - self, dims: Mapping[Any, int], old_dim: Hashable - ) -> Variable: + def _unstack_once_full(self, dims: Mapping[Any, int], old_dim: Hashable) -> Self: """ Unstacks the variable without needing an index. @@ -1812,7 +1574,9 @@ def _unstack_once_full( new_data = reordered.data.reshape(new_shape) new_dims = reordered.dims[: len(other_dims)] + new_dim_names - return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) + return type(self)( + new_dims, new_data, self._attrs, self._encoding, fastpath=True + ) def _unstack_once( self, @@ -1820,7 +1584,7 @@ def _unstack_once( dim: Hashable, fill_value=dtypes.NA, sparse: bool = False, - ) -> Variable: + ) -> Self: """ Unstacks this variable given an index to unstack and the name of the dimension to which the index refers. @@ -1837,15 +1601,20 @@ def _unstack_once( new_shape = tuple(list(reordered.shape[: len(other_dims)]) + new_dim_sizes) new_dims = reordered.dims[: len(other_dims)] + new_dim_names + create_template: Callable if fill_value is dtypes.NA: is_missing_values = math.prod(new_shape) > math.prod(self.shape) if is_missing_values: dtype, fill_value = dtypes.maybe_promote(self.dtype) + + create_template = partial(np.full_like, fill_value=fill_value) else: dtype = self.dtype fill_value = dtypes.get_fill_value(dtype) + create_template = np.empty_like else: dtype = self.dtype + create_template = partial(np.full_like, fill_value=fill_value) if sparse: # unstacking a dense multitindexed array to a sparse array @@ -1868,12 +1637,7 @@ def _unstack_once( ) else: - data = np.full_like( - self.data, - fill_value=fill_value, - shape=new_shape, - dtype=dtype, - ) + data = create_template(self.data, shape=new_shape, dtype=dtype) # Indexer is a list of lists of locations. Each list is the locations # on the new dimension. This is robust to the data being sparse; in that @@ -1941,7 +1705,7 @@ def clip(self, min=None, max=None): return apply_ufunc(np.clip, self, min, max, dask="allowed") - def reduce( + def reduce( # type: ignore[override] self, func: Callable[..., Any], dim: Dims = None, @@ -1982,57 +1746,21 @@ def reduce( Array with summarized data and the indicated dimension(s) removed. """ - if dim == ...: - dim = None - if dim is not None and axis is not None: - raise ValueError("cannot supply both 'axis' and 'dim' arguments") - - if dim is not None: - axis = self.get_axis_num(dim) - - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", r"Mean of empty slice", category=RuntimeWarning - ) - if axis is not None: - if isinstance(axis, tuple) and len(axis) == 1: - # unpack axis for the benefit of functions - # like np.argmin which can't handle tuple arguments - axis = axis[0] - data = func(self.data, axis=axis, **kwargs) - else: - data = func(self.data, **kwargs) - - if getattr(data, "shape", ()) == self.shape: - dims = self.dims - else: - removed_axes: Iterable[int] - if axis is None: - removed_axes = range(self.ndim) - else: - removed_axes = np.atleast_1d(axis) % self.ndim - if keepdims: - # Insert np.newaxis for removed dims - slices = tuple( - np.newaxis if i in removed_axes else slice(None, None) - for i in range(self.ndim) - ) - if getattr(data, "shape", None) is None: - # Reduce has produced a scalar value, not an array-like - data = np.asanyarray(data)[slices] - else: - data = data[slices] - dims = self.dims - else: - dims = tuple( - adim for n, adim in enumerate(self.dims) if n not in removed_axes - ) + keep_attrs_ = ( + _get_keep_attrs(default=False) if keep_attrs is None else keep_attrs + ) - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) - attrs = self._attrs if keep_attrs else None + # Noe that the call order for Variable.mean is + # Variable.mean -> NamedArray.mean -> Variable.reduce + # -> NamedArray.reduce + result = super().reduce( + func=func, dim=dim, axis=axis, keepdims=keepdims, **kwargs + ) - return Variable(dims, data, attrs=attrs) + # return Variable always to support IndexVariable + return Variable( + result.dims, result._data, attrs=result._attrs if keep_attrs_ else None + ) @classmethod def concat( @@ -2119,7 +1847,7 @@ def concat( for var in variables: if var.dims != first_var_dims: raise ValueError( - f"Variable has dimensions {list(var.dims)} but first Variable has dimensions {list(first_var_dims)}" + f"Variable has dimensions {tuple(var.dims)} but first Variable has dimensions {tuple(first_var_dims)}" ) return cls(dims, data, attrs, encoding, fastpath=True) @@ -2181,7 +1909,7 @@ def quantile( keep_attrs: bool | None = None, skipna: bool | None = None, interpolation: QuantileMethods | None = None, - ) -> Variable: + ) -> Self: """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements. @@ -2198,15 +1926,15 @@ def quantile( desired quantile lies between two data points. The options sorted by their R type as summarized in the H&F paper [1]_ are: - 1. "inverted_cdf" (*) - 2. "averaged_inverted_cdf" (*) - 3. "closest_observation" (*) - 4. "interpolated_inverted_cdf" (*) - 5. "hazen" (*) - 6. "weibull" (*) + 1. "inverted_cdf" + 2. "averaged_inverted_cdf" + 3. "closest_observation" + 4. "interpolated_inverted_cdf" + 5. "hazen" + 6. "weibull" 7. "linear" (default) - 8. "median_unbiased" (*) - 9. "normal_unbiased" (*) + 8. "median_unbiased" + 9. "normal_unbiased" The first three methods are discontiuous. The following discontinuous variations of the default "linear" (7.) option are also available: @@ -2220,8 +1948,6 @@ def quantile( was previously called "interpolation", renamed in accordance with numpy version 1.22.0. - (*) These methods require numpy version 1.22 or newer. - keep_attrs : bool, optional If True, the variable's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new @@ -2289,14 +2015,7 @@ def _wrapper(npa, **kwargs): axis = np.arange(-1, -1 * len(dim) - 1, -1) - if Version(np.__version__) >= Version("1.22.0"): - kwargs = {"q": q, "axis": axis, "method": method} - else: - if method not in ("linear", "lower", "higher", "midpoint", "nearest"): - raise ValueError( - f"Interpolation method '{method}' requires numpy >= 1.22 or is not supported." - ) - kwargs = {"q": q, "axis": axis, "interpolation": method} + kwargs = {"q": q, "axis": axis, "method": method} result = apply_ufunc( _wrapper, @@ -2443,7 +2162,7 @@ def rolling_window( raise ValueError( f"Expected {name}={arg!r} to be a scalar like 'dim'." ) - dim = [dim] + dim = (dim,) # dim is now a list nroll = len(dim) @@ -2474,7 +2193,7 @@ def rolling_window( pads[d] = (win - 1, 0) padded = var.pad(pads, mode="constant", constant_values=fill_value) - axis = tuple(self.get_axis_num(d) for d in dim) + axis = self.get_axis_num(dim) new_dims = self.dims + tuple(window_dim) return Variable( new_dims, @@ -2559,8 +2278,8 @@ def coarsen_reshape(self, windows, boundary, side): variable = variable.pad(pad_width, mode="constant") else: raise TypeError( - "{} is invalid for boundary. Valid option is 'exact', " - "'trim' and 'pad'".format(boundary[d]) + f"{boundary[d]} is invalid for boundary. Valid option is 'exact', " + "'trim' and 'pad'" ) shape = [] @@ -2576,7 +2295,7 @@ def coarsen_reshape(self, windows, boundary, side): else: shape.append(variable.shape[i]) - return variable.data.reshape(shape), tuple(axes) + return duck_array_ops.reshape(variable.data, shape), tuple(axes) def isnull(self, keep_attrs: bool | None = None): """Test each value in the array for whether it is a missing value. @@ -2647,26 +2366,26 @@ def notnull(self, keep_attrs: bool | None = None): ) @property - def real(self): + def imag(self) -> Variable: """ - The real part of the variable. + The imaginary part of the variable. See Also -------- - numpy.ndarray.real + numpy.ndarray.imag """ - return self._replace(data=self.data.real) + return self._new(data=self.data.imag) @property - def imag(self): + def real(self) -> Variable: """ - The imaginary part of the variable. + The real part of the variable. See Also -------- - numpy.ndarray.imag + numpy.ndarray.real """ - return self._replace(data=self.data.imag) + return self._new(data=self.data.real) def __array_wrap__(self, obj, context=None): return Variable(self.dims, obj) @@ -2876,6 +2595,28 @@ def argmax( """ return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) + def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable: + """ + Use sparse-array as backend. + """ + from xarray.namedarray.utils import _default as _default_named + + if sparse_format is _default: + sparse_format = _default_named + + if fill_value is _default: + fill_value = _default_named + + out = super()._as_sparse(sparse_format, fill_value) + return cast("Variable", out) + + def _to_dense(self) -> Variable: + """ + Change backend from sparse to np.array. + """ + out = super()._to_dense() + return cast("Variable", out) + class IndexVariable(Variable): """Wrapper for accommodating a pandas.Index in an xarray.Variable. @@ -2890,6 +2631,9 @@ class IndexVariable(Variable): __slots__ = () + # TODO: PandasIndexingAdapter doesn't match the array api: + _data: PandasIndexingAdapter # type: ignore[assignment] + def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): super().__init__(dims, data, attrs, encoding, fastpath) if self.ndim != 1: @@ -3006,7 +2750,7 @@ def concat( return cls(first_var.dims, data, attrs) - def copy(self, deep: bool = True, data: ArrayLike | None = None): + def copy(self, deep: bool = True, data: T_DuckArray | ArrayLike | None = None): """Returns a copy of this object. `deep` is ignored since data is stored in the form of @@ -3031,12 +2775,16 @@ def copy(self, deep: bool = True, data: ArrayLike | None = None): data copied from original. """ if data is None: - ndata = self._data.copy(deep=deep) + ndata = self._data + + if deep: + ndata = copy.deepcopy(ndata, None) + else: ndata = as_compatible_data(data) - if self.shape != ndata.shape: + if self.shape != ndata.shape: # type: ignore[attr-defined] raise ValueError( - f"Data shape {ndata.shape} must match shape of object {self.shape}" + f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined] ) attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs) @@ -3126,10 +2874,6 @@ def _inplace_binary_op(self, other, f): ) -# for backwards compatibility -Coordinate = utils.alias(IndexVariable, "Coordinate") - - def _unified_dims(variables): # validate dimensions all_dims = {} diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index e21091fad6b..28740a99020 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -11,6 +11,7 @@ from xarray.core.computation import apply_ufunc, dot from xarray.core.pycompat import is_duck_dask_array from xarray.core.types import Dims, T_Xarray +from xarray.util.deprecation_helpers import _deprecate_positional_args # Weighted quantile methods are a subset of the numpy supported quantile methods. QUANTILE_METHODS = Literal[ @@ -198,10 +199,11 @@ def _check_dim(self, dim: Dims): dims = [dim] if dim else [] else: dims = list(dim) - missing_dims = set(dims) - set(self.obj.dims) - set(self.weights.dims) + all_dims = set(self.obj.dims).union(set(self.weights.dims)) + missing_dims = set(dims) - all_dims if missing_dims: raise ValueError( - f"{self.__class__.__name__} does not contain the dimensions: {missing_dims}" + f"Dimensions {tuple(missing_dims)} not found in {self.__class__.__name__} dimensions {tuple(all_dims)}" ) @staticmethod @@ -323,6 +325,7 @@ def _weighted_quantile( def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray: """Return the interpolation parameter.""" # Note that options are not yet exposed in the public API. + h: np.ndarray if method == "linear": h = (n - 1) * q + 1 elif method == "interpolated_inverted_cdf": @@ -448,18 +451,22 @@ def _weighted_quantile_1d( def _implementation(self, func, dim, **kwargs): raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`") + @_deprecate_positional_args("v2023.10.0") def sum_of_weights( self, dim: Dims = None, + *, keep_attrs: bool | None = None, ) -> T_Xarray: return self._implementation( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def sum_of_squares( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -467,9 +474,11 @@ def sum_of_squares( self._sum_of_squares, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def sum( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -477,9 +486,11 @@ def sum( self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def mean( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -487,9 +498,11 @@ def mean( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def var( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -497,9 +510,11 @@ def var( self._weighted_var, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def std( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: diff --git a/xarray/namedarray/__init__.py b/xarray/namedarray/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/xarray/namedarray/_aggregations.py b/xarray/namedarray/_aggregations.py new file mode 100644 index 00000000000..76dfb18d068 --- /dev/null +++ b/xarray/namedarray/_aggregations.py @@ -0,0 +1,949 @@ +"""Mixin classes with reduction operations.""" +# This file was generated using xarray.util.generate_aggregations. Do not edit manually. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, Callable + +from xarray.core import duck_array_ops +from xarray.core.types import Dims, Self + + +class NamedArrayAggregations: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> Self: + raise NotImplementedError() + + def count( + self, + dim: Dims = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``count`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``count`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + pandas.DataFrame.count + dask.dataframe.DataFrame.count + Dataset.count + DataArray.count + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.count() + + array(5) + """ + return self.reduce( + duck_array_ops.count, + dim=dim, + **kwargs, + ) + + def all( + self, + dim: Dims = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``all`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``all`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.all + dask.array.all + Dataset.all + DataArray.all + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + >>> na + + array([ True, True, True, True, True, False]) + + >>> na.all() + + array(False) + """ + return self.reduce( + duck_array_ops.array_all, + dim=dim, + **kwargs, + ) + + def any( + self, + dim: Dims = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``any`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``any`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.any + dask.array.any + Dataset.any + DataArray.any + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + >>> na + + array([ True, True, True, True, True, False]) + + >>> na.any() + + array(True) + """ + return self.reduce( + duck_array_ops.array_any, + dim=dim, + **kwargs, + ) + + def max( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``max`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``max`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.max + dask.array.max + Dataset.max + DataArray.max + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.max() + + array(3.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.max(skipna=False) + + array(nan) + """ + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def min( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``min`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``min`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.min + dask.array.min + Dataset.min + DataArray.min + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.min() + + array(0.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.min(skipna=False) + + array(nan) + """ + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def mean( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``mean`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``mean`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.mean + dask.array.mean + Dataset.mean + DataArray.mean + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.mean() + + array(1.6) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.mean(skipna=False) + + array(nan) + """ + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def prod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``prod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``prod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.prod + dask.array.prod + Dataset.prod + DataArray.prod + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.prod() + + array(0.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.prod(skipna=False) + + array(nan) + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> na.prod(skipna=True, min_count=2) + + array(0.) + """ + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + **kwargs, + ) + + def sum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``sum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``sum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.sum + dask.array.sum + Dataset.sum + DataArray.sum + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.sum() + + array(8.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.sum(skipna=False) + + array(nan) + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> na.sum(skipna=True, min_count=2) + + array(8.) + """ + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + **kwargs, + ) + + def std( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``std`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``std`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.std + dask.array.std + Dataset.std + DataArray.std + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.std() + + array(1.0198039) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.std(skipna=False) + + array(nan) + + Specify ``ddof=1`` for an unbiased estimate. + + >>> na.std(skipna=True, ddof=1) + + array(1.14017543) + """ + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + **kwargs, + ) + + def var( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``var`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``var`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.var + dask.array.var + Dataset.var + DataArray.var + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.var() + + array(1.04) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.var(skipna=False) + + array(nan) + + Specify ``ddof=1`` for an unbiased estimate. + + >>> na.var(skipna=True, ddof=1) + + array(1.3) + """ + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + **kwargs, + ) + + def median( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``median`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``median`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.median + dask.array.median + Dataset.median + DataArray.median + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.median() + + array(2.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.median(skipna=False) + + array(nan) + """ + return self.reduce( + duck_array_ops.median, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def cumsum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``cumsum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumsum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumsum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``cumsum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumsum + dask.array.cumsum + Dataset.cumsum + DataArray.cumsum + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.cumsum() + + array([1., 3., 6., 6., 8., 8.]) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.cumsum(skipna=False) + + array([ 1., 3., 6., 6., 8., nan]) + """ + return self.reduce( + duck_array_ops.cumsum, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def cumprod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``cumprod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumprod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumprod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``cumprod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumprod + dask.array.cumprod + Dataset.cumprod + DataArray.cumprod + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.cumprod() + + array([1., 2., 6., 0., 0., 0.]) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.cumprod(skipna=False) + + array([ 1., 2., 6., 0., 0., nan]) + """ + return self.reduce( + duck_array_ops.cumprod, + dim=dim, + skipna=skipna, + **kwargs, + ) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py new file mode 100644 index 00000000000..e205c4d4efe --- /dev/null +++ b/xarray/namedarray/_array_api.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import warnings +from types import ModuleType +from typing import Any + +import numpy as np + +from xarray.namedarray._typing import ( + _arrayapi, + _DType, + _ScalarType, + _ShapeType, + _SupportsImag, + _SupportsReal, +) +from xarray.namedarray.core import NamedArray + +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + r"The numpy.array_api submodule is still experimental", + category=UserWarning, + ) + import numpy.array_api as nxp # noqa: F401 + + +def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: + if isinstance(x._data, _arrayapi): + return x._data.__array_namespace__() + + return np + + +# %% Creation Functions + + +def astype( + x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True +) -> NamedArray[_ShapeType, _DType]: + """ + Copies an array to a specified data type irrespective of Type Promotion Rules rules. + + Parameters + ---------- + x : NamedArray + Array to cast. + dtype : _DType + Desired data type. + copy : bool, optional + Specifies whether to copy an array when the specified dtype matches the data + type of the input array x. + If True, a newly allocated array must always be returned. + If False and the specified dtype matches the data type of the input array, + the input array must be returned; otherwise, a newly allocated array must be + returned. Default: True. + + Returns + ------- + out : NamedArray + An array having the specified data type. The returned array must have the + same shape as x. + + Examples + -------- + >>> narr = NamedArray(("x",), nxp.asarray([1.5, 2.5])) + >>> narr + + Array([1.5, 2.5], dtype=float64) + >>> astype(narr, np.dtype(np.int32)) + + Array([1, 2], dtype=int32) + """ + if isinstance(x._data, _arrayapi): + xp = x._data.__array_namespace__() + return x._new(data=xp.astype(x._data, dtype, copy=copy)) + + # np.astype doesn't exist yet: + return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined] + + +# %% Elementwise Functions + + +def imag( + x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], / # type: ignore[type-var] +) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: + """ + Returns the imaginary component of a complex number for each element x_i of the + input array x. + + Parameters + ---------- + x : NamedArray + Input array. Should have a complex floating-point data type. + + Returns + ------- + out : NamedArray + An array containing the element-wise results. The returned array must have a + floating-point data type with the same floating-point precision as x + (e.g., if x is complex64, the returned array must have the floating-point + data type float32). + + Examples + -------- + >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp + >>> imag(narr) + + array([2., 4.]) + """ + xp = _get_data_namespace(x) + out = x._new(data=xp.imag(x._data)) + return out + + +def real( + x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], / # type: ignore[type-var] +) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: + """ + Returns the real component of a complex number for each element x_i of the + input array x. + + Parameters + ---------- + x : NamedArray + Input array. Should have a complex floating-point data type. + + Returns + ------- + out : NamedArray + An array containing the element-wise results. The returned array must have a + floating-point data type with the same floating-point precision as x + (e.g., if x is complex64, the returned array must have the floating-point + data type float32). + + Examples + -------- + >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp + >>> real(narr) + + array([1., 2.]) + """ + xp = _get_data_namespace(x) + out = x._new(data=xp.real(x._data)) + return out diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py new file mode 100644 index 00000000000..0b972e19539 --- /dev/null +++ b/xarray/namedarray/_typing.py @@ -0,0 +1,270 @@ +from __future__ import annotations + +from collections.abc import Hashable, Iterable, Mapping, Sequence +from types import ModuleType +from typing import ( + Any, + Callable, + Protocol, + SupportsIndex, + TypeVar, + Union, + overload, + runtime_checkable, +) + +import numpy as np + +# https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + + +_DType = TypeVar("_DType", bound=np.dtype[Any]) +_DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any]) +# A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic` + +_ScalarType = TypeVar("_ScalarType", bound=np.generic) +_ScalarType_co = TypeVar("_ScalarType_co", bound=np.generic, covariant=True) + + +# A protocol for anything with the dtype attribute +@runtime_checkable +class _SupportsDType(Protocol[_DType_co]): + @property + def dtype(self) -> _DType_co: + ... + + +_DTypeLike = Union[ + np.dtype[_ScalarType], + type[_ScalarType], + _SupportsDType[np.dtype[_ScalarType]], +] + +# For unknown shapes Dask uses np.nan, array_api uses None: +_IntOrUnknown = int +_Shape = tuple[_IntOrUnknown, ...] +_ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]] +_ShapeType = TypeVar("_ShapeType", bound=Any) +_ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True) + +_Chunks = tuple[_Shape, ...] + +_Dim = Hashable +_Dims = tuple[_Dim, ...] + +_DimsLike = Union[str, Iterable[_Dim]] +_AttrsLike = Union[Mapping[Any, Any], None] + +_dtype = np.dtype + + +class _SupportsReal(Protocol[_T_co]): + @property + def real(self) -> _T_co: + ... + + +class _SupportsImag(Protocol[_T_co]): + @property + def imag(self) -> _T_co: + ... + + +@runtime_checkable +class _array(Protocol[_ShapeType_co, _DType_co]): + """ + Minimal duck array named array uses. + + Corresponds to np.ndarray. + """ + + @property + def shape(self) -> _Shape: + ... + + @property + def dtype(self) -> _DType_co: + ... + + +@runtime_checkable +class _arrayfunction( + _array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Duck array supporting NEP 18. + + Corresponds to np.ndarray. + """ + + @overload + def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]: + ... + + @overload + def __array__(self, dtype: _DType, /) -> np.ndarray[Any, _DType]: + ... + + def __array__( + self, dtype: _DType | None = ..., / + ) -> np.ndarray[Any, _DType] | np.ndarray[Any, _DType_co]: + ... + + # TODO: Should return the same subclass but with a new dtype generic. + # https://github.com/python/typing/issues/548 + def __array_ufunc__( + self, + ufunc: Any, + method: Any, + *inputs: Any, + **kwargs: Any, + ) -> Any: + ... + + # TODO: Should return the same subclass but with a new dtype generic. + # https://github.com/python/typing/issues/548 + def __array_function__( + self, + func: Callable[..., Any], + types: Iterable[type], + args: Iterable[Any], + kwargs: Mapping[str, Any], + ) -> Any: + ... + + @property + def imag(self) -> _arrayfunction[_ShapeType_co, Any]: + ... + + @property + def real(self) -> _arrayfunction[_ShapeType_co, Any]: + ... + + +@runtime_checkable +class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co]): + """ + Duck array supporting NEP 47. + + Corresponds to np.ndarray. + """ + + def __array_namespace__(self) -> ModuleType: + ... + + +# NamedArray can most likely use both __array_function__ and __array_namespace__: +_arrayfunction_or_api = (_arrayfunction, _arrayapi) + +duckarray = Union[ + _arrayfunction[_ShapeType_co, _DType_co], _arrayapi[_ShapeType_co, _DType_co] +] + +# Corresponds to np.typing.NDArray: +DuckArray = _arrayfunction[Any, np.dtype[_ScalarType_co]] + + +@runtime_checkable +class _chunkedarray( + _array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Minimal chunked duck array. + + Corresponds to np.ndarray. + """ + + @property + def chunks(self) -> _Chunks: + ... + + +@runtime_checkable +class _chunkedarrayfunction( + _arrayfunction[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Chunked duck array supporting NEP 18. + + Corresponds to np.ndarray. + """ + + @property + def chunks(self) -> _Chunks: + ... + + +@runtime_checkable +class _chunkedarrayapi( + _arrayapi[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Chunked duck array supporting NEP 47. + + Corresponds to np.ndarray. + """ + + @property + def chunks(self) -> _Chunks: + ... + + +# NamedArray can most likely use both __array_function__ and __array_namespace__: +_chunkedarrayfunction_or_api = (_chunkedarrayfunction, _chunkedarrayapi) +chunkedduckarray = Union[ + _chunkedarrayfunction[_ShapeType_co, _DType_co], + _chunkedarrayapi[_ShapeType_co, _DType_co], +] + + +@runtime_checkable +class _sparsearray( + _array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Minimal sparse duck array. + + Corresponds to np.ndarray. + """ + + def todense(self) -> np.ndarray[Any, _DType_co]: + ... + + +@runtime_checkable +class _sparsearrayfunction( + _arrayfunction[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Sparse duck array supporting NEP 18. + + Corresponds to np.ndarray. + """ + + def todense(self) -> np.ndarray[Any, _DType_co]: + ... + + +@runtime_checkable +class _sparsearrayapi( + _arrayapi[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Sparse duck array supporting NEP 47. + + Corresponds to np.ndarray. + """ + + def todense(self) -> np.ndarray[Any, _DType_co]: + ... + + +# NamedArray can most likely use both __array_function__ and __array_namespace__: +_sparsearrayfunction_or_api = (_sparsearrayfunction, _sparsearrayapi) + +sparseduckarray = Union[ + _sparsearrayfunction[_ShapeType_co, _DType_co], + _sparsearrayapi[_ShapeType_co, _DType_co], +] diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py new file mode 100644 index 00000000000..2fef1cad3db --- /dev/null +++ b/xarray/namedarray/core.py @@ -0,0 +1,849 @@ +from __future__ import annotations + +import copy +import math +import sys +import warnings +from collections.abc import Hashable, Iterable, Mapping, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + TypeVar, + cast, + overload, +) + +import numpy as np + +# TODO: get rid of this after migrating this class to array API +from xarray.core import dtypes, formatting, formatting_html +from xarray.namedarray._aggregations import NamedArrayAggregations +from xarray.namedarray._typing import ( + _arrayapi, + _arrayfunction_or_api, + _chunkedarray, + _dtype, + _DType_co, + _ScalarType_co, + _ShapeType_co, + _sparsearrayfunction_or_api, + _SupportsImag, + _SupportsReal, +) +from xarray.namedarray.utils import _default, is_duck_dask_array, to_0d_object_array + +if TYPE_CHECKING: + from numpy.typing import ArrayLike, NDArray + + from xarray.core.types import Dims + from xarray.namedarray._typing import ( + DuckArray, + _AttrsLike, + _Chunks, + _Dim, + _Dims, + _DimsLike, + _DType, + _IntOrUnknown, + _ScalarType, + _Shape, + _ShapeType, + duckarray, + ) + from xarray.namedarray.utils import Default + + try: + from dask.typing import ( + Graph, + NestedKeys, + PostComputeCallable, + PostPersistCallable, + SchedulerGetCallable, + ) + except ImportError: + Graph: Any # type: ignore[no-redef] + NestedKeys: Any # type: ignore[no-redef] + SchedulerGetCallable: Any # type: ignore[no-redef] + PostComputeCallable: Any # type: ignore[no-redef] + PostPersistCallable: Any # type: ignore[no-redef] + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + + T_NamedArray = TypeVar("T_NamedArray", bound="_NamedArray[Any]") + T_NamedArrayInteger = TypeVar( + "T_NamedArrayInteger", bound="_NamedArray[np.integer[Any]]" + ) + + +@overload +def _new( + x: NamedArray[Any, _DType_co], + dims: _DimsLike | Default = ..., + data: duckarray[_ShapeType, _DType] = ..., + attrs: _AttrsLike | Default = ..., +) -> NamedArray[_ShapeType, _DType]: + ... + + +@overload +def _new( + x: NamedArray[_ShapeType_co, _DType_co], + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., +) -> NamedArray[_ShapeType_co, _DType_co]: + ... + + +def _new( + x: NamedArray[Any, _DType_co], + dims: _DimsLike | Default = _default, + data: duckarray[_ShapeType, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, +) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, _DType_co]: + """ + Create a new array with new typing information. + + Parameters + ---------- + x : NamedArray + Array to create a new array from + dims : Iterable of Hashable, optional + Name(s) of the dimension(s). + Will copy the dims from x by default. + data : duckarray, optional + The actual data that populates the array. Should match the + shape specified by `dims`. + Will copy the data from x by default. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Will copy the attrs from x by default. + """ + dims_ = copy.copy(x._dims) if dims is _default else dims + + attrs_: Mapping[Any, Any] | None + if attrs is _default: + attrs_ = None if x._attrs is None else x._attrs.copy() + else: + attrs_ = attrs + + if data is _default: + return type(x)(dims_, copy.copy(x._data), attrs_) + else: + cls_ = cast("type[NamedArray[_ShapeType, _DType]]", type(x)) + return cls_(dims_, data, attrs_) + + +@overload +def from_array( + dims: _DimsLike, + data: DuckArray[_ScalarType], + attrs: _AttrsLike = ..., +) -> _NamedArray[_ScalarType]: + ... + + +@overload +def from_array( + dims: _DimsLike, + data: ArrayLike, + attrs: _AttrsLike = ..., +) -> _NamedArray[Any]: + ... + + +def from_array( + dims: _DimsLike, + data: DuckArray[_ScalarType] | ArrayLike, + attrs: _AttrsLike = None, +) -> _NamedArray[_ScalarType] | _NamedArray[Any]: + """ + Create a Named array from an array-like object. + + Parameters + ---------- + dims : str or iterable of str + Name(s) of the dimension(s). + data : T_DuckArray or ArrayLike + The actual data that populates the array. Should match the + shape specified by `dims`. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Default is None, meaning no attributes will be stored. + """ + if isinstance(data, NamedArray): + raise TypeError( + "Array is already a Named array. Use 'data.data' to retrieve the data array" + ) + + # TODO: dask.array.ma.masked_array also exists, better way? + if isinstance(data, np.ma.MaskedArray): + mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call] + if mask.any(): + # TODO: requires refactoring/vendoring xarray.core.dtypes and + # xarray.core.duck_array_ops + raise NotImplementedError("MaskedArray is not supported yet") + + return NamedArray(dims, data, attrs) + + if isinstance(data, _arrayfunction_or_api): + return NamedArray(dims, data, attrs) + + if isinstance(data, tuple): + return NamedArray(dims, to_0d_object_array(data), attrs) + + # validate whether the data is valid data types. + return NamedArray(dims, np.asarray(data), attrs) + + +class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]): + """ + A wrapper around duck arrays with named dimensions + and attributes which describe a single Array. + Numeric operations on this object implement array broadcasting and + dimension alignment based on dimension names, + rather than axis order. + + + Parameters + ---------- + dims : str or iterable of hashable + Name(s) of the dimension(s). + data : array-like or duck-array + The actual data that populates the array. Should match the + shape specified by `dims`. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Default is None, meaning no attributes will be stored. + + Raises + ------ + ValueError + If the `dims` length does not match the number of data dimensions (ndim). + + + Examples + -------- + >>> data = np.array([1.5, 2, 3], dtype=float) + >>> narr = NamedArray(("x",), data, {"units": "m"}) # TODO: Better name than narr? + """ + + __slots__ = ("_data", "_dims", "_attrs") + + _data: duckarray[Any, _DType_co] + _dims: _Dims + _attrs: dict[Any, Any] | None + + def __init__( + self, + dims: _DimsLike, + data: duckarray[Any, _DType_co], + attrs: _AttrsLike = None, + ): + self._data = data + self._dims = self._parse_dimensions(dims) + self._attrs = dict(attrs) if attrs else None + + def __init_subclass__(cls, **kwargs: Any) -> None: + if NamedArray in cls.__bases__ and (cls._new == NamedArray._new): + # Type hinting does not work for subclasses unless _new is + # overridden with the correct class. + raise TypeError( + "Subclasses of `NamedArray` must override the `_new` method." + ) + super().__init_subclass__(**kwargs) + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: duckarray[_ShapeType, _DType] = ..., + attrs: _AttrsLike | Default = ..., + ) -> NamedArray[_ShapeType, _DType]: + ... + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., + ) -> NamedArray[_ShapeType_co, _DType_co]: + ... + + def _new( + self, + dims: _DimsLike | Default = _default, + data: duckarray[Any, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> NamedArray[_ShapeType, _DType] | NamedArray[_ShapeType_co, _DType_co]: + """ + Create a new array with new typing information. + + _new has to be reimplemented each time NamedArray is subclassed, + otherwise type hints will not be correct. The same is likely true + for methods that relied on _new. + + Parameters + ---------- + dims : Iterable of Hashable, optional + Name(s) of the dimension(s). + Will copy the dims from x by default. + data : duckarray, optional + The actual data that populates the array. Should match the + shape specified by `dims`. + Will copy the data from x by default. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Will copy the attrs from x by default. + """ + return _new(self, dims, data, attrs) + + def _replace( + self, + dims: _DimsLike | Default = _default, + data: duckarray[_ShapeType_co, _DType_co] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> Self: + """ + Create a new array with the same typing information. + + The types for each argument cannot change, + use self._new if that is a risk. + + Parameters + ---------- + dims : Iterable of Hashable, optional + Name(s) of the dimension(s). + Will copy the dims from x by default. + data : duckarray, optional + The actual data that populates the array. Should match the + shape specified by `dims`. + Will copy the data from x by default. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Will copy the attrs from x by default. + """ + return cast("Self", self._new(dims, data, attrs)) + + def _copy( + self, + deep: bool = True, + data: duckarray[_ShapeType_co, _DType_co] | None = None, + memo: dict[int, Any] | None = None, + ) -> Self: + if data is None: + ndata = self._data + if deep: + ndata = copy.deepcopy(ndata, memo=memo) + else: + ndata = data + self._check_shape(ndata) + + attrs = ( + copy.deepcopy(self._attrs, memo=memo) if deep else copy.copy(self._attrs) + ) + + return self._replace(data=ndata, attrs=attrs) + + def __copy__(self) -> Self: + return self._copy(deep=False) + + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: + return self._copy(deep=True, memo=memo) + + def copy( + self, + deep: bool = True, + data: duckarray[_ShapeType_co, _DType_co] | None = None, + ) -> Self: + """Returns a copy of this object. + + If `deep=True`, the data array is loaded into memory and copied onto + the new object. Dimensions, attributes and encodings are always copied. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, default: True + Whether the data array is loaded into memory and copied onto + the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored. + + Returns + ------- + object : NamedArray + New object with dimensions, attributes, and optionally + data copied from original. + + + """ + return self._copy(deep=deep, data=data) + + @property + def ndim(self) -> int: + """ + Number of array dimensions. + + See Also + -------- + numpy.ndarray.ndim + """ + return len(self.shape) + + @property + def size(self) -> _IntOrUnknown: + """ + Number of elements in the array. + + Equal to ``np.prod(a.shape)``, i.e., the product of the array’s dimensions. + + See Also + -------- + numpy.ndarray.size + """ + return math.prod(self.shape) + + def __len__(self) -> _IntOrUnknown: + try: + return self.shape[0] + except Exception as exc: + raise TypeError("len() of unsized object") from exc + + @property + def dtype(self) -> _DType_co: + """ + Data-type of the array’s elements. + + See Also + -------- + ndarray.dtype + numpy.dtype + """ + return self._data.dtype + + @property + def shape(self) -> _Shape: + """ + Get the shape of the array. + + Returns + ------- + shape : tuple of ints + Tuple of array dimensions. + + See Also + -------- + numpy.ndarray.shape + """ + return self._data.shape + + @property + def nbytes(self) -> _IntOrUnknown: + """ + Total bytes consumed by the elements of the data array. + + If the underlying data array does not include ``nbytes``, estimates + the bytes consumed based on the ``size`` and ``dtype``. + """ + if hasattr(self._data, "nbytes"): + return self._data.nbytes # type: ignore[no-any-return] + else: + return self.size * self.dtype.itemsize + + @property + def dims(self) -> _Dims: + """Tuple of dimension names with which this NamedArray is associated.""" + return self._dims + + @dims.setter + def dims(self, value: _DimsLike) -> None: + self._dims = self._parse_dimensions(value) + + def _parse_dimensions(self, dims: _DimsLike) -> _Dims: + dims = (dims,) if isinstance(dims, str) else tuple(dims) + if len(dims) != self.ndim: + raise ValueError( + f"dimensions {dims} must have the same length as the " + f"number of data dimensions, ndim={self.ndim}" + ) + return dims + + @property + def attrs(self) -> dict[Any, Any]: + """Dictionary of local attributes on this NamedArray.""" + if self._attrs is None: + self._attrs = {} + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping[Any, Any]) -> None: + self._attrs = dict(value) + + def _check_shape(self, new_data: duckarray[Any, _DType_co]) -> None: + if new_data.shape != self.shape: + raise ValueError( + f"replacement data must match the {self.__class__.__name__}'s shape. " + f"replacement data has shape {new_data.shape}; {self.__class__.__name__} has shape {self.shape}" + ) + + @property + def data(self) -> duckarray[Any, _DType_co]: + """ + The NamedArray's data as an array. The underlying array type + (e.g. dask, sparse, pint) is preserved. + + """ + + return self._data + + @data.setter + def data(self, data: duckarray[Any, _DType_co]) -> None: + self._check_shape(data) + self._data = data + + @property + def imag( + self: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var] + ) -> NamedArray[_ShapeType, _dtype[_ScalarType]]: + """ + The imaginary part of the array. + + See Also + -------- + numpy.ndarray.imag + """ + if isinstance(self._data, _arrayapi): + from xarray.namedarray._array_api import imag + + return imag(self) + + return self._new(data=self._data.imag) + + @property + def real( + self: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var] + ) -> NamedArray[_ShapeType, _dtype[_ScalarType]]: + """ + The real part of the array. + + See Also + -------- + numpy.ndarray.real + """ + if isinstance(self._data, _arrayapi): + from xarray.namedarray._array_api import real + + return real(self) + return self._new(data=self._data.real) + + def __dask_tokenize__(self) -> Hashable: + # Use v.data, instead of v._data, in order to cope with the wrappers + # around NetCDF and the like + from dask.base import normalize_token + + s, d, a, attrs = type(self), self._dims, self.data, self.attrs + return normalize_token((s, d, a, attrs)) # type: ignore[no-any-return] + + def __dask_graph__(self) -> Graph | None: + if is_duck_dask_array(self._data): + return self._data.__dask_graph__() + else: + # TODO: Should this method just raise instead? + # raise NotImplementedError("Method requires self.data to be a dask array") + return None + + def __dask_keys__(self) -> NestedKeys: + if is_duck_dask_array(self._data): + return self._data.__dask_keys__() + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def __dask_layers__(self) -> Sequence[str]: + if is_duck_dask_array(self._data): + return self._data.__dask_layers__() + else: + raise AttributeError("Method requires self.data to be a dask array.") + + @property + def __dask_optimize__( + self, + ) -> Callable[..., dict[Any, Any]]: + if is_duck_dask_array(self._data): + return self._data.__dask_optimize__ # type: ignore[no-any-return] + else: + raise AttributeError("Method requires self.data to be a dask array.") + + @property + def __dask_scheduler__(self) -> SchedulerGetCallable: + if is_duck_dask_array(self._data): + return self._data.__dask_scheduler__ + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def __dask_postcompute__( + self, + ) -> tuple[PostComputeCallable, tuple[Any, ...]]: + if is_duck_dask_array(self._data): + array_func, array_args = self._data.__dask_postcompute__() # type: ignore[no-untyped-call] + return self._dask_finalize, (array_func,) + array_args + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def __dask_postpersist__( + self, + ) -> tuple[ + Callable[ + [Graph, PostPersistCallable[Any], Any, Any], + Self, + ], + tuple[Any, ...], + ]: + if is_duck_dask_array(self._data): + a: tuple[PostPersistCallable[Any], tuple[Any, ...]] + a = self._data.__dask_postpersist__() # type: ignore[no-untyped-call] + array_func, array_args = a + + return self._dask_finalize, (array_func,) + array_args + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def _dask_finalize( + self, + results: Graph, + array_func: PostPersistCallable[Any], + *args: Any, + **kwargs: Any, + ) -> Self: + data = array_func(results, *args, **kwargs) + return type(self)(self._dims, data, attrs=self._attrs) + + def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]: + """Return axis number(s) corresponding to dimension(s) in this array. + + Parameters + ---------- + dim : str or iterable of str + Dimension name(s) for which to lookup axes. + + Returns + ------- + int or tuple of int + Axis number or numbers corresponding to the given dimensions. + """ + if not isinstance(dim, str) and isinstance(dim, Iterable): + return tuple(self._get_axis_num(d) for d in dim) + else: + return self._get_axis_num(dim) + + def _get_axis_num(self: Any, dim: Hashable) -> int: + try: + return self.dims.index(dim) # type: ignore[no-any-return] + except ValueError: + raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") + + @property + def chunks(self) -> _Chunks | None: + """ + Tuple of block lengths for this NamedArray's data, in order of dimensions, or None if + the underlying data is not a dask array. + + See Also + -------- + NamedArray.chunk + NamedArray.chunksizes + xarray.unify_chunks + """ + data = self._data + if isinstance(data, _chunkedarray): + return data.chunks + else: + return None + + @property + def chunksizes( + self, + ) -> Mapping[_Dim, _Shape]: + """ + Mapping from dimension names to block lengths for this namedArray's data, or None if + the underlying data is not a dask array. + Cannot be modified directly, but can be modified by calling .chunk(). + + Differs from NamedArray.chunks because it returns a mapping of dimensions to chunk shapes + instead of a tuple of chunk shapes. + + See Also + -------- + NamedArray.chunk + NamedArray.chunks + xarray.unify_chunks + """ + data = self._data + if isinstance(data, _chunkedarray): + return dict(zip(self.dims, data.chunks)) + else: + return {} + + @property + def sizes(self) -> dict[_Dim, _IntOrUnknown]: + """Ordered mapping from dimension names to lengths.""" + return dict(zip(self.dims, self.shape)) + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> NamedArray[Any, Any]: + """Reduce this array by applying `func` along some dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of reducing an + np.ndarray over an integer valued axis. + dim : "...", str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. By default `func` is + applied over all dimensions. + axis : int or Sequence of int, optional + Axis(es) over which to apply `func`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + the reduction is calculated over the flattened array (by calling + `func(x)` without an axis argument). + keepdims : bool, default: False + If True, the dimensions which are reduced are left in the result + as dimensions of size one + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : Array + Array with summarized data and the indicated dimension(s) + removed. + """ + if dim == ...: + dim = None + if dim is not None and axis is not None: + raise ValueError("cannot supply both 'axis' and 'dim' arguments") + + if dim is not None: + axis = self.get_axis_num(dim) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", r"Mean of empty slice", category=RuntimeWarning + ) + if axis is not None: + if isinstance(axis, tuple) and len(axis) == 1: + # unpack axis for the benefit of functions + # like np.argmin which can't handle tuple arguments + axis = axis[0] + data = func(self.data, axis=axis, **kwargs) + else: + data = func(self.data, **kwargs) + + if getattr(data, "shape", ()) == self.shape: + dims = self.dims + else: + removed_axes: Iterable[int] + if axis is None: + removed_axes = range(self.ndim) + else: + removed_axes = np.atleast_1d(axis) % self.ndim + if keepdims: + # Insert np.newaxis for removed dims + slices = tuple( + np.newaxis if i in removed_axes else slice(None, None) + for i in range(self.ndim) + ) + if getattr(data, "shape", None) is None: + # Reduce has produced a scalar value, not an array-like + data = np.asanyarray(data)[slices] + else: + data = data[slices] + dims = self.dims + else: + dims = tuple( + adim for n, adim in enumerate(self.dims) if n not in removed_axes + ) + + # Return NamedArray to handle IndexVariable when data is nD + return from_array(dims, data, attrs=self._attrs) + + def _nonzero(self: T_NamedArrayInteger) -> tuple[T_NamedArrayInteger, ...]: + """Equivalent numpy's nonzero but returns a tuple of NamedArrays.""" + # TODO: we should replace dask's native nonzero + # after https://github.com/dask/dask/issues/1076 is implemented. + # TODO: cast to ndarray and back to T_DuckArray is a workaround + nonzeros = np.nonzero(cast("NDArray[np.integer[Any]]", self.data)) + _attrs = self.attrs + return tuple( + cast("T_NamedArrayInteger", self._new((dim,), nz, _attrs)) + for nz, dim in zip(nonzeros, self.dims) + ) + + def __repr__(self) -> str: + return formatting.array_repr(self) + + def _repr_html_(self) -> str: + return formatting_html.array_repr(self) + + def _as_sparse( + self, + sparse_format: Literal["coo"] | Default = _default, + fill_value: ArrayLike | Default = _default, + ) -> NamedArray[Any, _DType_co]: + """ + Use sparse-array as backend. + """ + import sparse + + from xarray.namedarray._array_api import astype + + # TODO: what to do if dask-backended? + if fill_value is _default: + dtype, fill_value = dtypes.maybe_promote(self.dtype) + else: + dtype = dtypes.result_type(self.dtype, fill_value) + + if sparse_format is _default: + sparse_format = "coo" + try: + as_sparse = getattr(sparse, f"as_{sparse_format.lower()}") + except AttributeError as exc: + raise ValueError(f"{sparse_format} is not a valid sparse format") from exc + + data = as_sparse(astype(self, dtype).data, fill_value=fill_value) + return self._new(data=data) + + def _to_dense(self) -> NamedArray[Any, _DType_co]: + """ + Change backend from sparse to np.array. + """ + if isinstance(self._data, _sparsearrayfunction_or_api): + data_dense: np.ndarray[Any, _DType_co] = self._data.todense() + return self._new(data=data_dense) + else: + raise TypeError("self.data is not a sparse array") + + +_NamedArray = NamedArray[Any, np.dtype[_ScalarType_co]] diff --git a/xarray/namedarray/dtypes.py b/xarray/namedarray/dtypes.py new file mode 100644 index 00000000000..7a83bd17064 --- /dev/null +++ b/xarray/namedarray/dtypes.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import functools +import sys +from typing import Any, Literal + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + +import numpy as np + +from xarray.namedarray import utils + +# Use as a sentinel value to indicate a dtype appropriate NA value. +NA = utils.ReprObject("") + + +@functools.total_ordering +class AlwaysGreaterThan: + def __gt__(self, other: Any) -> Literal[True]: + return True + + def __eq__(self, other: Any) -> bool: + return isinstance(other, type(self)) + + +@functools.total_ordering +class AlwaysLessThan: + def __lt__(self, other: Any) -> Literal[True]: + return True + + def __eq__(self, other: Any) -> bool: + return isinstance(other, type(self)) + + +# Equivalence to np.inf (-np.inf) for object-type +INF = AlwaysGreaterThan() +NINF = AlwaysLessThan() + + +# Pairs of types that, if both found, should be promoted to object dtype +# instead of following NumPy's own type-promotion rules. These type promotion +# rules match pandas instead. For reference, see the NumPy type hierarchy: +# https://numpy.org/doc/stable/reference/arrays.scalars.html +PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( + (np.number, np.character), # numpy promotes to character + (np.bool_, np.character), # numpy promotes to character + (np.bytes_, np.str_), # numpy promotes to unicode +) + + +def maybe_promote(dtype: np.dtype[np.generic]) -> tuple[np.dtype[np.generic], Any]: + """Simpler equivalent of pandas.core.common._maybe_promote + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + dtype : Promoted dtype that can hold missing values. + fill_value : Valid missing value for the promoted dtype. + """ + # N.B. these casting rules should match pandas + dtype_: np.typing.DTypeLike + fill_value: Any + if np.issubdtype(dtype, np.floating): + dtype_ = dtype + fill_value = np.nan + elif np.issubdtype(dtype, np.timedelta64): + # See https://github.com/numpy/numpy/issues/10685 + # np.timedelta64 is a subclass of np.integer + # Check np.timedelta64 before np.integer + fill_value = np.timedelta64("NaT") + dtype_ = dtype + elif np.issubdtype(dtype, np.integer): + dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64 + fill_value = np.nan + elif np.issubdtype(dtype, np.complexfloating): + dtype_ = dtype + fill_value = np.nan + np.nan * 1j + elif np.issubdtype(dtype, np.datetime64): + dtype_ = dtype + fill_value = np.datetime64("NaT") + else: + dtype_ = object + fill_value = np.nan + + dtype_out = np.dtype(dtype_) + fill_value = dtype_out.type(fill_value) + return dtype_out, fill_value + + +NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} + + +def get_fill_value(dtype: np.dtype[np.generic]) -> Any: + """Return an appropriate fill value for this dtype. + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + fill_value : Missing value corresponding to this dtype. + """ + _, fill_value = maybe_promote(dtype) + return fill_value + + +def get_pos_infinity( + dtype: np.dtype[np.generic], max_for_int: bool = False +) -> float | complex | AlwaysGreaterThan: + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + max_for_int : bool + Return np.iinfo(dtype).max instead of np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, np.floating): + return np.inf + + if issubclass(dtype.type, np.integer): + return np.iinfo(dtype.type).max if max_for_int else np.inf + if issubclass(dtype.type, np.complexfloating): + return np.inf + 1j * np.inf + + return INF + + +def get_neg_infinity( + dtype: np.dtype[np.generic], min_for_int: bool = False +) -> float | complex | AlwaysLessThan: + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + min_for_int : bool + Return np.iinfo(dtype).min instead of -np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, np.floating): + return -np.inf + + if issubclass(dtype.type, np.integer): + return np.iinfo(dtype.type).min if min_for_int else -np.inf + if issubclass(dtype.type, np.complexfloating): + return -np.inf - 1j * np.inf + + return NINF + + +def is_datetime_like( + dtype: np.dtype[np.generic], +) -> TypeGuard[np.datetime64 | np.timedelta64]: + """Check if a dtype is a subclass of the numpy datetime types""" + return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) + + +def result_type( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +) -> np.dtype[np.generic]: + """Like np.result_type, but with type promotion rules matching pandas. + + Examples of changed behavior: + number + string -> object (not string) + bytes + unicode -> object (not unicode) + + Parameters + ---------- + *arrays_and_dtypes : list of arrays and dtypes + The dtype is extracted from both numpy and dask arrays. + + Returns + ------- + numpy.dtype for the result. + """ + types = {np.result_type(t).type for t in arrays_and_dtypes} + + for left, right in PROMOTE_TO_OBJECT: + if any(issubclass(t, left) for t in types) and any( + issubclass(t, right) for t in types + ): + return np.dtype(object) + + return np.result_type(*arrays_and_dtypes) diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py new file mode 100644 index 00000000000..03eb0134231 --- /dev/null +++ b/xarray/namedarray/utils.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import sys +from collections.abc import Hashable +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Final, +) + +import numpy as np + +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import TypeGuard + else: + from typing_extensions import TypeGuard + + from numpy.typing import NDArray + + from xarray.namedarray._typing import ( + duckarray, + ) + + try: + from dask.array.core import Array as DaskArray + from dask.typing import DaskCollection + except ImportError: + DaskArray = NDArray # type: ignore + DaskCollection: Any = NDArray # type: ignore + + +# Singleton type, as per https://github.com/python/typing/pull/240 +class Default(Enum): + token: Final = 0 + + +_default = Default.token + + +def module_available(module: str) -> bool: + """Checks whether a module is installed without importing it. + + Use this for a lightweight check and lazy imports. + + Parameters + ---------- + module : str + Name of the module. + + Returns + ------- + available : bool + Whether the module is installed. + """ + from importlib.util import find_spec + + return find_spec(module) is not None + + +def is_dask_collection(x: object) -> TypeGuard[DaskCollection]: + if module_available("dask"): + from dask.typing import DaskCollection + + return isinstance(x, DaskCollection) + return False + + +def is_duck_dask_array(x: duckarray[Any, Any]) -> TypeGuard[DaskArray]: + return is_dask_collection(x) + + +def to_0d_object_array( + value: object, +) -> NDArray[np.object_]: + """Given a value, wrap it in a 0-D numpy.ndarray with dtype=object.""" + result = np.empty((), dtype=object) + result[()] = value + return result + + +class ReprObject: + """Object that prints as the given value, for use with sentinel values.""" + + __slots__ = ("_value",) + + _value: str + + def __init__(self, value: str): + self._value = value + + def __repr__(self) -> str: + return self._value + + def __eq__(self, other: ReprObject | Any) -> bool: + # TODO: What type can other be? ArrayLike? + return self._value == other._value if isinstance(other, ReprObject) else False + + def __hash__(self) -> int: + return hash((type(self), self._value)) + + def __dask_tokenize__(self) -> Hashable: + from dask.base import normalize_token + + return normalize_token((type(self), self._value)) # type: ignore[no-any-return] diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index ff707602545..203bae2691f 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -16,6 +16,7 @@ from matplotlib.container import BarContainer from matplotlib.contour import QuadContourSet from matplotlib.image import AxesImage + from matplotlib.patches import Polygon from matplotlib.quiver import Quiver from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection from numpy.typing import ArrayLike @@ -47,11 +48,13 @@ def __call__(self, **kwargs) -> Any: return dataarray_plot.plot(self._da, **kwargs) @functools.wraps(dataarray_plot.hist) - def hist(self, *args, **kwargs) -> tuple[np.ndarray, np.ndarray, BarContainer]: + def hist( + self, *args, **kwargs + ) -> tuple[np.ndarray, np.ndarray, BarContainer | Polygon]: return dataarray_plot.hist(self._da, *args, **kwargs) @overload - def line( # type: ignore[misc] # None is hashable :( + def line( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, row: None = None, # no wrap -> primitive @@ -69,8 +72,8 @@ def line( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, @@ -96,8 +99,8 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, @@ -123,20 +126,20 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.line) + @functools.wraps(dataarray_plot.line, assigned=("__doc__",)) def line(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.line(self._da, *args, **kwargs) @overload - def step( # type: ignore[misc] # None is hashable :( + def step( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, where: Literal["pre", "post", "mid"] = "pre", @@ -174,12 +177,12 @@ def step( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.step) + @functools.wraps(dataarray_plot.step, assigned=("__doc__",)) def step(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.step(self._da, *args, **kwargs) @overload - def scatter( + def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -207,8 +210,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -248,8 +251,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -289,8 +292,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -301,12 +304,12 @@ def scatter( ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.scatter) - def scatter(self, *args, **kwargs): + @functools.wraps(dataarray_plot.scatter, assigned=("__doc__",)) + def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: return dataarray_plot.scatter(self._da, *args, **kwargs) @overload - def imshow( + def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -338,8 +341,8 @@ def imshow( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> AxesImage: @@ -378,8 +381,8 @@ def imshow( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: @@ -418,19 +421,19 @@ def imshow( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.imshow) - def imshow(self, *args, **kwargs) -> AxesImage: + @functools.wraps(dataarray_plot.imshow, assigned=("__doc__",)) + def imshow(self, *args, **kwargs) -> AxesImage | FacetGrid[DataArray]: return dataarray_plot.imshow(self._da, *args, **kwargs) @overload - def contour( + def contour( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -462,8 +465,8 @@ def contour( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> QuadContourSet: @@ -502,8 +505,8 @@ def contour( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: @@ -542,19 +545,19 @@ def contour( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.contour) - def contour(self, *args, **kwargs) -> QuadContourSet: + @functools.wraps(dataarray_plot.contour, assigned=("__doc__",)) + def contour(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: return dataarray_plot.contour(self._da, *args, **kwargs) @overload - def contourf( + def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -586,8 +589,8 @@ def contourf( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> QuadContourSet: @@ -626,8 +629,8 @@ def contourf( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid[DataArray]: @@ -666,19 +669,19 @@ def contourf( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid: ... - @functools.wraps(dataarray_plot.contourf) - def contourf(self, *args, **kwargs) -> QuadContourSet: + @functools.wraps(dataarray_plot.contourf, assigned=("__doc__",)) + def contourf(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: return dataarray_plot.contourf(self._da, *args, **kwargs) @overload - def pcolormesh( + def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -710,8 +713,8 @@ def pcolormesh( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> QuadMesh: @@ -750,11 +753,11 @@ def pcolormesh( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[DataArray]: ... @overload @@ -790,15 +793,15 @@ def pcolormesh( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.pcolormesh) - def pcolormesh(self, *args, **kwargs) -> QuadMesh: + @functools.wraps(dataarray_plot.pcolormesh, assigned=("__doc__",)) + def pcolormesh(self, *args, **kwargs) -> QuadMesh | FacetGrid[DataArray]: return dataarray_plot.pcolormesh(self._da, *args, **kwargs) @overload @@ -834,8 +837,8 @@ def surface( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> Poly3DCollection: @@ -874,8 +877,8 @@ def surface( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid: @@ -914,14 +917,14 @@ def surface( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> FacetGrid: ... - @functools.wraps(dataarray_plot.surface) + @functools.wraps(dataarray_plot.surface, assigned=("__doc__",)) def surface(self, *args, **kwargs) -> Poly3DCollection: return dataarray_plot.surface(self._da, *args, **kwargs) @@ -945,7 +948,7 @@ def __call__(self, *args, **kwargs) -> NoReturn: ) @overload - def scatter( + def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -973,8 +976,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -1014,8 +1017,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -1023,7 +1026,7 @@ def scatter( extend=None, levels=None, **kwargs: Any, - ) -> FacetGrid[DataArray]: + ) -> FacetGrid[Dataset]: ... @overload @@ -1055,8 +1058,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -1064,15 +1067,15 @@ def scatter( extend=None, levels=None, **kwargs: Any, - ) -> FacetGrid[DataArray]: + ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.scatter) - def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: + @functools.wraps(dataset_plot.scatter, assigned=("__doc__",)) + def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[Dataset]: return dataset_plot.scatter(self._ds, *args, **kwargs) @overload - def quiver( + def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -1142,7 +1145,7 @@ def quiver( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[Dataset]: ... @overload @@ -1179,15 +1182,15 @@ def quiver( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.quiver) - def quiver(self, *args, **kwargs) -> Quiver | FacetGrid: + @functools.wraps(dataset_plot.quiver, assigned=("__doc__",)) + def quiver(self, *args, **kwargs) -> Quiver | FacetGrid[Dataset]: return dataset_plot.quiver(self._ds, *args, **kwargs) @overload - def streamplot( + def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -1257,7 +1260,7 @@ def streamplot( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[Dataset]: ... @overload @@ -1294,9 +1297,9 @@ def streamplot( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: + ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.streamplot) - def streamplot(self, *args, **kwargs) -> LineCollection | FacetGrid: + @functools.wraps(dataset_plot.streamplot, assigned=("__doc__",)) + def streamplot(self, *args, **kwargs) -> LineCollection | FacetGrid[Dataset]: return dataset_plot.streamplot(self._ds, *args, **kwargs) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index d2c0a8e2af6..61f2014fbc3 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -3,7 +3,7 @@ import functools import warnings from collections.abc import Hashable, Iterable, MutableMapping -from typing import TYPE_CHECKING, Any, Callable, Literal, cast, overload +from typing import TYPE_CHECKING, Any, Callable, Literal, Union, cast, overload import numpy as np import pandas as pd @@ -29,7 +29,6 @@ _resolve_intervals_2dplot, _update_axes, get_axis, - import_matplotlib_pyplot, label_from_attrs, ) @@ -40,6 +39,7 @@ from matplotlib.container import BarContainer from matplotlib.contour import QuadContourSet from matplotlib.image import AxesImage + from matplotlib.patches import Polygon from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection from numpy.typing import ArrayLike @@ -53,7 +53,7 @@ ) from xarray.plot.facetgrid import FacetGrid -_styles: MutableMapping[str, Any] = { +_styles: dict[str, Any] = { # Add a white border to make it easier seeing overlapping markers: "scatter.edgecolors": "w", } @@ -186,7 +186,7 @@ def _prepare_plot1d_data( # dimensions so the plotter can plot anything: if darray.ndim > 1: # When stacking dims the lines will continue connecting. For floats - # this can be solved by adding a nan element inbetween the flattening + # this can be solved by adding a nan element in between the flattening # points: dims_T = [] if np.issubdtype(darray.dtype, np.floating): @@ -264,7 +264,9 @@ def plot( -------- xarray.DataArray.squeeze """ - darray = darray.squeeze().compute() + darray = darray.squeeze( + d for d, s in darray.sizes.items() if s == 1 and d not in (row, col, hue) + ).compute() plot_dims = set(darray.dims) plot_dims.discard(row) @@ -307,7 +309,7 @@ def plot( @overload -def line( # type: ignore[misc] # None is hashable :( +def line( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, *args: Any, row: None = None, # no wrap -> primitive @@ -325,8 +327,8 @@ def line( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, @@ -336,7 +338,7 @@ def line( # type: ignore[misc] # None is hashable :( @overload def line( - darray, + darray: T_DataArray, *args: Any, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, @@ -353,18 +355,18 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def line( - darray, + darray: T_DataArray, *args: Any, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid @@ -381,19 +383,19 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... # This function signature should not change so that it can use # matplotlib format strings def line( - darray: DataArray, + darray: T_DataArray, *args: Any, row: Hashable | None = None, col: Hashable | None = None, @@ -410,12 +412,12 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> list[Line3D] | FacetGrid[DataArray]: +) -> list[Line3D] | FacetGrid[T_DataArray]: """ Line plot of DataArray values. @@ -459,7 +461,7 @@ def line( Specifies scaling for the *x*- and *y*-axis, respectively. xticks, yticks : array-like, optional Specify tick locations for *x*- and *y*-axis. - xlim, ylim : array-like, optional + xlim, ylim : tuple[float, float], optional Specify *x*- and *y*-axis limits. add_legend : bool, default: True Add legend with *y* axis coordinates (2D inputs only). @@ -486,8 +488,8 @@ def line( if ndims > 2: raise ValueError( "Line plots are for 1- or 2-dimensional DataArrays. " - "Passed DataArray has {ndims} " - "dimensions".format(ndims=ndims) + f"Passed DataArray has {ndims} " + "dimensions" ) # The allargs dict passed to _easy_facetgrid above contains args @@ -538,7 +540,7 @@ def line( @overload -def step( # type: ignore[misc] # None is hashable :( +def step( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, *args: Any, where: Literal["pre", "post", "mid"] = "pre", @@ -654,10 +656,10 @@ def hist( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, **kwargs: Any, -) -> tuple[np.ndarray, np.ndarray, BarContainer]: +) -> tuple[np.ndarray, np.ndarray, BarContainer | Polygon]: """ Histogram of DataArray. @@ -691,7 +693,7 @@ def hist( Specifies scaling for the *x*- and *y*-axis, respectively. xticks, yticks : array-like, optional Specify tick locations for *x*- and *y*-axis. - xlim, ylim : array-like, optional + xlim, ylim : tuple[float, float], optional Specify *x*- and *y*-axis limits. **kwargs : optional Additional keyword arguments to :py:func:`matplotlib:matplotlib.pyplot.hist`. @@ -708,14 +710,17 @@ def hist( no_nan = np.ravel(darray.to_numpy()) no_nan = no_nan[pd.notnull(no_nan)] - primitive = ax.hist(no_nan, **kwargs) + n, bins, patches = cast( + tuple[np.ndarray, np.ndarray, Union["BarContainer", "Polygon"]], + ax.hist(no_nan, **kwargs), + ) ax.set_title(darray._title_for_slice()) ax.set_xlabel(label_from_attrs(darray)) _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) - return primitive + return n, bins, patches def _plot1d(plotfunc): @@ -779,9 +784,9 @@ def _plot1d(plotfunc): Specify tick locations for x-axes. yticks : ArrayLike or None, optional Specify tick locations for y-axes. - xlim : ArrayLike or None, optional + xlim : tuple[float, float] or None, optional Specify x-axes limits. - ylim : ArrayLike or None, optional + ylim : tuple[float, float] or None, optional Specify y-axes limits. cmap : matplotlib colormap name or colormap, optional The mapping from data values to color space. Either a @@ -789,7 +794,7 @@ def _plot1d(plotfunc): be either ``'viridis'`` (if the function infers a sequential dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset). - See :doc:`Choosing Colormaps in Matplotlib ` + See :doc:`Choosing Colormaps in Matplotlib ` for more information. If *seaborn* is installed, ``cmap`` may also be a @@ -866,8 +871,8 @@ def newplotfunc( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, @@ -879,7 +884,7 @@ def newplotfunc( # All 1d plots in xarray share this function signature. # Method signature below should be consistent. - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt if subplot_kws is None: subplot_kws = dict() @@ -992,9 +997,13 @@ def newplotfunc( with plt.rc_context(_styles): if z is not None: + import mpl_toolkits + if ax is None: subplot_kws.update(projection="3d") ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + # Using 30, 30 minimizes rotation of the plot. Making it easier to # build on your intuition from 2D plots: ax.view_init(azim=30, elev=30, vertical_axis="y") @@ -1043,9 +1052,7 @@ def newplotfunc( else: hueplt_norm_values: list[np.ndarray | None] if hueplt_norm.data is not None: - hueplt_norm_values = list( - cast("DataArray", hueplt_norm.data).to_numpy() - ) + hueplt_norm_values = list(hueplt_norm.data.to_numpy()) else: hueplt_norm_values = [hueplt_norm.data] @@ -1078,12 +1085,12 @@ def newplotfunc( def _add_labels( add_labels: bool | Iterable[bool], - darrays: Iterable[DataArray], + darrays: Iterable[DataArray | None], suffixes: Iterable[str], rotate_labels: Iterable[bool], ax: Axes, ) -> None: - # Set x, y, z labels: + """Set x, y, z labels.""" add_labels = [add_labels] * 3 if isinstance(add_labels, bool) else add_labels for axis, add_label, darray, suffix, rotate_label in zip( ("x", "y", "z"), add_labels, darrays, suffixes, rotate_labels @@ -1107,7 +1114,7 @@ def _add_labels( @overload -def scatter( +def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, *args: Any, x: Hashable | None = None, @@ -1150,7 +1157,7 @@ def scatter( @overload def scatter( - darray: DataArray, + darray: T_DataArray, *args: Any, x: Hashable | None = None, y: Hashable | None = None, @@ -1186,13 +1193,13 @@ def scatter( extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def scatter( - darray: DataArray, + darray: T_DataArray, *args: Any, x: Hashable | None = None, y: Hashable | None = None, @@ -1228,7 +1235,7 @@ def scatter( extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -1257,15 +1264,24 @@ def scatter( if sizeplt is not None: kwargs.update(s=sizeplt.to_numpy().ravel()) - axis_order = ["x", "y", "z"] + plts_or_none = (xplt, yplt, zplt) + _add_labels(add_labels, plts_or_none, ("", "", ""), (True, False, False), ax) - plts_dict: dict[str, DataArray | None] = dict(x=xplt, y=yplt, z=zplt) - plts_or_none = [plts_dict[v] for v in axis_order] - plts = [p for p in plts_or_none if p is not None] - primitive = ax.scatter(*[p.to_numpy().ravel() for p in plts], **kwargs) - _add_labels(add_labels, plts, ("", "", ""), (True, False, False), ax) + xplt_np = None if xplt is None else xplt.to_numpy().ravel() + yplt_np = None if yplt is None else yplt.to_numpy().ravel() + zplt_np = None if zplt is None else zplt.to_numpy().ravel() + plts_np = tuple(p for p in (xplt_np, yplt_np, zplt_np) if p is not None) - return primitive + if len(plts_np) == 3: + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + return ax.scatter(xplt_np, yplt_np, zplt_np, **kwargs) + + if len(plts_np) == 2: + return ax.scatter(plts_np[0], plts_np[1], **kwargs) + + raise ValueError("At least two variables required for a scatter plot.") def _plot2d(plotfunc): @@ -1325,14 +1341,14 @@ def _plot2d(plotfunc): The mapping from data values to color space. If not provided, this will be either be ``'viridis'`` (if the function infers a sequential dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset). - See :doc:`Choosing Colormaps in Matplotlib ` + See :doc:`Choosing Colormaps in Matplotlib ` for more information. If *seaborn* is installed, ``cmap`` may also be a `seaborn color palette `_. Note: if ``cmap`` is a seaborn color palette and the plot type is not ``'contour'`` or ``'contourf'``, ``levels`` must also be specified. - center : float, optional + center : float or False, optional The value at which to center the colormap. Passing this value implies use of a diverging colormap. Setting it to ``False`` prevents use of a diverging colormap. @@ -1374,9 +1390,9 @@ def _plot2d(plotfunc): Specify tick locations for x-axes. yticks : ArrayLike or None, optional Specify tick locations for y-axes. - xlim : ArrayLike or None, optional + xlim : tuple[float, float] or None, optional Specify x-axes limits. - ylim : ArrayLike or None, optional + ylim : tuple[float, float] or None, optional Specify y-axes limits. norm : matplotlib.colors.Normalize, optional If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding @@ -1416,7 +1432,7 @@ def newplotfunc( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1429,8 +1445,8 @@ def newplotfunc( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> Any: @@ -1502,8 +1518,6 @@ def newplotfunc( # TypeError to be consistent with pandas raise TypeError("No numeric data to plot.") - plt = import_matplotlib_pyplot() - if ( plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False) @@ -1616,6 +1630,9 @@ def newplotfunc( ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) ax.set_title(darray._title_for_slice()) if plotfunc.__name__ == "surface": + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) ax.set_zlabel(label_from_attrs(darray)) if add_colorbar: @@ -1656,7 +1673,7 @@ def newplotfunc( @overload -def imshow( +def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -1675,7 +1692,7 @@ def imshow( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1698,7 +1715,7 @@ def imshow( @overload def imshow( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1716,7 +1733,7 @@ def imshow( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1733,13 +1750,13 @@ def imshow( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def imshow( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1757,7 +1774,7 @@ def imshow( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1774,7 +1791,7 @@ def imshow( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -1875,7 +1892,7 @@ def _center_pixels(x): @overload -def contour( +def contour( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -1894,7 +1911,7 @@ def contour( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1917,7 +1934,7 @@ def contour( @overload def contour( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1935,7 +1952,7 @@ def contour( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1952,13 +1969,13 @@ def contour( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def contour( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1976,7 +1993,7 @@ def contour( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1993,7 +2010,7 @@ def contour( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -2011,7 +2028,7 @@ def contour( @overload -def contourf( +def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -2030,7 +2047,7 @@ def contourf( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2053,7 +2070,7 @@ def contourf( @overload def contourf( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2071,7 +2088,7 @@ def contourf( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2088,13 +2105,13 @@ def contourf( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def contourf( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2112,7 +2129,7 @@ def contourf( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2129,7 +2146,7 @@ def contourf( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -2147,7 +2164,7 @@ def contourf( @overload -def pcolormesh( +def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -2166,7 +2183,7 @@ def pcolormesh( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2189,7 +2206,7 @@ def pcolormesh( @overload def pcolormesh( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2207,7 +2224,7 @@ def pcolormesh( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2224,13 +2241,13 @@ def pcolormesh( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def pcolormesh( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2248,7 +2265,7 @@ def pcolormesh( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2265,7 +2282,7 @@ def pcolormesh( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -2353,7 +2370,7 @@ def surface( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2376,7 +2393,7 @@ def surface( @overload def surface( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2394,7 +2411,7 @@ def surface( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2411,13 +2428,13 @@ def surface( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @overload def surface( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2435,7 +2452,7 @@ def surface( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2452,7 +2469,7 @@ def surface( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: +) -> FacetGrid[T_DataArray]: ... @@ -2465,5 +2482,8 @@ def surface( Wraps :py:meth:`matplotlib:mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`. """ + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) primitive = ax.plot_surface(x, y, z, **kwargs) return primitive diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index b0774c31b17..a3ca201eec4 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -103,7 +103,7 @@ def _dsplot(plotfunc): be either ``'viridis'`` (if the function infers a sequential dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset). - See :doc:`Choosing Colormaps in Matplotlib ` + See :doc:`Choosing Colormaps in Matplotlib ` for more information. If *seaborn* is installed, ``cmap`` may also be a @@ -321,7 +321,7 @@ def newplotfunc( @overload -def quiver( +def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, @@ -475,7 +475,7 @@ def quiver( @overload -def streamplot( +def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, @@ -632,7 +632,6 @@ def streamplot( du = du.transpose(ydim, xdim) dv = dv.transpose(ydim, xdim) - args = [dx.values, dy.values, du.values, dv.values] hue = kwargs.pop("hue") cmap_params = kwargs.pop("cmap_params") @@ -646,7 +645,9 @@ def streamplot( ) kwargs.pop("hue_style") - hdl = ax.streamplot(*args, **kwargs, **cmap_params) + hdl = ax.streamplot( + dx.values, dy.values, du.values, dv.values, **kwargs, **cmap_params + ) # Return .lines so colorbar creation works properly return hdl.lines @@ -730,7 +731,7 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr coords = dict(ds.coords) # Add extra coords to the DataArray from valid kwargs, if using all - # kwargs there is a risk that we add unneccessary dataarrays as + # kwargs there is a risk that we add unnecessary dataarrays as # coords straining RAM further for example: # ds.both and extend="both" would add ds.both to the coords: valid_coord_kwargs = {"x", "z", "markersize", "hue", "row", "col", "u", "v"} @@ -748,7 +749,7 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr @overload -def scatter( +def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 93a328836d0..faf809a8a74 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -9,7 +9,7 @@ import numpy as np from xarray.core.formatting import format_item -from xarray.core.types import HueStyleOptions, T_Xarray +from xarray.core.types import HueStyleOptions, T_DataArrayOrSet from xarray.plot.utils import ( _LINEWIDTH_RANGE, _MARKERSIZE_RANGE, @@ -21,7 +21,6 @@ _Normalize, _parse_size, _process_cmap_cbar_kwargs, - import_matplotlib_pyplot, label_from_attrs, ) @@ -60,7 +59,7 @@ def _nicetitle(coord, value, maxchar, template): T_FacetGrid = TypeVar("T_FacetGrid", bound="FacetGrid") -class FacetGrid(Generic[T_Xarray]): +class FacetGrid(Generic[T_DataArrayOrSet]): """ Initialize the Matplotlib figure and FacetGrid object. @@ -76,7 +75,7 @@ class FacetGrid(Generic[T_Xarray]): The general approach to plotting here is called "small multiples", where the same kind of plot is repeated multiple times, and the specific use of small multiples to display the same relationship - conditioned on one ore more other variables is often called a "trellis + conditioned on one or more other variables is often called a "trellis plot". The basic workflow is to initialize the :class:`FacetGrid` object with @@ -101,7 +100,7 @@ class FacetGrid(Generic[T_Xarray]): sometimes the rightmost grid positions in the bottom row. """ - data: T_Xarray + data: T_DataArrayOrSet name_dicts: np.ndarray fig: Figure axs: np.ndarray @@ -126,7 +125,7 @@ class FacetGrid(Generic[T_Xarray]): def __init__( self, - data: T_Xarray, + data: T_DataArrayOrSet, col: Hashable | None = None, row: Hashable | None = None, col_wrap: int | None = None, @@ -166,7 +165,7 @@ def __init__( """ - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt # Handle corner case of nonunique coordinates rep_col = col is not None and not data[col].to_index().is_unique @@ -681,7 +680,10 @@ def _finalize_grid(self, *axlabels: Hashable) -> None: def _adjust_fig_for_guide(self, guide) -> None: # Draw the plot to set the bounding boxes correctly - renderer = self.fig.canvas.get_renderer() + if hasattr(self.fig.canvas, "get_renderer"): + renderer = self.fig.canvas.get_renderer() + else: + raise RuntimeError("MPL backend has no renderer") self.fig.draw(renderer) # Calculate and set the new width of the figure so the legend fits @@ -731,6 +733,9 @@ def add_colorbar(self, **kwargs: Any) -> None: if hasattr(self._mappables[-1], "extend"): kwargs.pop("extend", None) if "label" not in kwargs: + from xarray import DataArray + + assert isinstance(self.data, DataArray) kwargs.setdefault("label", label_from_attrs(self.data)) self.cbar = self.fig.colorbar( self._mappables[-1], ax=list(self.axs.flat), **kwargs @@ -985,7 +990,7 @@ def map( self : FacetGrid object """ - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt for ax, namedict in zip(self.axs.flat, self.name_dicts.flat): if namedict is not None: @@ -1004,7 +1009,7 @@ def map( def _easy_facetgrid( - data: T_Xarray, + data: T_DataArrayOrSet, plotfunc: Callable, kind: Literal["line", "dataarray", "dataset", "plot1d"], x: Hashable | None = None, @@ -1020,7 +1025,7 @@ def _easy_facetgrid( ax: Axes | None = None, figsize: Iterable[float] | None = None, **kwargs: Any, -) -> FacetGrid[T_Xarray]: +) -> FacetGrid[T_DataArrayOrSet]: """ Convenience method to call xarray.plot.FacetGrid from 2d plotting methods diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 2c58fe83cef..5694acc06e8 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -47,14 +47,6 @@ _LINEWIDTH_RANGE = (1.5, 1.5, 6.0) -def import_matplotlib_pyplot(): - """import pyplot""" - # TODO: This function doesn't do anything (after #6109), remove it? - import matplotlib.pyplot as plt - - return plt - - def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax @@ -505,28 +497,29 @@ def _maybe_gca(**subplot_kws: Any) -> Axes: return plt.axes(**subplot_kws) -def _get_units_from_attrs(da) -> str: +def _get_units_from_attrs(da: DataArray) -> str: """Extracts and formats the unit/units from a attributes.""" pint_array_type = DuckArrayModule("pint").type units = " [{}]" if isinstance(da.data, pint_array_type): - units = units.format(str(da.data.units)) - elif da.attrs.get("units"): - units = units.format(da.attrs["units"]) - elif da.attrs.get("unit"): - units = units.format(da.attrs["unit"]) - else: - units = "" - return units + return units.format(str(da.data.units)) + if "units" in da.attrs: + return units.format(da.attrs["units"]) + if "unit" in da.attrs: + return units.format(da.attrs["unit"]) + return "" -def label_from_attrs(da, extra: str = "") -> str: +def label_from_attrs(da: DataArray | None, extra: str = "") -> str: """Makes informative labels if variable metadata (attrs) follows CF conventions.""" + if da is None: + return "" + name: str = "{}" - if da.attrs.get("long_name"): + if "long_name" in da.attrs: name = name.format(da.attrs["long_name"]) - elif da.attrs.get("standard_name"): + elif "standard_name" in da.attrs: name = name.format(da.attrs["standard_name"]) elif da.name is not None: name = name.format(da.name) @@ -774,8 +767,8 @@ def _update_axes( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, ) -> None: """ Update axes with provided parameters @@ -1131,7 +1124,7 @@ def _get_color_and_size(value): # Labels are not numerical so modifying label_values is not # possible, instead filter the array with nicely distributed # indexes: - if type(num) == int: + if type(num) == int: # noqa: E721 loc = mpl.ticker.LinearLocator(num) else: raise ValueError("`num` only supports integers for non-numeric labels.") @@ -1166,7 +1159,7 @@ def _get_color_and_size(value): def _legend_add_subtitle(handles, labels, text): """Add a subtitle to legend handles.""" - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt if text and len(handles) > 1: # Create a blank handle that's not visible, the @@ -1184,7 +1177,7 @@ def _legend_add_subtitle(handles, labels, text): def _adjust_legend_subtitles(legend): """Make invisible-handle "subtitles" entries look more like titles.""" - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt # Legend title not in rcParams until 3.0 font_size = plt.rcParams.get("legend.title_fontsize", None) @@ -1461,7 +1454,7 @@ def _calc_widths(self, y: DataArray) -> DataArray: def _calc_widths(self, y: np.ndarray | DataArray) -> np.ndarray | DataArray: """ - Normalize the values so they're inbetween self._width. + Normalize the values so they're in between self._width. """ if self._width is None: return y @@ -1473,7 +1466,7 @@ def _calc_widths(self, y: np.ndarray | DataArray) -> np.ndarray | DataArray: # Use default with if y is constant: widths = xdefault + 0 * y else: - # Normalize inbetween xmin and xmax: + # Normalize in between xmin and xmax: k = (y - np.min(y)) / diff_maxy_miny widths = xmin + k * (xmax - xmin) return widths @@ -1640,7 +1633,7 @@ def format(self) -> FuncFormatter: >>> aa.format(1) '3.0' """ - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt def _func(x: Any, pos: None | Any = None): return f"{self._lookup_arr([x])[0]}" @@ -1821,8 +1814,8 @@ def _guess_coords_to_plot( ) # If dims_plot[k] isn't defined then fill with one of the available dims, unless - # one of related mpl kwargs has been used. This should have similiar behaviour as - # * plt.plot(x, y) -> Multple lines with different colors if y is 2d. + # one of related mpl kwargs has been used. This should have similar behaviour as + # * plt.plot(x, y) -> Multiple lines with different colors if y is 2d. # * plt.plot(x, y, color="red") -> Multiple red lines if y is 2d. for k, dim, ign_kws in zip(default_guess, available_coords, ignore_guess_kwargs): if coords_to_plot.get(k, None) is None and all( diff --git a/xarray/testing/testing.py b/xarray/testing/testing.py index f04e888f364..faa595a64b6 100644 --- a/xarray/testing/testing.py +++ b/xarray/testing/testing.py @@ -8,6 +8,7 @@ import pandas as pd from xarray.core import duck_array_ops, formatting, utils +from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes @@ -59,9 +60,9 @@ def assert_equal(a, b): Parameters ---------- - a : xarray.Dataset, xarray.DataArray or xarray.Variable + a : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The first object to compare. - b : xarray.Dataset, xarray.DataArray or xarray.Variable + b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The second object to compare. See Also @@ -70,11 +71,15 @@ def assert_equal(a, b): numpy.testing.assert_array_equal """ __tracebackhide__ = True - assert type(a) == type(b) + assert ( + type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates) + ) if isinstance(a, (Variable, DataArray)): assert a.equals(b), formatting.diff_array_repr(a, b, "equals") elif isinstance(a, Dataset): assert a.equals(b), formatting.diff_dataset_repr(a, b, "equals") + elif isinstance(a, Coordinates): + assert a.equals(b), formatting.diff_coords_repr(a, b, "equals") else: raise TypeError(f"{type(a)} not supported by assertion comparison") @@ -88,9 +93,9 @@ def assert_identical(a, b): Parameters ---------- - a : xarray.Dataset, xarray.DataArray or xarray.Variable + a : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The first object to compare. - b : xarray.Dataset, xarray.DataArray or xarray.Variable + b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The second object to compare. See Also @@ -98,7 +103,9 @@ def assert_identical(a, b): assert_equal, assert_allclose, Dataset.equals, DataArray.equals """ __tracebackhide__ = True - assert type(a) == type(b) + assert ( + type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates) + ) if isinstance(a, Variable): assert a.identical(b), formatting.diff_array_repr(a, b, "identical") elif isinstance(a, DataArray): @@ -106,6 +113,8 @@ def assert_identical(a, b): assert a.identical(b), formatting.diff_array_repr(a, b, "identical") elif isinstance(a, (Dataset, Variable)): assert a.identical(b), formatting.diff_dataset_repr(a, b, "identical") + elif isinstance(a, Coordinates): + assert a.identical(b), formatting.diff_coords_repr(a, b, "identical") else: raise TypeError(f"{type(a)} not supported by assertion comparison") @@ -346,7 +355,7 @@ def _assert_dataset_invariants(ds: Dataset, check_default_indexes: bool): set(ds._variables), ) - assert type(ds._dims) is dict, ds._dims + assert type(ds._dims) is dict, ds._dims # noqa: E721 assert all(isinstance(v, int) for v in ds._dims.values()), ds._dims var_dims: set[Hashable] = set() for v in ds._variables.values(): @@ -391,6 +400,10 @@ def _assert_internal_invariants( _assert_dataset_invariants( xarray_obj, check_default_indexes=check_default_indexes ) + elif isinstance(xarray_obj, Coordinates): + _assert_dataset_invariants( + xarray_obj.to_dataset(), check_default_indexes=check_default_indexes + ) else: raise TypeError( f"{type(xarray_obj)} is not a supported type for xarray invariant checks" diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 7e1b964ecba..07ba0be6a8c 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -75,7 +75,7 @@ def _importorskip( has_zarr, requires_zarr = _importorskip("zarr") has_fsspec, requires_fsspec = _importorskip("fsspec") has_iris, requires_iris = _importorskip("iris") -has_numbagg, requires_numbagg = _importorskip("numbagg") +has_numbagg, requires_numbagg = _importorskip("numbagg", "0.4.0") has_seaborn, requires_seaborn = _importorskip("seaborn") has_sparse, requires_sparse = _importorskip("sparse") has_cupy, requires_cupy = _importorskip("cupy") diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index 168d3232f81..dc325a84748 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -279,22 +279,20 @@ def test_case_bytes() -> None: def test_case_str() -> None: # This string includes some unicode characters # that are common case management corner cases - value = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) - - exp_capitalized = xr.DataArray(["Some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.unicode_) - exp_lowered = xr.DataArray(["some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.unicode_) - exp_swapped = xr.DataArray(["soME WoRD dž SS ᾛ σς FFI⁵å ç ⅰ"]).astype(np.unicode_) - exp_titled = xr.DataArray(["Some Word Dž Ss ᾛ Σς Ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) - exp_uppered = xr.DataArray(["SOME WORD DŽ SS ἫΙ ΣΣ FFI⁵Å Ç Ⅰ"]).astype(np.unicode_) - exp_casefolded = xr.DataArray(["some word dž ss ἣι σσ ffi⁵å ç ⅰ"]).astype( - np.unicode_ - ) - - exp_norm_nfc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) - exp_norm_nfkc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(np.unicode_) - exp_norm_nfd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + value = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.str_) + + exp_capitalized = xr.DataArray(["Some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.str_) + exp_lowered = xr.DataArray(["some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.str_) + exp_swapped = xr.DataArray(["soME WoRD dž SS ᾛ σς FFI⁵å ç ⅰ"]).astype(np.str_) + exp_titled = xr.DataArray(["Some Word Dž Ss ᾛ Σς Ffi⁵Å Ç Ⅰ"]).astype(np.str_) + exp_uppered = xr.DataArray(["SOME WORD DŽ SS ἫΙ ΣΣ FFI⁵Å Ç Ⅰ"]).astype(np.str_) + exp_casefolded = xr.DataArray(["some word dž ss ἣι σσ ffi⁵å ç ⅰ"]).astype(np.str_) + + exp_norm_nfc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.str_) + exp_norm_nfkc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(np.str_) + exp_norm_nfd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.str_) exp_norm_nfkd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype( - np.unicode_ + np.str_ ) res_capitalized = value.str.capitalize() @@ -680,7 +678,7 @@ def test_extract_extractall_name_collision_raises(dtype) -> None: def test_extract_single_case(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re) @@ -728,7 +726,7 @@ def test_extract_single_case(dtype) -> None: def test_extract_single_nocase(dtype) -> None: pat_str = r"(\w+)?_Xy_\d*" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re, flags=re.IGNORECASE) @@ -770,7 +768,7 @@ def test_extract_single_nocase(dtype) -> None: def test_extract_multi_case(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re) @@ -810,7 +808,7 @@ def test_extract_multi_case(dtype) -> None: def test_extract_multi_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re, flags=re.IGNORECASE) @@ -876,7 +874,7 @@ def test_extract_broadcast(dtype) -> None: def test_extractall_single_single_case(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re) @@ -908,7 +906,7 @@ def test_extractall_single_single_case(dtype) -> None: def test_extractall_single_single_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re, flags=re.I) @@ -937,7 +935,7 @@ def test_extractall_single_single_nocase(dtype) -> None: def test_extractall_single_multi_case(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re) @@ -983,7 +981,7 @@ def test_extractall_single_multi_case(dtype) -> None: def test_extractall_single_multi_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re, flags=re.I) @@ -1030,7 +1028,7 @@ def test_extractall_single_multi_nocase(dtype) -> None: def test_extractall_multi_single_case(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re) @@ -1065,7 +1063,7 @@ def test_extractall_multi_single_case(dtype) -> None: def test_extractall_multi_single_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re, flags=re.I) @@ -1097,7 +1095,7 @@ def test_extractall_multi_single_nocase(dtype) -> None: def test_extractall_multi_multi_case(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re) @@ -1147,7 +1145,7 @@ def test_extractall_multi_multi_case(dtype) -> None: def test_extractall_multi_multi_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re, flags=re.I) @@ -3419,12 +3417,12 @@ def test_cat_multi() -> None: values_4 = "" - values_5 = np.array("", dtype=np.unicode_) + values_5 = np.array("", dtype=np.str_) sep = xr.DataArray( [" ", ", "], dims=["ZZ"], - ).astype(np.unicode_) + ).astype(np.str_) expected = xr.DataArray( [ @@ -3440,7 +3438,7 @@ def test_cat_multi() -> None: ], ], dims=["X", "Y", "ZZ"], - ).astype(np.unicode_) + ).astype(np.str_) res = values_1.str.cat(values_2, values_3, values_4, values_5, sep=sep) @@ -3561,7 +3559,7 @@ def test_format_scalar() -> None: values = xr.DataArray( ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) pos0 = 1 pos1 = 1.2 @@ -3574,7 +3572,7 @@ def test_format_scalar() -> None: expected = xr.DataArray( ["1.X.None", "1,1.2,'test','test'", "'test'-X-None"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) @@ -3586,7 +3584,7 @@ def test_format_broadcast() -> None: values = xr.DataArray( ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) pos0 = 1 pos1 = 1.2 @@ -3608,7 +3606,7 @@ def test_format_broadcast() -> None: ["'test'-X-None", "'test'-X-None"], ], dims=["X", "YY"], - ).astype(np.unicode_) + ).astype(np.str_) res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) @@ -3620,7 +3618,7 @@ def test_mod_scalar() -> None: values = xr.DataArray( ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) pos0 = 1 pos1 = 1.2 @@ -3629,7 +3627,7 @@ def test_mod_scalar() -> None: expected = xr.DataArray( ["1.1.2.2.3", "1,1.2,2.3", "1-1.2-2.3"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) res = values.str % (pos0, pos1, pos2) @@ -3641,7 +3639,7 @@ def test_mod_dict() -> None: values = xr.DataArray( ["%(a)s.%(a)s.%(b)s", "%(b)s,%(c)s,%(b)s", "%(c)s-%(b)s-%(a)s"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) a = 1 b = 1.2 @@ -3650,7 +3648,7 @@ def test_mod_dict() -> None: expected = xr.DataArray( ["1.1.1.2", "1.2,2.3,1.2", "2.3-1.2-1"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) res = values.str % {"a": a, "b": b, "c": c} @@ -3662,7 +3660,7 @@ def test_mod_broadcast_single() -> None: values = xr.DataArray( ["%s_1", "%s_2", "%s_3"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) pos = xr.DataArray( ["2.3", "3.44444"], @@ -3672,7 +3670,7 @@ def test_mod_broadcast_single() -> None: expected = xr.DataArray( [["2.3_1", "3.44444_1"], ["2.3_2", "3.44444_2"], ["2.3_3", "3.44444_3"]], dims=["X", "YY"], - ).astype(np.unicode_) + ).astype(np.str_) res = values.str % pos @@ -3684,7 +3682,7 @@ def test_mod_broadcast_multi() -> None: values = xr.DataArray( ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) pos0 = 1 pos1 = 1.2 @@ -3701,7 +3699,7 @@ def test_mod_broadcast_multi() -> None: ["1-1.2-2.3", "1-1.2-3.44444"], ], dims=["X", "YY"], - ).astype(np.unicode_) + ).astype(np.str_) res = values.str % (pos0, pos1, pos2) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index dad2d668ff8..59e9f655b2e 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -13,11 +13,13 @@ import tempfile import uuid import warnings -from collections.abc import Iterator +from collections.abc import Generator, Iterator from contextlib import ExitStack from io import BytesIO +from os import listdir from pathlib import Path from typing import TYPE_CHECKING, Any, Final, cast +from unittest.mock import patch import numpy as np import pandas as pd @@ -53,7 +55,6 @@ from xarray.core.options import set_options from xarray.core.pycompat import array_type from xarray.tests import ( - arm_xfail, assert_allclose, assert_array_equal, assert_equal, @@ -525,7 +526,6 @@ def test_roundtrip_string_encoded_characters(self) -> None: assert_identical(expected, actual) assert actual["x"].encoding["_Encoding"] == "ascii" - @arm_xfail def test_roundtrip_numpy_datetime_data(self) -> None: times = pd.to_datetime(["2000-01-01", "2000-01-02", "NaT"]) expected = Dataset({"t": ("t", times), "t0": times[0]}) @@ -715,9 +715,6 @@ def multiple_indexing(indexers): ] multiple_indexing(indexers5) - @pytest.mark.xfail( - reason="zarr without dask handles negative steps in slices incorrectly", - ) def test_vectorized_indexing_negative_step(self) -> None: # use dask explicitly when present open_kwargs: dict[str, Any] | None @@ -813,7 +810,7 @@ def test_array_type_after_indexing(self) -> None: def test_dropna(self) -> None: # regression test for GH:issue:1694 a = np.random.randn(4, 3) - a[1, 1] = np.NaN + a[1, 1] = np.nan in_memory = xr.Dataset( {"a": (("y", "x"), a)}, coords={"y": np.arange(4), "x": np.arange(3)} ) @@ -1535,6 +1532,83 @@ def test_keep_chunksizes_if_no_original_shape(self) -> None: ds["x"].encoding["chunksizes"], actual["x"].encoding["chunksizes"] ) + def test_preferred_chunks_is_present(self) -> None: + ds = Dataset({"x": [1, 2, 3]}) + chunksizes = (2,) + ds.variables["x"].encoding = {"chunksizes": chunksizes} + + with self.roundtrip(ds) as actual: + assert actual["x"].encoding["preferred_chunks"] == {"x": 2} + + @requires_dask + def test_auto_chunking_is_based_on_disk_chunk_sizes(self) -> None: + x_size = y_size = 1000 + y_chunksize = y_size + x_chunksize = 10 + + with dask.config.set({"array.chunk-size": "100KiB"}): + with self.chunked_roundtrip( + (1, y_size, x_size), + (1, y_chunksize, x_chunksize), + open_kwargs={"chunks": "auto"}, + ) as ds: + t_chunks, y_chunks, x_chunks = ds["image"].data.chunks + assert all(np.asanyarray(y_chunks) == y_chunksize) + # Check that the chunk size is a multiple of the file chunk size + assert all(np.asanyarray(x_chunks) % x_chunksize == 0) + + @requires_dask + def test_base_chunking_uses_disk_chunk_sizes(self) -> None: + x_size = y_size = 1000 + y_chunksize = y_size + x_chunksize = 10 + + with self.chunked_roundtrip( + (1, y_size, x_size), + (1, y_chunksize, x_chunksize), + open_kwargs={"chunks": {}}, + ) as ds: + for chunksizes, expected in zip( + ds["image"].data.chunks, (1, y_chunksize, x_chunksize) + ): + assert all(np.asanyarray(chunksizes) == expected) + + @contextlib.contextmanager + def chunked_roundtrip( + self, + array_shape: tuple[int, int, int], + chunk_sizes: tuple[int, int, int], + open_kwargs: dict[str, Any] | None = None, + ) -> Generator[Dataset, None, None]: + t_size, y_size, x_size = array_shape + t_chunksize, y_chunksize, x_chunksize = chunk_sizes + + image = xr.DataArray( + np.arange(t_size * x_size * y_size, dtype=np.int16).reshape( + (t_size, y_size, x_size) + ), + dims=["t", "y", "x"], + ) + image.encoding = {"chunksizes": (t_chunksize, y_chunksize, x_chunksize)} + dataset = xr.Dataset(dict(image=image)) + + with self.roundtrip(dataset, open_kwargs=open_kwargs) as ds: + yield ds + + def test_preferred_chunks_are_disk_chunk_sizes(self) -> None: + x_size = y_size = 1000 + y_chunksize = y_size + x_chunksize = 10 + + with self.chunked_roundtrip( + (1, y_size, x_size), (1, y_chunksize, x_chunksize) + ) as ds: + assert ds["image"].encoding["preferred_chunks"] == { + "t": 1, + "y": y_chunksize, + "x": x_chunksize, + } + def test_encoding_chunksizes_unlimited(self) -> None: # regression test for GH1225 ds = Dataset({"x": [1, 2, 3], "y": ("x", [2, 3, 4])}) @@ -1766,8 +1840,8 @@ def test_unsorted_index_raises(self) -> None: # dask first pulls items by block. pass + @pytest.mark.skip(reason="caching behavior differs for dask") def test_dataset_caching(self) -> None: - # caching behavior differs for dask pass def test_write_inconsistent_chunks(self) -> None: @@ -1888,7 +1962,7 @@ def test_auto_chunk(self) -> None: assert v.chunks == original[k].chunks @requires_dask - @pytest.mark.filterwarnings("ignore:The specified Dask chunks separate") + @pytest.mark.filterwarnings("ignore:The specified chunks separate:UserWarning") def test_manual_chunk(self) -> None: original = create_test_data().chunk({"dim1": 3, "dim2": 4, "dim3": 3}) @@ -2185,9 +2259,6 @@ def test_encoding_kwarg_fixed_width_string(self) -> None: # not relevant for zarr, since we don't use EncodedStringCoder pass - # TODO: someone who understand caching figure out whether caching - # makes sense for Zarr backend - @pytest.mark.xfail(reason="Zarr caching not implemented") def test_dataset_caching(self) -> None: super().test_dataset_caching() @@ -2636,6 +2707,14 @@ def test_attributes(self, obj) -> None: with pytest.raises(TypeError, match=r"Invalid attribute in Dataset.attrs."): ds.to_zarr(store_target, **self.version_kwargs) + def test_vectorized_indexing_negative_step(self) -> None: + if not has_dask: + pytest.xfail( + reason="zarr without dask handles negative steps in slices incorrectly" + ) + + super().test_vectorized_indexing_negative_step() + @requires_zarr class TestZarrDictStore(ZarrBase): @@ -2665,6 +2744,86 @@ def create_store(self): yield group +@requires_zarr +class TestZarrWriteEmpty(TestZarrDirectoryStore): + @contextlib.contextmanager + def temp_dir(self) -> Iterator[tuple[str, str]]: + with tempfile.TemporaryDirectory() as d: + store = os.path.join(d, "test.zarr") + yield d, store + + @contextlib.contextmanager + def roundtrip_dir( + self, + data, + store, + save_kwargs=None, + open_kwargs=None, + allow_cleanup_failure=False, + ) -> Iterator[Dataset]: + if save_kwargs is None: + save_kwargs = {} + if open_kwargs is None: + open_kwargs = {} + + data.to_zarr(store, **save_kwargs, **self.version_kwargs) + with xr.open_dataset( + store, engine="zarr", **open_kwargs, **self.version_kwargs + ) as ds: + yield ds + + @pytest.mark.parametrize("write_empty", [True, False]) + def test_write_empty(self, write_empty: bool) -> None: + if not write_empty: + expected = ["0.1.0", "1.1.0"] + else: + expected = [ + "0.0.0", + "0.0.1", + "0.1.0", + "0.1.1", + "1.0.0", + "1.0.1", + "1.1.0", + "1.1.1", + ] + + ds = xr.Dataset( + data_vars={ + "test": ( + ("Z", "Y", "X"), + np.array([np.nan, np.nan, 1.0, np.nan]).reshape((1, 2, 2)), + ) + } + ) + + if has_dask: + ds["test"] = ds["test"].chunk(1) + encoding = None + else: + encoding = {"test": {"chunks": (1, 1, 1)}} + + with self.temp_dir() as (d, store): + ds.to_zarr( + store, + mode="w", + encoding=encoding, + write_empty_chunks=write_empty, + ) + + with self.roundtrip_dir( + ds, + store, + {"mode": "a", "append_dim": "Z", "write_empty_chunks": write_empty}, + ) as a_ds: + expected_ds = xr.concat([ds, ds], dim="Z") + + assert_identical(a_ds, expected_ds) + + ls = listdir(os.path.join(store, "test")) + assert set(expected) == set([file for file in ls if file[0] != "."]) + + class ZarrBaseV3(ZarrBase): zarr_version = 3 @@ -2704,6 +2863,47 @@ def create_zarr_target(self): yield tmp +@requires_zarr +class TestZarrArrayWrapperCalls(TestZarrKVStoreV3): + def test_avoid_excess_metadata_calls(self) -> None: + """Test that chunk requests do not trigger redundant metadata requests. + + This test targets logic in backends.zarr.ZarrArrayWrapper, asserting that calls + to retrieve chunk data after initialization do not trigger additional + metadata requests. + + https://github.com/pydata/xarray/issues/8290 + """ + + import zarr + + ds = xr.Dataset(data_vars={"test": (("Z",), np.array([123]).reshape(1))}) + + # The call to retrieve metadata performs a group lookup. We patch Group.__getitem__ + # so that we can inspect calls to this method - specifically count of calls. + # Use of side_effect means that calls are passed through to the original method + # rather than a mocked method. + Group = zarr.hierarchy.Group + with ( + self.create_zarr_target() as store, + patch.object( + Group, "__getitem__", side_effect=Group.__getitem__, autospec=True + ) as mock, + ): + ds.to_zarr(store, mode="w") + + # We expect this to request array metadata information, so call_count should be >= 1, + # At time of writing, 2 calls are made + xrds = xr.open_zarr(store) + call_count = mock.call_count + assert call_count > 0 + + # compute() requests array data, which should not trigger additional metadata requests + # we assert that the number of calls has not increased after fetchhing the array + xrds.test.compute(scheduler="sync") + assert mock.call_count == call_count + + @requires_zarr @requires_fsspec def test_zarr_storage_options() -> None: @@ -2928,8 +3128,9 @@ def create_store(self): def test_complex(self) -> None: expected = Dataset({"x": ("y", np.ones(5) + 1j * np.ones(5))}) save_kwargs = {"invalid_netcdf": True} - with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: - assert_equal(expected, actual) + with pytest.warns(UserWarning, match="You are writing invalid netcdf features"): + with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: + assert_equal(expected, actual) @pytest.mark.parametrize("invalid_netcdf", [None, False]) def test_complex_error(self, invalid_netcdf) -> None: @@ -2943,14 +3144,14 @@ def test_complex_error(self, invalid_netcdf) -> None: with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: assert_equal(expected, actual) - @pytest.mark.filterwarnings("ignore:You are writing invalid netcdf features") def test_numpy_bool_(self) -> None: # h5netcdf loads booleans as numpy.bool_, this type needs to be supported # when writing invalid_netcdf datasets in order to support a roundtrip expected = Dataset({"x": ("y", np.ones(5), {"numpy_bool": np.bool_(True)})}) save_kwargs = {"invalid_netcdf": True} - with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: - assert_identical(expected, actual) + with pytest.warns(UserWarning, match="You are writing invalid netcdf features"): + with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: + assert_identical(expected, actual) def test_cross_engine_read_write_netcdf4(self) -> None: # Drop dim3, because its labels include strings. These appear to be @@ -3221,8 +3422,8 @@ def roundtrip( ) as ds: yield ds + @pytest.mark.skip(reason="caching behavior differs for dask") def test_dataset_caching(self) -> None: - # caching behavior differs for dask pass def test_write_inconsistent_chunks(self) -> None: @@ -3300,6 +3501,7 @@ def skip_if_not_engine(engine): @requires_dask @pytest.mark.filterwarnings("ignore:use make_scale(name) instead") +@pytest.mark.xfail(reason="Flaky test. Very open to contributions on fixing this") def test_open_mfdataset_manyfiles( readengine, nfiles, parallel, chunks, file_cache_maxsize ): @@ -3825,7 +4027,6 @@ def test_open_mfdataset_raise_on_bad_combine_args(self) -> None: with pytest.raises(ValueError, match="`concat_dim` has no effect"): open_mfdataset([tmp1, tmp2], concat_dim="x") - @pytest.mark.xfail(reason="mfdataset loses encoding currently.") def test_encoding_mfdataset(self) -> None: original = Dataset( { @@ -4038,7 +4239,6 @@ def test_dataarray_compute(self) -> None: assert computed._in_memory assert_allclose(actual, computed, decode_bytes=False) - @pytest.mark.xfail def test_save_mfdataset_compute_false_roundtrip(self) -> None: from dask.delayed import Delayed @@ -4968,15 +5168,17 @@ def test_open_fsspec() -> None: ds2 = open_dataset(url, engine="zarr") xr.testing.assert_equal(ds0, ds2) - # multi dataset - url = "memory://out*.zarr" - ds2 = open_mfdataset(url, engine="zarr") - xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) + # open_mfdataset requires dask + if has_dask: + # multi dataset + url = "memory://out*.zarr" + ds2 = open_mfdataset(url, engine="zarr") + xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) - # multi dataset with caching - url = "simplecache::memory://out*.zarr" - ds2 = open_mfdataset(url, engine="zarr") - xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) + # multi dataset with caching + url = "simplecache::memory://out*.zarr" + ds2 = open_mfdataset(url, engine="zarr") + xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) @requires_h5netcdf @@ -5035,7 +5237,7 @@ def test_open_dataset_chunking_zarr(chunks, tmp_path: Path) -> None: @pytest.mark.parametrize( "chunks", ["auto", -1, {}, {"x": "auto"}, {"x": -1}, {"x": "auto", "y": -1}] ) -@pytest.mark.filterwarnings("ignore:The specified Dask chunks separate") +@pytest.mark.filterwarnings("ignore:The specified chunks separate") def test_chunking_consintency(chunks, tmp_path: Path) -> None: encoded_chunks: dict[str, Any] = {} dask_arr = da.from_array( diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index 24ffab305ad..5f13415d925 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -1,6 +1,7 @@ from __future__ import annotations from itertools import product +from typing import Callable, Literal import numpy as np import pandas as pd @@ -1215,7 +1216,7 @@ def test_cftime_range_name(): @pytest.mark.parametrize( - ("start", "end", "periods", "freq", "closed"), + ("start", "end", "periods", "freq", "inclusive"), [ (None, None, 5, "A", None), ("2000", None, None, "A", None), @@ -1226,9 +1227,22 @@ def test_cftime_range_name(): ("2000", "2001", 5, "A", None), ], ) -def test_invalid_cftime_range_inputs(start, end, periods, freq, closed): +def test_invalid_cftime_range_inputs( + start: str | None, + end: str | None, + periods: int | None, + freq: str | None, + inclusive: Literal["up", None], +) -> None: with pytest.raises(ValueError): - cftime_range(start, end, periods, freq, closed=closed) + cftime_range(start, end, periods, freq, inclusive=inclusive) # type: ignore[arg-type] + + +def test_invalid_cftime_arg() -> None: + with pytest.warns( + FutureWarning, match="Following pandas, the `closed` parameter is deprecated" + ): + cftime_range("2000", "2001", None, "A", closed="left") _CALENDAR_SPECIFIC_MONTH_END_TESTS = [ @@ -1246,7 +1260,9 @@ def test_invalid_cftime_range_inputs(start, end, periods, freq, closed): _CALENDAR_SPECIFIC_MONTH_END_TESTS, ids=_id_func, ) -def test_calendar_specific_month_end(freq, calendar, expected_month_day): +def test_calendar_specific_month_end( + freq: str, calendar: str, expected_month_day: list[tuple[int, int]] +) -> None: year = 2000 # Use a leap-year to highlight calendar differences result = cftime_range( start="2000-02", end="2001", freq=freq, calendar=calendar @@ -1273,26 +1289,28 @@ def test_calendar_specific_month_end(freq, calendar, expected_month_day): ("julian", "2001", "2002", 365), ], ) -def test_calendar_year_length(calendar, start, end, expected_number_of_days): - result = cftime_range(start, end, freq="D", closed="left", calendar=calendar) +def test_calendar_year_length( + calendar: str, start: str, end: str, expected_number_of_days: int +) -> None: + result = cftime_range(start, end, freq="D", inclusive="left", calendar=calendar) assert len(result) == expected_number_of_days @pytest.mark.parametrize("freq", ["A", "M", "D"]) -def test_dayofweek_after_cftime_range(freq): +def test_dayofweek_after_cftime_range(freq: str) -> None: result = cftime_range("2000-02-01", periods=3, freq=freq).dayofweek expected = pd.date_range("2000-02-01", periods=3, freq=freq).dayofweek np.testing.assert_array_equal(result, expected) @pytest.mark.parametrize("freq", ["A", "M", "D"]) -def test_dayofyear_after_cftime_range(freq): +def test_dayofyear_after_cftime_range(freq: str) -> None: result = cftime_range("2000-02-01", periods=3, freq=freq).dayofyear expected = pd.date_range("2000-02-01", periods=3, freq=freq).dayofyear np.testing.assert_array_equal(result, expected) -def test_cftime_range_standard_calendar_refers_to_gregorian(): +def test_cftime_range_standard_calendar_refers_to_gregorian() -> None: from cftime import DatetimeGregorian (result,) = cftime_range("2000", periods=1) @@ -1310,7 +1328,9 @@ def test_cftime_range_standard_calendar_refers_to_gregorian(): ("3400-01-01", "standard", None, CFTimeIndex), ], ) -def test_date_range(start, calendar, use_cftime, expected_type): +def test_date_range( + start: str, calendar: str, use_cftime: bool | None, expected_type +) -> None: dr = date_range( start, periods=14, freq="D", calendar=calendar, use_cftime=use_cftime ) @@ -1318,7 +1338,7 @@ def test_date_range(start, calendar, use_cftime, expected_type): assert isinstance(dr, expected_type) -def test_date_range_errors(): +def test_date_range_errors() -> None: with pytest.raises(ValueError, match="Date range is invalid"): date_range( "1400-01-01", periods=1, freq="D", calendar="standard", use_cftime=False @@ -1412,7 +1432,7 @@ def as_timedelta_not_implemented_error(): @pytest.mark.parametrize("function", [cftime_range, date_range]) -def test_cftime_or_date_range_closed_and_inclusive_error(function) -> None: +def test_cftime_or_date_range_closed_and_inclusive_error(function: Callable) -> None: if function == cftime_range and not has_cftime: pytest.skip("requires cftime") @@ -1421,22 +1441,29 @@ def test_cftime_or_date_range_closed_and_inclusive_error(function) -> None: @pytest.mark.parametrize("function", [cftime_range, date_range]) -def test_cftime_or_date_range_invalid_closed_value(function) -> None: +def test_cftime_or_date_range_invalid_inclusive_value(function: Callable) -> None: if function == cftime_range and not has_cftime: pytest.skip("requires cftime") - with pytest.raises(ValueError, match="Argument `closed` must be"): - function("2000", periods=3, closed="foo") + with pytest.raises(ValueError, match="nclusive"): + function("2000", periods=3, inclusive="foo") -@pytest.mark.parametrize("function", [cftime_range, date_range]) +@pytest.mark.parametrize( + "function", + [ + pytest.param(cftime_range, id="cftime", marks=requires_cftime), + pytest.param(date_range, id="date"), + ], +) @pytest.mark.parametrize( ("closed", "inclusive"), [(None, "both"), ("left", "left"), ("right", "right")] ) -def test_cftime_or_date_range_closed(function, closed, inclusive) -> None: - if function == cftime_range and not has_cftime: - pytest.skip("requires cftime") - +def test_cftime_or_date_range_closed( + function: Callable, + closed: Literal["left", "right", None], + inclusive: Literal["left", "right", "both"], +) -> None: with pytest.warns(FutureWarning, match="Following pandas"): result_closed = function("2000-01-01", "2000-01-04", freq="D", closed=closed) result_inclusive = function( diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index a676b1f07f1..1a1df6b81fe 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -7,7 +7,6 @@ import numpy as np import pandas as pd import pytest -from packaging.version import Version import xarray as xr from xarray.coding.cftimeindex import ( @@ -33,12 +32,7 @@ # cftime 1.5.2 renames "gregorian" to "standard" standard_or_gregorian = "" if has_cftime: - import cftime - - if Version(cftime.__version__) >= Version("1.5.2"): - standard_or_gregorian = "standard" - else: - standard_or_gregorian = "gregorian" + standard_or_gregorian = "standard" def date_dict(year=None, month=None, day=None, hour=None, minute=None, second=None): @@ -1141,7 +1135,6 @@ def test_to_datetimeindex_feb_29(calendar): @requires_cftime -@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/24263") def test_multiindex(): index = xr.cftime_range("2001-01-01", periods=100, calendar="360_day") mindex = pd.MultiIndex.from_arrays([index]) diff --git a/xarray/tests/test_cftimeindex_resample.py b/xarray/tests/test_cftimeindex_resample.py index 07bc14f8983..f8e6e80452a 100644 --- a/xarray/tests/test_cftimeindex_resample.py +++ b/xarray/tests/test_cftimeindex_resample.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import pytest +from packaging.version import Version import xarray as xr from xarray.core.pdcompat import _convert_base_to_offset @@ -122,6 +123,16 @@ def da(index) -> xr.DataArray: ) def test_resample(freqs, closed, label, base, offset) -> None: initial_freq, resample_freq = freqs + if ( + resample_freq == "4001D" + and closed == "right" + and Version(pd.__version__) < Version("2.2") + ): + pytest.skip( + "Pandas fixed a bug in this test case in version 2.2, which we " + "ported to xarray, so this test no longer produces the same " + "result as pandas for earlier pandas versions." + ) start = "2000-01-01T12:07:01" loffset = "12H" origin = "start" @@ -169,7 +180,7 @@ def test_closed_label_defaults(freq, expected) -> None: @pytest.mark.parametrize( "calendar", ["gregorian", "noleap", "all_leap", "360_day", "julian"] ) -def test_calendars(calendar) -> None: +def test_calendars(calendar: str) -> None: # Limited testing for non-standard calendars freq, closed, label, base = "8001T", None, None, 17 loffset = datetime.timedelta(hours=12) diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py index d58361afdd3..01d5393e289 100644 --- a/xarray/tests/test_coarsen.py +++ b/xarray/tests/test_coarsen.py @@ -6,6 +6,7 @@ import xarray as xr from xarray import DataArray, Dataset, set_options +from xarray.core import duck_array_ops from xarray.tests import ( assert_allclose, assert_equal, @@ -17,7 +18,10 @@ def test_coarsen_absent_dims_error(ds: Dataset) -> None: - with pytest.raises(ValueError, match=r"not found in Dataset."): + with pytest.raises( + ValueError, + match=r"Window dimensions \('foo',\) not found in Dataset dimensions", + ): ds.coarsen(foo=2) @@ -269,21 +273,24 @@ def test_coarsen_construct(self, dask: bool) -> None: expected = xr.Dataset(attrs={"foo": "bar"}) expected["vart"] = ( ("year", "month"), - ds.vart.data.reshape((-1, 12)), + duck_array_ops.reshape(ds.vart.data, (-1, 12)), {"a": "b"}, ) expected["varx"] = ( ("x", "x_reshaped"), - ds.varx.data.reshape((-1, 5)), + duck_array_ops.reshape(ds.varx.data, (-1, 5)), {"a": "b"}, ) expected["vartx"] = ( ("x", "x_reshaped", "year", "month"), - ds.vartx.data.reshape(2, 5, 4, 12), + duck_array_ops.reshape(ds.vartx.data, (2, 5, 4, 12)), {"a": "b"}, ) expected["vary"] = ds.vary - expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12))) + expected.coords["time"] = ( + ("year", "month"), + duck_array_ops.reshape(ds.time.data, (-1, 12)), + ) with raise_if_dask_computes(): actual = ds.coarsen(time=12, x=5).construct( diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index 0c9f67e77ad..f1eca00f9a1 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -33,7 +33,7 @@ def test_vlen_dtype() -> None: assert strings.check_vlen_dtype(dtype) is bytes # check h5py variant ("vlen") - dtype = np.dtype("O", metadata={"vlen": str}) # type: ignore[call-overload] + dtype = np.dtype("O", metadata={"vlen": str}) # type: ignore[call-overload,unused-ignore] assert strings.check_vlen_dtype(dtype) is str assert strings.check_vlen_dtype(np.dtype(object)) is None diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 580de878fe6..423e48bd155 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -20,6 +20,7 @@ ) from xarray.coding.times import ( _encode_datetime_with_cftime, + _numpy_to_netcdf_timeunit, _should_cftime_be_used, cftime_to_nptime, decode_cf_datetime, @@ -110,6 +111,7 @@ def _all_cftime_date_types(): @requires_cftime @pytest.mark.filterwarnings("ignore:Ambiguous reference date string") +@pytest.mark.filterwarnings("ignore:Times can't be serialized faithfully") @pytest.mark.parametrize(["num_dates", "units", "calendar"], _CF_DATETIME_TESTS) def test_cf_datetime(num_dates, units, calendar) -> None: import cftime @@ -567,6 +569,7 @@ def test_infer_cftime_datetime_units(calendar, date_args, expected) -> None: assert expected == coding.times.infer_datetime_units(dates) +@pytest.mark.filterwarnings("ignore:Timedeltas can't be serialized faithfully") @pytest.mark.parametrize( ["timedeltas", "units", "numbers"], [ @@ -576,10 +579,10 @@ def test_infer_cftime_datetime_units(calendar, date_args, expected) -> None: ("1ms", "milliseconds", np.int64(1)), ("1us", "microseconds", np.int64(1)), ("1ns", "nanoseconds", np.int64(1)), - (["NaT", "0s", "1s"], None, [np.nan, 0, 1]), + (["NaT", "0s", "1s"], None, [np.iinfo(np.int64).min, 0, 1]), (["30m", "60m"], "hours", [0.5, 1.0]), - ("NaT", "days", np.nan), - (["NaT", "NaT"], "days", [np.nan, np.nan]), + ("NaT", "days", np.iinfo(np.int64).min), + (["NaT", "NaT"], "days", [np.iinfo(np.int64).min, np.iinfo(np.int64).min]), ], ) def test_cf_timedelta(timedeltas, units, numbers) -> None: @@ -1020,6 +1023,7 @@ def test_decode_ambiguous_time_warns(calendar) -> None: np.testing.assert_array_equal(result, expected) +@pytest.mark.filterwarnings("ignore:Times can't be serialized faithfully") @pytest.mark.parametrize("encoding_units", FREQUENCIES_TO_ENCODING_UNITS.values()) @pytest.mark.parametrize("freq", FREQUENCIES_TO_ENCODING_UNITS.keys()) @pytest.mark.parametrize("date_range", [pd.date_range, cftime_range]) @@ -1032,7 +1036,7 @@ def test_encode_cf_datetime_defaults_to_correct_dtype( pytest.skip("Nanosecond frequency is not valid for cftime dates.") times = date_range("2000", periods=3, freq=freq) units = f"{encoding_units} since 2000-01-01" - encoded, _, _ = coding.times.encode_cf_datetime(times, units) + encoded, _units, _ = coding.times.encode_cf_datetime(times, units) numpy_timeunit = coding.times._netcdf_to_numpy_timeunit(encoding_units) encoding_units_as_timedelta = np.timedelta64(1, numpy_timeunit) @@ -1191,3 +1195,195 @@ def test_contains_cftime_lazy() -> None: ) array = FirstElementAccessibleArray(times) assert _contains_cftime_datetimes(array) + + +@pytest.mark.parametrize( + "timestr, timeunit, dtype, fill_value, use_encoding", + [ + ("1677-09-21T00:12:43.145224193", "ns", np.int64, 20, True), + ("1970-09-21T00:12:44.145224808", "ns", np.float64, 1e30, True), + ( + "1677-09-21T00:12:43.145225216", + "ns", + np.float64, + -9.223372036854776e18, + True, + ), + ("1677-09-21T00:12:43.145224193", "ns", np.int64, None, False), + ("1677-09-21T00:12:43.145225", "us", np.int64, None, False), + ("1970-01-01T00:00:01.000001", "us", np.int64, None, False), + ("1677-09-21T00:21:52.901038080", "ns", np.float32, 20.0, True), + ], +) +def test_roundtrip_datetime64_nanosecond_precision( + timestr: str, + timeunit: str, + dtype: np.typing.DTypeLike, + fill_value: int | float | None, + use_encoding: bool, +) -> None: + # test for GH7817 + time = np.datetime64(timestr, timeunit) + times = [np.datetime64("1970-01-01T00:00:00", timeunit), np.datetime64("NaT"), time] + + if use_encoding: + encoding = dict(dtype=dtype, _FillValue=fill_value) + else: + encoding = {} + + var = Variable(["time"], times, encoding=encoding) + assert var.dtype == np.dtype(" None: + # test warning if times can't be serialized faithfully + times = [ + np.datetime64("1970-01-01T00:01:00", "ns"), + np.datetime64("NaT"), + np.datetime64("1970-01-02T00:01:00", "ns"), + ] + units = "days since 1970-01-10T01:01:00" + needed_units = "hours" + new_units = f"{needed_units} since 1970-01-10T01:01:00" + + encoding = dict(dtype=None, _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with pytest.warns(UserWarning, match=f"Resolution of {needed_units!r} needed."): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.float64 + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == 20.0 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="int64", _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with pytest.warns( + UserWarning, match=f"Serializing with units {new_units!r} instead." + ): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == new_units + assert encoded_var.attrs["_FillValue"] == 20 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="float64", _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with warnings.catch_warnings(): + warnings.simplefilter("error") + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.float64 + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == 20.0 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="int64", _FillValue=20, units=new_units) + var = Variable(["time"], times, encoding=encoding) + with warnings.catch_warnings(): + warnings.simplefilter("error") + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == new_units + assert encoded_var.attrs["_FillValue"] == 20 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + +@pytest.mark.parametrize( + "dtype, fill_value", + [(np.int64, 20), (np.int64, np.iinfo(np.int64).min), (np.float64, 1e30)], +) +def test_roundtrip_timedelta64_nanosecond_precision( + dtype: np.typing.DTypeLike, fill_value: int | float +) -> None: + # test for GH7942 + one_day = np.timedelta64(1, "ns") + nat = np.timedelta64("nat", "ns") + timedelta_values = (np.arange(5) * one_day).astype("timedelta64[ns]") + timedelta_values[2] = nat + timedelta_values[4] = nat + + encoding = dict(dtype=dtype, _FillValue=fill_value) + var = Variable(["time"], timedelta_values, encoding=encoding) + + encoded_var = conventions.encode_cf_variable(var) + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + + assert_identical(var, decoded_var) + + +def test_roundtrip_timedelta64_nanosecond_precision_warning() -> None: + # test warning if timedeltas can't be serialized faithfully + one_day = np.timedelta64(1, "D") + nat = np.timedelta64("nat", "ns") + timedelta_values = (np.arange(5) * one_day).astype("timedelta64[ns]") + timedelta_values[2] = nat + timedelta_values[4] = np.timedelta64(12, "h").astype("timedelta64[ns]") + + units = "days" + needed_units = "hours" + wmsg = ( + f"Timedeltas can't be serialized faithfully with requested units {units!r}. " + f"Serializing with units {needed_units!r} instead." + ) + encoding = dict(dtype=np.int64, _FillValue=20, units=units) + var = Variable(["time"], timedelta_values, encoding=encoding) + with pytest.warns(UserWarning, match=wmsg): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == needed_units + assert encoded_var.attrs["_FillValue"] == 20 + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + assert decoded_var.encoding["dtype"] == np.int64 + + +def test_roundtrip_float_times() -> None: + # Regression test for GitHub issue #8271 + fill_value = 20.0 + times = [ + np.datetime64("1970-01-01 00:00:00", "ns"), + np.datetime64("1970-01-01 06:00:00", "ns"), + np.datetime64("NaT", "ns"), + ] + + units = "days since 1960-01-01" + var = Variable( + ["time"], + times, + encoding=dict(dtype=np.float64, _FillValue=fill_value, units=units), + ) + + encoded_var = conventions.encode_cf_variable(var) + np.testing.assert_array_equal(encoded_var, np.array([3653, 3653.25, 20.0])) + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == fill_value + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + assert decoded_var.encoding["units"] == units + assert decoded_var.encoding["_FillValue"] == fill_value diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 3c10e3f27ab..e7eac068e97 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -257,6 +257,168 @@ def func(x): assert_identical(out1, dataset) +def test_apply_missing_dims() -> None: + ## Single arg + + def add_one(a, core_dims, on_missing_core_dim): + return apply_ufunc( + lambda x: x + 1, + a, + input_core_dims=core_dims, + output_core_dims=core_dims, + on_missing_core_dim=on_missing_core_dim, + ) + + array = np.arange(6).reshape(2, 3) + variable = xr.Variable(["x", "y"], array) + variable_no_y = xr.Variable(["x", "z"], array) + + ds = xr.Dataset({"x_y": variable, "x_z": variable_no_y}) + + # Check the standard stuff works OK + assert_identical( + add_one(ds[["x_y"]], core_dims=[["y"]], on_missing_core_dim="raise"), + ds[["x_y"]] + 1, + ) + + # `raise` — should raise on a missing dim + with pytest.raises(ValueError): + add_one(ds, core_dims=[["y"]], on_missing_core_dim="raise") + + # `drop` — should drop the var with the missing dim + assert_identical( + add_one(ds, core_dims=[["y"]], on_missing_core_dim="drop"), + (ds + 1).drop_vars("x_z"), + ) + + # `copy` — should not add one to the missing with `copy` + copy_result = add_one(ds, core_dims=[["y"]], on_missing_core_dim="copy") + assert_identical(copy_result["x_y"], (ds + 1)["x_y"]) + assert_identical(copy_result["x_z"], ds["x_z"]) + + ## Multiple args + + def sum_add(a, b, core_dims, on_missing_core_dim): + return apply_ufunc( + lambda a, b, axis=None: a.sum(axis) + b.sum(axis), + a, + b, + input_core_dims=core_dims, + on_missing_core_dim=on_missing_core_dim, + ) + + # Check the standard stuff works OK + assert_identical( + sum_add( + ds[["x_y"]], + ds[["x_y"]], + core_dims=[["x", "y"], ["x", "y"]], + on_missing_core_dim="raise", + ), + ds[["x_y"]].sum() * 2, + ) + + # `raise` — should raise on a missing dim + with pytest.raises( + ValueError, + match=r".*Missing core dims \{'y'\} from arg number 1 on a variable named `x_z`:\n.* None: data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) @@ -1028,7 +1190,7 @@ def test_apply_dask() -> None: # unknown setting for dask array handling with pytest.raises(ValueError): - apply_ufunc(identity, array, dask="unknown") + apply_ufunc(identity, array, dask="unknown") # type: ignore def dask_safe_identity(x): return apply_ufunc(identity, x, dask="allowed") @@ -1087,7 +1249,7 @@ def check(x, y): assert actual.data.chunks == array.chunks assert_identical(data_array, actual) - check(data_array, 0), + check(data_array, 0) check(0, data_array) check(data_array, xr.DataArray(0)) check(data_array, 0 * data_array) @@ -1666,7 +1828,10 @@ def identity(x): def tuple3x(x): return (x, x, x) - with pytest.raises(ValueError, match=r"number of outputs"): + with pytest.raises( + ValueError, + match=r"number of outputs.* Received a with 10 elements. Expected a tuple of 2 elements:\n\narray\(\[0", + ): apply_ufunc(identity, variable, output_core_dims=[(), ()]) with pytest.raises(ValueError, match=r"number of outputs"): @@ -1682,7 +1847,10 @@ def add_dim(x): def remove_dim(x): return x[..., 0] - with pytest.raises(ValueError, match=r"unexpected number of dimensions"): + with pytest.raises( + ValueError, + match=r"unexpected number of dimensions.*from:\n\n.*array\(\[\[0", + ): apply_ufunc(add_dim, variable, output_core_dims=[("y", "z")]) with pytest.raises(ValueError, match=r"unexpected number of dimensions"): diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 030f653e031..11d0d38594d 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -9,6 +9,7 @@ from xarray import DataArray, Dataset, Variable, concat from xarray.core import dtypes, merge +from xarray.core.coordinates import Coordinates from xarray.core.indexes import PandasIndex from xarray.tests import ( InaccessibleArray, @@ -614,11 +615,14 @@ def test_concat_errors(self): with pytest.raises(ValueError, match=r"must supply at least one"): concat([], "dim1") - with pytest.raises(ValueError, match=r"are not coordinates"): + with pytest.raises(ValueError, match=r"are not found in the coordinates"): concat([data, data], "new_dim", coords=["not_found"]) + with pytest.raises(ValueError, match=r"are not found in the data variables"): + concat([data, data], "new_dim", data_vars=["not_found"]) + with pytest.raises(ValueError, match=r"global attributes not"): - # call deepcopy seperately to get unique attrs + # call deepcopy separately to get unique attrs data0 = deepcopy(split_data[0]) data1 = deepcopy(split_data[1]) data1.attrs["foo"] = "bar" @@ -906,8 +910,9 @@ def test_concat_dim_is_dataarray(self) -> None: assert_identical(actual, expected) def test_concat_multiindex(self) -> None: - x = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]]) - expected = Dataset(coords={"x": x}) + midx = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + expected = Dataset(coords=midx_coords) actual = concat( [expected.isel(x=slice(2)), expected.isel(x=slice(2, None))], "x" ) @@ -917,8 +922,9 @@ def test_concat_multiindex(self) -> None: def test_concat_along_new_dim_multiindex(self) -> None: # see https://github.com/pydata/xarray/issues/6881 level_names = ["x_level_0", "x_level_1"] - x = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]], names=level_names) - ds = Dataset(coords={"x": x}) + midx = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]], names=level_names) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + ds = Dataset(coords=midx_coords) concatenated = concat([ds], "new") actual = list(concatenated.xindexes.get_all_coords("x")) expected = ["x"] + level_names diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 424b7db5ac4..d6d1303a696 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -80,6 +80,28 @@ def test_decode_cf_with_conflicting_fill_missing_value() -> None: assert_identical(actual, expected) +def test_decode_cf_variable_with_mismatched_coordinates() -> None: + # tests for decoding mismatched coordinates attributes + # see GH #1809 + zeros1 = np.zeros((1, 5, 3)) + orig = Dataset( + { + "XLONG": (["x", "y"], zeros1.squeeze(0), {}), + "XLAT": (["x", "y"], zeros1.squeeze(0), {}), + "foo": (["time", "x", "y"], zeros1, {"coordinates": "XTIME XLONG XLAT"}), + "time": ("time", [0.0], {"units": "hours since 2017-01-01"}), + } + ) + decoded = conventions.decode_cf(orig, decode_coords=True) + assert decoded["foo"].encoding["coordinates"] == "XTIME XLONG XLAT" + assert list(decoded.coords.keys()) == ["XLONG", "XLAT", "time"] + + decoded = conventions.decode_cf(orig, decode_coords=False) + assert "coordinates" not in decoded["foo"].encoding + assert decoded["foo"].attrs.get("coordinates") == "XTIME XLONG XLAT" + assert list(decoded.coords.keys()) == ["time"] + + @requires_cftime class TestEncodeCFVariable: def test_incompatible_attributes(self) -> None: @@ -129,9 +151,9 @@ def test_multidimensional_coordinates(self) -> None: foo1_coords = enc["foo1"].attrs.get("coordinates", "") foo2_coords = enc["foo2"].attrs.get("coordinates", "") foo3_coords = enc["foo3"].attrs.get("coordinates", "") - assert set(foo1_coords.split()) == {"lat1", "lon1"} - assert set(foo2_coords.split()) == {"lat2", "lon2"} - assert set(foo3_coords.split()) == {"lat3", "lon3"} + assert foo1_coords == "lon1 lat1" + assert foo2_coords == "lon2 lat2" + assert foo3_coords == "lon3 lat3" # Should not have any global coordinates. assert "coordinates" not in attrs @@ -150,11 +172,12 @@ def test_var_with_coord_attr(self) -> None: enc, attrs = conventions.encode_dataset_coordinates(orig) # Make sure we have the right coordinates for each variable. values_coords = enc["values"].attrs.get("coordinates", "") - assert set(values_coords.split()) == {"time", "lat", "lon"} + assert values_coords == "time lon lat" # Should not have any global coordinates. assert "coordinates" not in attrs def test_do_not_overwrite_user_coordinates(self) -> None: + # don't overwrite user-defined "coordinates" encoding orig = Dataset( coords={"x": [0, 1, 2], "y": ("x", [5, 6, 7]), "z": ("x", [8, 9, 10])}, data_vars={"a": ("x", [1, 2, 3]), "b": ("x", [3, 5, 6])}, @@ -168,6 +191,18 @@ def test_do_not_overwrite_user_coordinates(self) -> None: with pytest.raises(ValueError, match=r"'coordinates' found in both attrs"): conventions.encode_dataset_coordinates(orig) + def test_deterministic_coords_encoding(self) -> None: + # the coordinates attribute is sorted when set by xarray.conventions ... + # ... on a variable's coordinates attribute + ds = Dataset({"foo": 0}, coords={"baz": 0, "bar": 0}) + vars, attrs = conventions.encode_dataset_coordinates(ds) + assert vars["foo"].attrs["coordinates"] == "bar baz" + assert attrs.get("coordinates") is None + # ... on the global coordinates attribute + ds = ds.drop_vars("foo") + vars, attrs = conventions.encode_dataset_coordinates(ds) + assert attrs["coordinates"] == "bar baz" + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_emit_coordinates_attribute_in_attrs(self) -> None: orig = Dataset( @@ -233,9 +268,12 @@ def test_dataset(self) -> None: assert_identical(expected, actual) def test_invalid_coordinates(self) -> None: - # regression test for GH308 + # regression test for GH308, GH1809 original = Dataset({"foo": ("t", [1, 2], {"coordinates": "invalid"})}) + decoded = Dataset({"foo": ("t", [1, 2], {}, {"coordinates": "invalid"})}) actual = conventions.decode_cf(original) + assert_identical(decoded, actual) + actual = conventions.decode_cf(original, decode_coords=False) assert_identical(original, actual) def test_decode_coordinates(self) -> None: @@ -295,6 +333,17 @@ def test_invalid_time_units_raises_eagerly(self) -> None: with pytest.raises(ValueError, match=r"unable to decode time"): decode_cf(ds) + @pytest.mark.parametrize("decode_times", [True, False]) + def test_invalid_timedelta_units_do_not_decode(self, decode_times) -> None: + # regression test for #8269 + ds = Dataset( + {"time": ("time", [0, 1, 20], {"units": "days invalid", "_FillValue": 20})} + ) + expected = Dataset( + {"time": ("time", [0.0, 1.0, np.nan], {"units": "days invalid"})} + ) + assert_identical(expected, decode_cf(ds, decode_times=decode_times)) + @requires_cftime def test_dataset_repr_with_netcdf4_datetimes(self) -> None: # regression test for #347 diff --git a/xarray/tests/test_coordinates.py b/xarray/tests/test_coordinates.py new file mode 100644 index 00000000000..ef73371dfe4 --- /dev/null +++ b/xarray/tests/test_coordinates.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import pandas as pd +import pytest + +from xarray.core.alignment import align +from xarray.core.coordinates import Coordinates +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset +from xarray.core.indexes import PandasIndex, PandasMultiIndex +from xarray.tests import assert_identical, source_ndarray + + +class TestCoordinates: + def test_init_noindex(self) -> None: + coords = Coordinates(coords={"foo": ("x", [0, 1, 2])}) + expected = Dataset(coords={"foo": ("x", [0, 1, 2])}) + assert_identical(coords.to_dataset(), expected) + + def test_init_default_index(self) -> None: + coords = Coordinates(coords={"x": [1, 2]}) + expected = Dataset(coords={"x": [1, 2]}) + assert_identical(coords.to_dataset(), expected) + assert "x" in coords.xindexes + + def test_init_no_default_index(self) -> None: + # dimension coordinate with no default index (explicit) + coords = Coordinates(coords={"x": [1, 2]}, indexes={}) + assert "x" not in coords.xindexes + + def test_init_from_coords(self) -> None: + expected = Dataset(coords={"foo": ("x", [0, 1, 2])}) + coords = Coordinates(coords=expected.coords) + assert_identical(coords.to_dataset(), expected) + + # test variables copied + assert coords.variables["foo"] is not expected.variables["foo"] + + # test indexes are extracted + expected = Dataset(coords={"x": [0, 1, 2]}) + coords = Coordinates(coords=expected.coords) + assert_identical(coords.to_dataset(), expected) + assert expected.xindexes == coords.xindexes + + # coords + indexes not supported + with pytest.raises( + ValueError, match="passing both.*Coordinates.*indexes.*not allowed" + ): + coords = Coordinates( + coords=expected.coords, indexes={"x": PandasIndex([0, 1, 2], "x")} + ) + + def test_init_empty(self) -> None: + coords = Coordinates() + assert len(coords) == 0 + + def test_init_index_error(self) -> None: + idx = PandasIndex([1, 2, 3], "x") + with pytest.raises(ValueError, match="no coordinate variables found"): + Coordinates(indexes={"x": idx}) + + with pytest.raises(TypeError, match=".* is not an `xarray.indexes.Index`"): + Coordinates(coords={"x": ("x", [1, 2, 3])}, indexes={"x": "not_an_xarray_index"}) # type: ignore + + def test_init_dim_sizes_conflict(self) -> None: + with pytest.raises(ValueError): + Coordinates(coords={"foo": ("x", [1, 2]), "bar": ("x", [1, 2, 3, 4])}) + + def test_from_pandas_multiindex(self) -> None: + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + coords = Coordinates.from_pandas_multiindex(midx, "x") + + assert isinstance(coords.xindexes["x"], PandasMultiIndex) + assert coords.xindexes["x"].index.equals(midx) + assert coords.xindexes["x"].dim == "x" + + expected = PandasMultiIndex(midx, "x").create_variables() + assert list(coords.variables) == list(expected) + for name in ("x", "one", "two"): + assert_identical(expected[name], coords.variables[name]) + + def test_dims(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + assert coords.dims == {"x": 3} + + def test_sizes(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + assert coords.sizes == {"x": 3} + + def test_dtypes(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + assert coords.dtypes == {"x": int} + + def test_getitem(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + assert_identical( + coords["x"], + DataArray([0, 1, 2], coords={"x": [0, 1, 2]}, name="x"), + ) + + def test_delitem(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + del coords["x"] + assert "x" not in coords + + with pytest.raises( + KeyError, match="'nonexistent' is not in coordinate variables" + ): + del coords["nonexistent"] + + def test_update(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + + coords.update({"y": ("y", [4, 5, 6])}) + assert "y" in coords + assert "y" in coords.xindexes + expected = DataArray([4, 5, 6], coords={"y": [4, 5, 6]}, name="y") + assert_identical(coords["y"], expected) + + def test_equals(self): + coords = Coordinates(coords={"x": [0, 1, 2]}) + + assert coords.equals(coords) + assert not coords.equals("not_a_coords") + + def test_identical(self): + coords = Coordinates(coords={"x": [0, 1, 2]}) + + assert coords.identical(coords) + assert not coords.identical("not_a_coords") + + def test_assign(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + expected = Coordinates(coords={"x": [0, 1, 2], "y": [3, 4]}) + + actual = coords.assign(y=[3, 4]) + assert_identical(actual, expected) + + actual = coords.assign({"y": [3, 4]}) + assert_identical(actual, expected) + + def test_copy(self) -> None: + no_index_coords = Coordinates({"foo": ("x", [1, 2, 3])}) + copied = no_index_coords.copy() + assert_identical(no_index_coords, copied) + v0 = no_index_coords.variables["foo"] + v1 = copied.variables["foo"] + assert v0 is not v1 + assert source_ndarray(v0.data) is source_ndarray(v1.data) + + deep_copied = no_index_coords.copy(deep=True) + assert_identical(no_index_coords.to_dataset(), deep_copied.to_dataset()) + v0 = no_index_coords.variables["foo"] + v1 = deep_copied.variables["foo"] + assert v0 is not v1 + assert source_ndarray(v0.data) is not source_ndarray(v1.data) + + def test_align(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + + left = coords + + # test Coordinates._reindex_callback + right = coords.to_dataset().isel(x=[0, 1]).coords + left2, right2 = align(left, right, join="inner") + assert_identical(left2, right2) + + # test Coordinates._overwrite_indexes + right.update({"x": ("x", [4, 5, 6])}) + left2, right2 = align(left, right, join="override") + assert_identical(left2, left) + assert_identical(left2, right2) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index ed18718043b..1c2511427ac 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -193,11 +193,9 @@ def test_binary_op_bitshift(self) -> None: def test_repr(self): expected = dedent( - """\ + f"""\ - {!r}""".format( - self.lazy_var.data - ) + {self.lazy_var.data!r}""" ) assert expected == repr(self.lazy_var) @@ -301,6 +299,17 @@ def test_persist(self): self.assertLazyAndAllClose(u + 1, v) self.assertLazyAndAllClose(u + 1, v2) + def test_tokenize_empty_attrs(self) -> None: + # Issue #6970 + assert self.eager_var._attrs is None + expected = dask.base.tokenize(self.eager_var) + assert self.eager_var.attrs == self.eager_var._attrs == {} + assert ( + expected + == dask.base.tokenize(self.eager_var) + == dask.base.tokenize(self.lazy_var.compute()) + ) + @requires_pint def test_tokenize_duck_dask_array(self): import pint @@ -656,14 +665,12 @@ def test_dataarray_repr(self): nonindex_coord = build_dask_array("coord") a = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)}) expected = dedent( - """\ + f"""\ - {!r} + {data!r} Coordinates: y (x) int64 dask.array - Dimensions without coordinates: x""".format( - data - ) + Dimensions without coordinates: x""" ) assert expected == repr(a) assert kernel_call_count == 0 # should not evaluate dask array diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index cee5afa56a4..26537766f4d 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1,6 +1,7 @@ from __future__ import annotations import pickle +import re import sys import warnings from collections.abc import Hashable @@ -13,6 +14,12 @@ import pytest from packaging.version import Version +# remove once numpy 2.0 is the oldest supported version +try: + from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import RankWarning + import xarray as xr from xarray import ( DataArray, @@ -27,9 +34,11 @@ from xarray.convert import from_cdms2 from xarray.core import dtypes from xarray.core.common import full_like +from xarray.core.coordinates import Coordinates from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords from xarray.core.types import QueryEngineOptions, QueryParserOptions from xarray.core.utils import is_scalar +from xarray.testing import _assert_internal_invariants from xarray.tests import ( InaccessibleArray, ReturnItem, @@ -278,7 +287,7 @@ def test_encoding(self) -> None: self.dv.encoding = expected2 assert expected2 is not self.dv.encoding - def test_reset_encoding(self) -> None: + def test_drop_encoding(self) -> None: array = self.mda encoding = {"scale_factor": 10} array.encoding = encoding @@ -287,7 +296,7 @@ def test_reset_encoding(self) -> None: assert array.encoding == encoding assert array["x"].encoding == encoding - actual = array.reset_encoding() + actual = array.drop_encoding() # did not modify in place assert array.encoding == encoding @@ -407,9 +416,6 @@ def test_constructor_invalid(self) -> None: with pytest.raises(ValueError, match=r"conflicting MultiIndex"): DataArray(np.random.rand(4, 4), [("x", self.mindex), ("level_1", range(4))]) - with pytest.raises(ValueError, match=r"matching the dimension size"): - DataArray(data, coords={"x": 0}, dims=["x", "y"]) - def test_constructor_from_self_described(self) -> None: data = [[-0.1, 21], [0, 2]] expected = DataArray( @@ -486,6 +492,32 @@ def test_constructor_dask_coords(self) -> None: expected = DataArray(data, coords={"x": ecoord, "y": ecoord}, dims=["x", "y"]) assert_equal(actual, expected) + def test_constructor_no_default_index(self) -> None: + # explicitly passing a Coordinates object skips the creation of default index + da = DataArray(range(3), coords=Coordinates({"x": [1, 2, 3]}, indexes={})) + assert "x" in da.coords + assert "x" not in da.xindexes + + def test_constructor_multiindex(self) -> None: + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + coords = Coordinates.from_pandas_multiindex(midx, "x") + + da = DataArray(range(4), coords=coords, dims="x") + assert_identical(da.coords, coords) + + def test_constructor_custom_index(self) -> None: + class CustomIndex(Index): + ... + + coords = Coordinates( + coords={"x": ("x", [1, 2, 3])}, indexes={"x": CustomIndex()} + ) + da = DataArray(range(3), coords=coords) + assert isinstance(da.xindexes["x"], CustomIndex) + + # test coordinate variables copied + assert da.coords["x"] is not coords.variables["x"] + def test_equals_and_identical(self) -> None: orig = DataArray(np.arange(5.0), {"a": 42}, dims="x") @@ -807,6 +839,27 @@ def get_data(): ) da[dict(x=ind)] = value # should not raise + def test_setitem_vectorized(self) -> None: + # Regression test for GH:7030 + # Positional indexing + v = xr.DataArray(np.r_[:120].reshape(2, 3, 4, 5), dims=["a", "b", "c", "d"]) + b = xr.DataArray([[0, 0], [1, 0]], dims=["u", "v"]) + c = xr.DataArray([[0, 1], [2, 3]], dims=["u", "v"]) + w = xr.DataArray([-1, -2], dims=["u"]) + index = dict(b=b, c=c) + v[index] = w + assert (v[index] == w).all() + + # Indexing with coordinates + v = xr.DataArray(np.r_[:120].reshape(2, 3, 4, 5), dims=["a", "b", "c", "d"]) + v.coords["b"] = [2, 4, 6] + b = xr.DataArray([[2, 2], [4, 2]], dims=["u", "v"]) + c = xr.DataArray([[0, 1], [2, 3]], dims=["u", "v"]) + w = xr.DataArray([-1, -2], dims=["u"]) + index = dict(b=b, c=c) + v.loc[index] = w + assert (v.loc[index] == w).all() + def test_contains(self) -> None: data_array = DataArray([1, 2]) assert 1 in data_array @@ -826,13 +879,14 @@ def test_chunk(self) -> None: assert blocked.chunks == ((3,), (4,)) first_dask_name = blocked.data.name - blocked = unblocked.chunk(chunks=((2, 1), (2, 2))) - assert blocked.chunks == ((2, 1), (2, 2)) - assert blocked.data.name != first_dask_name + with pytest.warns(DeprecationWarning): + blocked = unblocked.chunk(chunks=((2, 1), (2, 2))) # type: ignore + assert blocked.chunks == ((2, 1), (2, 2)) + assert blocked.data.name != first_dask_name - blocked = unblocked.chunk(chunks=(3, 3)) - assert blocked.chunks == ((3,), (3, 1)) - assert blocked.data.name != first_dask_name + blocked = unblocked.chunk(chunks=(3, 3)) + assert blocked.chunks == ((3,), (3, 1)) + assert blocked.data.name != first_dask_name # name doesn't change when rechunking by same amount # this fails if ReprObject doesn't have __dask_tokenize__ defined @@ -1408,7 +1462,7 @@ def test_coords(self) -> None: assert_identical(da, expected) with pytest.raises( - ValueError, match=r"cannot set or update variable.*corrupt.*index " + ValueError, match=r"cannot drop or update coordinate.*corrupt.*index " ): self.mda["level_1"] = ("x", np.arange(4)) self.mda.coords["level_1"] = ("x", np.arange(4)) @@ -1528,7 +1582,7 @@ def test_assign_coords(self) -> None: assert_identical(actual, expected) with pytest.raises( - ValueError, match=r"cannot set or update variable.*corrupt.*index " + ValueError, match=r"cannot drop or update coordinate.*corrupt.*index " ): self.mda.assign_coords(level_1=("x", range(4))) @@ -1543,9 +1597,29 @@ def test_assign_coords(self) -> None: def test_assign_coords_existing_multiindex(self) -> None: data = self.mda - with pytest.warns(FutureWarning, match=r"Updating MultiIndexed coordinate"): + with pytest.warns( + FutureWarning, match=r"updating coordinate.*MultiIndex.*inconsistent" + ): data.assign_coords(x=range(4)) + def test_assign_coords_custom_index(self) -> None: + class CustomIndex(Index): + pass + + coords = Coordinates( + coords={"x": ("x", [1, 2, 3])}, indexes={"x": CustomIndex()} + ) + da = xr.DataArray([0, 1, 2], dims="x") + actual = da.assign_coords(coords) + assert isinstance(actual.xindexes["x"], CustomIndex) + + def test_assign_coords_no_default_index(self) -> None: + coords = Coordinates({"y": [1, 2, 3]}, indexes={}) + da = DataArray([1, 2, 3], dims="y") + actual = da.assign_coords(coords) + assert_identical(actual.coords, coords, check_default_indexes=False) + assert "y" not in actual.xindexes + def test_coords_alignment(self) -> None: lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])]) @@ -1563,7 +1637,7 @@ def test_set_coords_update_index(self) -> None: def test_set_coords_multiindex_level(self) -> None: with pytest.raises( - ValueError, match=r"cannot set or update variable.*corrupt.*index " + ValueError, match=r"cannot drop or update coordinate.*corrupt.*index " ): self.mda["level_1"] = range(4) @@ -1810,6 +1884,16 @@ def test_rename_dimension_coord_warnings(self) -> None: ): da.rename(x="y") + # No operation should not raise a warning + da = xr.DataArray( + data=np.ones((2, 3)), + dims=["x", "y"], + coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + da.rename(x="x") + def test_init_value(self) -> None: expected = DataArray( np.full((3, 4), 3), dims=["x", "y"], coords=[range(3), range(4)] @@ -2445,6 +2529,19 @@ def test_unstack_pandas_consistency(self) -> None: actual = DataArray(s, dims="z").unstack("z") assert_identical(expected, actual) + @pytest.mark.filterwarnings("error") + def test_unstack_roundtrip_integer_array(self) -> None: + arr = xr.DataArray( + np.arange(6).reshape(2, 3), + coords={"x": ["a", "b"], "y": [0, 1, 2]}, + dims=["x", "y"], + ) + + stacked = arr.stack(z=["x", "y"]) + roundtripped = stacked.unstack() + + assert_identical(arr, roundtripped) + def test_stack_nonunique_consistency(self, da) -> None: da = da.isel(time=0, drop=True) # 2D actual = da.stack(z=["a", "x"]) @@ -2631,6 +2728,14 @@ def test_where_lambda(self) -> None: actual = arr.where(lambda x: x.y < 2, drop=True) assert_identical(actual, expected) + def test_where_other_lambda(self) -> None: + arr = DataArray(np.arange(4), dims="y") + expected = xr.concat( + [arr.sel(y=slice(2)), arr.sel(y=slice(2, None)) + 1], dim="y" + ) + actual = arr.where(lambda x: x.y < 2, lambda x: x + 1) + assert_identical(actual, expected) + def test_where_string(self) -> None: array = DataArray(["a", "b"]) expected = DataArray(np.array(["a", np.nan], dtype=object)) @@ -2785,7 +2890,7 @@ def test_reduce_out(self) -> None: ) def test_quantile(self, q, axis, dim, skipna) -> None: va = self.va.copy(deep=True) - va[0, 0] = np.NaN + va[0, 0] = np.nan actual = DataArray(va).quantile(q, dim=dim, keep_attrs=True, skipna=skipna) _percentile_func = np.nanpercentile if skipna in (True, None) else np.percentile @@ -2803,10 +2908,7 @@ def test_quantile_method(self, method) -> None: q = [0.25, 0.5, 0.75] actual = DataArray(self.va).quantile(q, method=method) - if Version(np.__version__) >= Version("1.22.0"): - expected = np.nanquantile(self.dv.values, np.array(q), method=method) - else: - expected = np.nanquantile(self.dv.values, np.array(q), interpolation=method) + expected = np.nanquantile(self.dv.values, np.array(q), method=method) np.testing.assert_allclose(actual.values, expected) @@ -3066,10 +3168,10 @@ def test_align_str_dtype(self) -> None: b = DataArray([1, 2], dims=["x"], coords={"x": ["b", "c"]}) expected_a = DataArray( - [0, 1, np.NaN], dims=["x"], coords={"x": ["a", "b", "c"]} + [0, 1, np.nan], dims=["x"], coords={"x": ["a", "b", "c"]} ) expected_b = DataArray( - [np.NaN, 1, 2], dims=["x"], coords={"x": ["a", "b", "c"]} + [np.nan, 1, 2], dims=["x"], coords={"x": ["a", "b", "c"]} ) actual_a, actual_b = xr.align(a, b, join="outer") @@ -3926,7 +4028,7 @@ def test_dot(self) -> None: assert_equal(expected5, actual5) with pytest.raises(NotImplementedError): - da.dot(dm3.to_dataset(name="dm")) # type: ignore + da.dot(dm3.to_dataset(name="dm")) with pytest.raises(TypeError): da.dot(dm3.values) # type: ignore @@ -4153,7 +4255,7 @@ def test_polyfit(self, use_dask, use_datetime) -> None: # Full output and deficient rank with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.RankWarning) + warnings.simplefilter("ignore", RankWarning) out = da.polyfit("x", 12, full=True) assert out.polyfit_residuals.isnull().all() @@ -4174,7 +4276,7 @@ def test_polyfit(self, use_dask, use_datetime) -> None: np.testing.assert_almost_equal(out.polyfit_residuals, [0, 0]) with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.RankWarning) + warnings.simplefilter("ignore", RankWarning) out = da.polyfit("x", 8, full=True) np.testing.assert_array_equal(out.polyfit_residuals.isnull(), [True, False]) @@ -4195,7 +4297,7 @@ def test_pad_constant(self) -> None: ar = xr.DataArray([9], dims="x") actual = ar.pad(x=1) - expected = xr.DataArray([np.NaN, 9, np.NaN], dims="x") + expected = xr.DataArray([np.nan, 9, np.nan], dims="x") assert_identical(actual, expected) actual = ar.pad(x=1, constant_values=1.23456) @@ -4203,7 +4305,7 @@ def test_pad_constant(self) -> None: assert_identical(actual, expected) with pytest.raises(ValueError, match="cannot convert float NaN to integer"): - ar.pad(x=1, constant_values=np.NaN) + ar.pad(x=1, constant_values=np.nan) def test_pad_coords(self) -> None: ar = DataArray( @@ -4641,10 +4743,10 @@ def setup(self): np.array([0.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0]), 5, 2, None, id="float" ), pytest.param( - np.array([1.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0]), 5, 2, 1, id="nan" + np.array([1.0, np.nan, 2.0, np.nan, -2.0, -4.0, 2.0]), 5, 2, 1, id="nan" ), pytest.param( - np.array([1.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0]).astype("object"), + np.array([1.0, np.nan, 2.0, np.nan, -2.0, -4.0, 2.0]).astype("object"), 5, 2, 1, @@ -4653,7 +4755,7 @@ def setup(self): ), id="obj", ), - pytest.param(np.array([np.NaN, np.NaN]), np.NaN, np.NaN, 0, id="allnan"), + pytest.param(np.array([np.nan, np.nan]), np.nan, np.nan, 0, id="allnan"), pytest.param( np.array( ["2015-12-31", "2020-01-02", "2020-01-01", "2016-01-01"], @@ -4829,8 +4931,10 @@ def test_idxmin( else: ar0 = ar0_raw - # dim doesn't exist - with pytest.raises(KeyError): + with pytest.raises( + KeyError, + match=r"'spam' not found in array dimensions", + ): ar0.idxmin(dim="spam") # Scalar Dataarray @@ -4846,7 +4950,7 @@ def test_idxmin( if hasna: coordarr1[...] = 1 - fill_value_0 = np.NaN + fill_value_0 = np.nan else: fill_value_0 = 1 @@ -4860,7 +4964,7 @@ def test_idxmin( assert_identical(result0, expected0) # Manually specify NaN fill_value - result1 = ar0.idxmin(fill_value=np.NaN) + result1 = ar0.idxmin(fill_value=np.nan) assert_identical(result1, expected0) # keep_attrs @@ -4942,8 +5046,10 @@ def test_idxmax( else: ar0 = ar0_raw - # dim doesn't exist - with pytest.raises(KeyError): + with pytest.raises( + KeyError, + match=r"'spam' not found in array dimensions", + ): ar0.idxmax(dim="spam") # Scalar Dataarray @@ -4959,7 +5065,7 @@ def test_idxmax( if hasna: coordarr1[...] = 1 - fill_value_0 = np.NaN + fill_value_0 = np.nan else: fill_value_0 = 1 @@ -4973,7 +5079,7 @@ def test_idxmax( assert_identical(result0, expected0) # Manually specify NaN fill_value - result1 = ar0.idxmax(fill_value=np.NaN) + result1 = ar0.idxmax(fill_value=np.nan) assert_identical(result1, expected0) # keep_attrs @@ -5138,12 +5244,12 @@ def test_argmax_dim( np.array( [ [2.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0], - [-4.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0], - [np.NaN] * 7, + [-4.0, np.nan, 2.0, np.nan, -2.0, -4.0, 2.0], + [np.nan] * 7, ] ), - [5, 0, np.NaN], - [0, 2, np.NaN], + [5, 0, np.nan], + [0, 2, np.nan], [None, 1, 0], id="nan", ), @@ -5151,12 +5257,12 @@ def test_argmax_dim( np.array( [ [2.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0], - [-4.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0], - [np.NaN] * 7, + [-4.0, np.nan, 2.0, np.nan, -2.0, -4.0, 2.0], + [np.nan] * 7, ] ).astype("object"), - [5, 0, np.NaN], - [0, 2, np.NaN], + [5, 0, np.nan], + [0, 2, np.nan], [None, 1, 0], marks=pytest.mark.filterwarnings( "ignore:invalid value encountered in reduce:RuntimeWarning:" @@ -5431,7 +5537,7 @@ def test_idxmin( coordarr1[hasna, :] = 1 minindex0 = [x if not np.isnan(x) else 0 for x in minindex] - nan_mult_0 = np.array([np.NaN if x else 1 for x in hasna])[:, None] + nan_mult_0 = np.array([np.nan if x else 1 for x in hasna])[:, None] expected0list = [ (coordarr1 * nan_mult_0).isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(minindex0) @@ -5446,7 +5552,7 @@ def test_idxmin( # Manually specify NaN fill_value with raise_if_dask_computes(max_computes=max_computes): - result1 = ar0.idxmin(dim="x", fill_value=np.NaN) + result1 = ar0.idxmin(dim="x", fill_value=np.nan) assert_identical(result1, expected0) # keep_attrs @@ -5573,7 +5679,7 @@ def test_idxmax( coordarr1[hasna, :] = 1 maxindex0 = [x if not np.isnan(x) else 0 for x in maxindex] - nan_mult_0 = np.array([np.NaN if x else 1 for x in hasna])[:, None] + nan_mult_0 = np.array([np.nan if x else 1 for x in hasna])[:, None] expected0list = [ (coordarr1 * nan_mult_0).isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(maxindex0) @@ -5588,7 +5694,7 @@ def test_idxmax( # Manually specify NaN fill_value with raise_if_dask_computes(max_computes=max_computes): - result1 = ar0.idxmax(dim="x", fill_value=np.NaN) + result1 = ar0.idxmax(dim="x", fill_value=np.nan) assert_identical(result1, expected0) # keep_attrs @@ -5847,31 +5953,31 @@ def test_argmax_dim( np.array( [ [[2.0, 1.0, 2.0, 0.0], [-2.0, -4.0, 2.0, 0.0]], - [[-4.0, np.NaN, 2.0, np.NaN], [-2.0, -4.0, 2.0, 0.0]], - [[np.NaN] * 4, [np.NaN] * 4], + [[-4.0, np.nan, 2.0, np.nan], [-2.0, -4.0, 2.0, 0.0]], + [[np.nan] * 4, [np.nan] * 4], ] ), {"x": np.array([[1, 0, 0, 0], [0, 0, 0, 0]])}, { "y": np.array( - [[1, 1, 0, 0], [0, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + [[1, 1, 0, 0], [0, 1, 0, 1], [np.nan, np.nan, np.nan, np.nan]] ) }, - {"z": np.array([[3, 1], [0, 1], [np.NaN, np.NaN]])}, + {"z": np.array([[3, 1], [0, 1], [np.nan, np.nan]])}, {"x": np.array([1, 0, 0, 0]), "y": np.array([0, 1, 0, 0])}, {"x": np.array([1, 0]), "z": np.array([0, 1])}, - {"y": np.array([1, 0, np.NaN]), "z": np.array([1, 0, np.NaN])}, + {"y": np.array([1, 0, np.nan]), "z": np.array([1, 0, np.nan])}, {"x": np.array(0), "y": np.array(1), "z": np.array(1)}, {"x": np.array([[0, 0, 0, 0], [0, 0, 0, 0]])}, { "y": np.array( - [[0, 0, 0, 0], [1, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + [[0, 0, 0, 0], [1, 1, 0, 1], [np.nan, np.nan, np.nan, np.nan]] ) }, - {"z": np.array([[0, 2], [2, 2], [np.NaN, np.NaN]])}, + {"z": np.array([[0, 2], [2, 2], [np.nan, np.nan]])}, {"x": np.array([0, 0, 0, 0]), "y": np.array([0, 0, 0, 0])}, {"x": np.array([0, 0]), "z": np.array([2, 2])}, - {"y": np.array([0, 0, np.NaN]), "z": np.array([0, 2, np.NaN])}, + {"y": np.array([0, 0, np.nan]), "z": np.array([0, 2, np.nan])}, {"x": np.array(0), "y": np.array(0), "z": np.array(0)}, {"x": np.array([[2, 1, 2, 1], [2, 2, 2, 2]])}, { @@ -5890,31 +5996,31 @@ def test_argmax_dim( np.array( [ [[2.0, 1.0, 2.0, 0.0], [-2.0, -4.0, 2.0, 0.0]], - [[-4.0, np.NaN, 2.0, np.NaN], [-2.0, -4.0, 2.0, 0.0]], - [[np.NaN] * 4, [np.NaN] * 4], + [[-4.0, np.nan, 2.0, np.nan], [-2.0, -4.0, 2.0, 0.0]], + [[np.nan] * 4, [np.nan] * 4], ] ).astype("object"), {"x": np.array([[1, 0, 0, 0], [0, 0, 0, 0]])}, { "y": np.array( - [[1, 1, 0, 0], [0, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + [[1, 1, 0, 0], [0, 1, 0, 1], [np.nan, np.nan, np.nan, np.nan]] ) }, - {"z": np.array([[3, 1], [0, 1], [np.NaN, np.NaN]])}, + {"z": np.array([[3, 1], [0, 1], [np.nan, np.nan]])}, {"x": np.array([1, 0, 0, 0]), "y": np.array([0, 1, 0, 0])}, {"x": np.array([1, 0]), "z": np.array([0, 1])}, - {"y": np.array([1, 0, np.NaN]), "z": np.array([1, 0, np.NaN])}, + {"y": np.array([1, 0, np.nan]), "z": np.array([1, 0, np.nan])}, {"x": np.array(0), "y": np.array(1), "z": np.array(1)}, {"x": np.array([[0, 0, 0, 0], [0, 0, 0, 0]])}, { "y": np.array( - [[0, 0, 0, 0], [1, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + [[0, 0, 0, 0], [1, 1, 0, 1], [np.nan, np.nan, np.nan, np.nan]] ) }, - {"z": np.array([[0, 2], [2, 2], [np.NaN, np.NaN]])}, + {"z": np.array([[0, 2], [2, 2], [np.nan, np.nan]])}, {"x": np.array([0, 0, 0, 0]), "y": np.array([0, 0, 0, 0])}, {"x": np.array([0, 0]), "z": np.array([2, 2])}, - {"y": np.array([0, 0, np.NaN]), "z": np.array([0, 2, np.NaN])}, + {"y": np.array([0, 0, np.nan]), "z": np.array([0, 2, np.nan])}, {"x": np.array(0), "y": np.array(0), "z": np.array(0)}, {"x": np.array([[2, 1, 2, 1], [2, 2, 2, 2]])}, { @@ -6460,12 +6566,12 @@ def test_isin(da) -> None: def test_raise_no_warning_for_nan_in_binary_ops() -> None: with assert_no_warnings(): - xr.DataArray([1, 2, np.NaN]) > 0 + xr.DataArray([1, 2, np.nan]) > 0 @pytest.mark.filterwarnings("error") def test_no_warning_for_all_nan() -> None: - _ = xr.DataArray([np.NaN, np.NaN]).mean() + _ = xr.DataArray([np.nan, np.nan]).mean() def test_name_in_masking() -> None: @@ -6505,7 +6611,7 @@ def test_to_and_from_iris(self) -> None: ) # Set a bad value to test the masking logic - original.data[0, 2] = np.NaN + original.data[0, 2] = np.nan original.attrs["cell_methods"] = "height: mean (comment: A cell method)" actual = original.to_iris() @@ -6897,7 +7003,12 @@ def test_drop_duplicates_1d(self, keep) -> None: result = da.drop_duplicates("time", keep=keep) assert_equal(expected, result) - with pytest.raises(ValueError, match="['space'] not found"): + with pytest.raises( + ValueError, + match=re.escape( + "Dimensions ('space',) not found in data dimensions ('time',)" + ), + ): da.drop_duplicates("space", keep=keep) def test_drop_duplicates_2d(self) -> None: @@ -7018,3 +7129,35 @@ def test_error_on_ellipsis_without_list(self) -> None: da = DataArray([[1, 2], [1, 2]], dims=("x", "y")) with pytest.raises(ValueError): da.stack(flat=...) # type: ignore + + +def test_nD_coord_dataarray() -> None: + # should succeed + da = DataArray( + np.ones((2, 4)), + dims=("x", "y"), + coords={ + "x": (("x", "y"), np.arange(8).reshape((2, 4))), + "y": ("y", np.arange(4)), + }, + ) + _assert_internal_invariants(da, check_default_indexes=True) + + da2 = DataArray(np.ones(4), dims=("y"), coords={"y": ("y", np.arange(4))}) + da3 = DataArray(np.ones(4), dims=("z")) + + _, actual = xr.align(da, da2) + assert_identical(da2, actual) + + expected = da.drop_vars("x") + _, actual = xr.broadcast(da, da2) + assert_identical(expected, actual) + + actual, _ = xr.broadcast(da, da3) + expected = da.expand_dims(z=4, axis=-1) + assert_identical(actual, expected) + + da4 = DataArray(np.ones((2, 4)), coords={"x": 0}, dims=["x", "y"]) + _assert_internal_invariants(da4, check_default_indexes=True) + assert "x" not in da4.xindexes + assert "x" in da4.coords diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f7f91d0e134..687aae8f1dc 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -15,6 +15,12 @@ import pytest from pandas.core.indexes.datetimes import DatetimeIndex +# remove once numpy 2.0 is the oldest supported version +try: + from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import RankWarning + import xarray as xr from xarray import ( DataArray, @@ -31,7 +37,7 @@ from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import dtypes, indexing, utils from xarray.core.common import duck_array_ops, full_like -from xarray.core.coordinates import DatasetCoordinates +from xarray.core.coordinates import Coordinates, DatasetCoordinates from xarray.core.indexes import Index, PandasIndex from xarray.core.pycompat import array_type, integer_types from xarray.core.utils import is_scalar @@ -118,7 +124,7 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: ), "unicode_var": xr.DataArray( unicode_var, coords=[time1], dims=["time"] - ).astype(np.unicode_), + ).astype(np.str_), "datetime_var": xr.DataArray( datetime_var, coords=[time1], dims=["time"] ), @@ -141,7 +147,7 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: ), "unicode_var": xr.DataArray( unicode_var[:nt2], coords=[time2], dims=["time"] - ).astype(np.unicode_), + ).astype(np.str_), "datetime_var": xr.DataArray( datetime_var_to_append, coords=[time2], dims=["time"] ), @@ -199,7 +205,7 @@ def create_test_multiindex() -> Dataset: mindex = pd.MultiIndex.from_product( [["a", "b"], [1, 2]], names=("level_1", "level_2") ) - return Dataset({}, {"x": mindex}) + return Dataset({}, Coordinates.from_pandas_multiindex(mindex, "x")) def create_test_stacked_array() -> tuple[DataArray, DataArray]: @@ -353,10 +359,11 @@ def test_repr_multiindex(self) -> None: assert expected == actual # verify that long level names are not truncated - mindex = pd.MultiIndex.from_product( + midx = pd.MultiIndex.from_product( [["a", "b"], [1, 2]], names=("a_quite_long_level_name", "level_2") ) - data = Dataset({}, {"x": mindex}) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + data = Dataset({}, midx_coords) expected = dedent( """\ @@ -374,7 +381,7 @@ def test_repr_multiindex(self) -> None: def test_repr_period_index(self) -> None: data = create_test_data(seed=456) - data.coords["time"] = pd.period_range("2000-01-01", periods=20, freq="B") + data.coords["time"] = pd.period_range("2000-01-01", periods=20, freq="D") # check that creating the repr doesn't raise an error #GH645 repr(data) @@ -404,10 +411,14 @@ def test_repr_nep18(self) -> None: class Array: def __init__(self): self.shape = (2,) + self.ndim = 1 self.dtype = np.dtype(np.float64) def __array_function__(self, *args, **kwargs): - pass + return NotImplemented + + def __array_ufunc__(self, *args, **kwargs): + return NotImplemented def __repr__(self): return "Custom\nArray" @@ -631,8 +642,50 @@ def test_constructor_with_coords(self) -> None: [["a", "b"], [1, 2]], names=("level_1", "level_2") ) with pytest.raises(ValueError, match=r"conflicting MultiIndex"): - Dataset({}, {"x": mindex, "y": mindex}) - Dataset({}, {"x": mindex, "level_1": range(4)}) + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + Dataset({}, {"x": mindex, "y": mindex}) + Dataset({}, {"x": mindex, "level_1": range(4)}) + + def test_constructor_no_default_index(self) -> None: + # explicitly passing a Coordinates object skips the creation of default index + ds = Dataset(coords=Coordinates({"x": [1, 2, 3]}, indexes={})) + assert "x" in ds + assert "x" not in ds.xindexes + + def test_constructor_multiindex(self) -> None: + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + coords = Coordinates.from_pandas_multiindex(midx, "x") + + ds = Dataset(coords=coords) + assert_identical(ds, coords.to_dataset()) + + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + Dataset(data_vars={"x": midx}) + + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + Dataset(coords={"x": midx}) + + def test_constructor_custom_index(self) -> None: + class CustomIndex(Index): + ... + + coords = Coordinates( + coords={"x": ("x", [1, 2, 3])}, indexes={"x": CustomIndex()} + ) + ds = Dataset(coords=coords) + assert isinstance(ds.xindexes["x"], CustomIndex) + + # test coordinate variables copied + assert ds.variables["x"] is not coords.variables["x"] def test_properties(self) -> None: ds = create_test_data() @@ -642,7 +695,7 @@ def test_properties(self) -> None: # change them inadvertently: assert isinstance(ds.dims, utils.Frozen) assert isinstance(ds.dims.mapping, dict) - assert type(ds.dims.mapping) is dict + assert type(ds.dims.mapping) is dict # noqa: E721 assert ds.dims == {"dim1": 8, "dim2": 9, "dim3": 10, "time": 20} assert ds.sizes == ds.dims @@ -841,7 +894,7 @@ def test_coords_modify(self) -> None: assert_array_equal(actual["z"], ["a", "b"]) actual = data.copy(deep=True) - with pytest.raises(ValueError, match=r"conflicting sizes"): + with pytest.raises(ValueError, match=r"conflicting dimension sizes"): actual.coords["x"] = ("x", [-1]) assert_identical(actual, data) # should not be modified @@ -878,9 +931,7 @@ def test_coords_setitem_with_new_dimension(self) -> None: def test_coords_setitem_multiindex(self) -> None: data = create_test_multiindex() - with pytest.raises( - ValueError, match=r"cannot set or update variable.*corrupt.*index " - ): + with pytest.raises(ValueError, match=r"cannot drop or update.*corrupt.*index "): data.coords["level_1"] = range(4) def test_coords_set(self) -> None: @@ -1010,6 +1061,17 @@ def test_data_vars_properties(self) -> None: "bar": np.dtype("float64"), } + # len + ds.coords["x"] = [1] + assert len(ds.data_vars) == 2 + + # https://github.com/pydata/xarray/issues/7588 + with pytest.raises( + AssertionError, match="something is wrong with Dataset._coord_names" + ): + ds._coord_names = {"w", "x", "y", "z"} + len(ds.data_vars) + def test_equals_and_identical(self) -> None: data = create_test_data(seed=42) assert data.equals(data) @@ -1111,7 +1173,12 @@ def get_dask_names(ds): for k, v in new_dask_names.items(): assert v == orig_dask_names[k] - with pytest.raises(ValueError, match=r"some chunks"): + with pytest.raises( + ValueError, + match=re.escape( + "chunks keys ('foo',) not found in data dimensions ('dim2', 'dim3', 'time', 'dim1')" + ), + ): data.chunk({"foo": 10}) @requires_dask @@ -1557,9 +1624,11 @@ def test_sel_dataarray(self) -> None: def test_sel_dataarray_mindex(self) -> None: midx = pd.MultiIndex.from_product([list("abc"), [0, 1]], names=("one", "two")) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + midx_coords["y"] = range(3) + mds = xr.Dataset( - {"var": (("x", "y"), np.random.rand(6, 3))}, - coords={"x": midx, "y": range(3)}, + {"var": (("x", "y"), np.random.rand(6, 3))}, coords=midx_coords ) actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims="x")) @@ -1677,7 +1746,8 @@ def test_sel_drop(self) -> None: def test_sel_drop_mindex(self) -> None: midx = pd.MultiIndex.from_arrays([["a", "a"], [1, 2]], names=("foo", "bar")) - data = Dataset(coords={"x": midx}) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + data = Dataset(coords=midx_coords) actual = data.sel(foo="a", drop=True) assert "foo" not in actual.coords @@ -1898,10 +1968,11 @@ def test_loc(self) -> None: data.loc["a"] # type: ignore[index] def test_selection_multiindex(self) -> None: - mindex = pd.MultiIndex.from_product( + midx = pd.MultiIndex.from_product( [["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three") ) - mdata = Dataset(data_vars={"var": ("x", range(8))}, coords={"x": mindex}) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + mdata = Dataset(data_vars={"var": ("x", range(8))}, coords=midx_coords) def test_sel( lab_indexer, pos_indexer, replaced_idx=False, renamed_dim=None @@ -2261,9 +2332,9 @@ def test_align(self) -> None: assert np.isnan(left2["var3"][-2:]).all() with pytest.raises(ValueError, match=r"invalid value for join"): - align(left, right, join="foobar") # type: ignore[arg-type] + align(left, right, join="foobar") # type: ignore[call-overload] with pytest.raises(TypeError): - align(left, right, foo="bar") # type: ignore[call-arg] + align(left, right, foo="bar") # type: ignore[call-overload] def test_align_exact(self) -> None: left = xr.Dataset(coords={"x": [0, 1]}) @@ -2380,10 +2451,10 @@ def test_align_str_dtype(self) -> None: b = Dataset({"foo": ("x", [1, 2])}, coords={"x": ["b", "c"]}) expected_a = Dataset( - {"foo": ("x", [0, 1, np.NaN])}, coords={"x": ["a", "b", "c"]} + {"foo": ("x", [0, 1, np.nan])}, coords={"x": ["a", "b", "c"]} ) expected_b = Dataset( - {"foo": ("x", [np.NaN, 1, 2])}, coords={"x": ["a", "b", "c"]} + {"foo": ("x", [np.nan, 1, 2])}, coords={"x": ["a", "b", "c"]} ) actual_a, actual_b = xr.align(a, b, join="outer") @@ -2733,7 +2804,10 @@ def test_drop_indexes(self) -> None: assert type(actual.x.variable) is Variable assert type(actual.y.variable) is Variable - with pytest.raises(ValueError, match="those coordinates don't exist"): + with pytest.raises( + ValueError, + match=r"The coordinates \('not_a_coord',\) are not found in the dataset coordinates", + ): ds.drop_indexes("not_a_coord") with pytest.raises(ValueError, match="those coordinates do not have an index"): @@ -2743,8 +2817,9 @@ def test_drop_indexes(self) -> None: assert_identical(actual, ds) # test index corrupted - mindex = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) - ds = Dataset(coords={"x": mindex}) + midx = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + ds = Dataset(coords=midx_coords) with pytest.raises(ValueError, match=".*would corrupt the following index.*"): ds.drop_indexes("a") @@ -2884,7 +2959,7 @@ def test_copy_with_data_errors(self) -> None: with pytest.raises(ValueError, match=r"contain all variables in original"): orig.copy(data={"var1": new_var1}) - def test_reset_encoding(self) -> None: + def test_drop_encoding(self) -> None: orig = create_test_data() vencoding = {"scale_factor": 10} orig.encoding = {"foo": "bar"} @@ -2892,7 +2967,7 @@ def test_reset_encoding(self) -> None: for k, v in orig.variables.items(): orig[k].encoding = vencoding - actual = orig.reset_encoding() + actual = orig.drop_encoding() assert actual.encoding == {} for k, v in actual.variables.items(): assert v.encoding == {} @@ -2957,8 +3032,7 @@ def test_rename_old_name(self) -> None: def test_rename_same_name(self) -> None: data = create_test_data() newnames = {"var1": "var1", "dim2": "dim2"} - with pytest.warns(UserWarning, match="does not create an index anymore"): - renamed = data.rename(newnames) + renamed = data.rename(newnames) assert_identical(renamed, data) def test_rename_dims(self) -> None: @@ -3028,10 +3102,23 @@ def test_rename_dimension_coord_warnings(self) -> None: ): ds.rename(x="y") + # No operation should not raise a warning + ds = Dataset( + data_vars={"data": (("x", "y"), np.ones((2, 3)))}, + coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + ds.rename(x="x") + def test_rename_multiindex(self) -> None: - mindex = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) - original = Dataset({}, {"x": mindex}) - expected = Dataset({}, {"x": mindex.rename(["a", "c"])}) + midx = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + original = Dataset({}, midx_coords) + + midx_renamed = midx.rename(["a", "c"]) + midx_coords_renamed = Coordinates.from_pandas_multiindex(midx_renamed, "x") + expected = Dataset({}, midx_coords_renamed) actual = original.rename({"b": "c"}) assert_identical(expected, actual) @@ -3141,9 +3228,14 @@ def test_swap_dims(self) -> None: assert_identical(expected, actual) # handle multiindex case - idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) - original = Dataset({"x": [1, 2, 3], "y": ("x", idx), "z": 42}) - expected = Dataset({"z": 42}, {"x": ("y", [1, 2, 3]), "y": idx}) + midx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) + + original = Dataset({"x": [1, 2, 3], "y": ("x", midx), "z": 42}) + + midx_coords = Coordinates.from_pandas_multiindex(midx, "y") + midx_coords["x"] = ("y", [1, 2, 3]) + expected = Dataset({"z": 42}, midx_coords) + actual = original.swap_dims({"x": "y"}) assert_identical(expected, actual) assert isinstance(actual.variables["y"], IndexVariable) @@ -3361,6 +3453,12 @@ def test_set_index(self) -> None: with pytest.raises(ValueError, match=r"dimension mismatch.*"): ds.set_index(y="x_var") + ds = Dataset(coords={"x": 1}) + with pytest.raises( + ValueError, match=r".*cannot set a PandasIndex.*scalar variable.*" + ): + ds.set_index(x="x") + def test_set_index_deindexed_coords(self) -> None: # test de-indexed coordinates are converted to base variable # https://github.com/pydata/xarray/issues/6969 @@ -3369,16 +3467,20 @@ def test_set_index_deindexed_coords(self) -> None: three = ["c", "c", "d", "d"] four = [3, 4, 3, 4] - mindex_12 = pd.MultiIndex.from_arrays([one, two], names=["one", "two"]) - mindex_34 = pd.MultiIndex.from_arrays([three, four], names=["three", "four"]) + midx_12 = pd.MultiIndex.from_arrays([one, two], names=["one", "two"]) + midx_34 = pd.MultiIndex.from_arrays([three, four], names=["three", "four"]) - ds = xr.Dataset( - coords={"x": mindex_12, "three": ("x", three), "four": ("x", four)} - ) + coords = Coordinates.from_pandas_multiindex(midx_12, "x") + coords["three"] = ("x", three) + coords["four"] = ("x", four) + ds = xr.Dataset(coords=coords) actual = ds.set_index(x=["three", "four"]) - expected = xr.Dataset( - coords={"x": mindex_34, "one": ("x", one), "two": ("x", two)} - ) + + coords_expected = Coordinates.from_pandas_multiindex(midx_34, "x") + coords_expected["one"] = ("x", one) + coords_expected["two"] = ("x", two) + expected = xr.Dataset(coords=coords_expected) + assert_identical(actual, expected) def test_reset_index(self) -> None: @@ -3409,7 +3511,7 @@ def test_reset_index_drop_dims(self) -> None: assert len(reset.dims) == 0 @pytest.mark.parametrize( - "arg,drop,dropped,converted,renamed", + ["arg", "drop", "dropped", "converted", "renamed"], [ ("foo", False, [], [], {"bar": "x"}), ("foo", True, ["foo"], [], {"bar": "x"}), @@ -3422,14 +3524,20 @@ def test_reset_index_drop_dims(self) -> None: ], ) def test_reset_index_drop_convert( - self, arg, drop, dropped, converted, renamed + self, + arg: str | list[str], + drop: bool, + dropped: list[str], + converted: list[str], + renamed: dict[str, str], ) -> None: # regressions https://github.com/pydata/xarray/issues/6946 and # https://github.com/pydata/xarray/issues/6989 # check that multi-index dimension or level coordinates are dropped, converted # from IndexVariable to Variable or renamed to dimension as expected midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("foo", "bar")) - ds = xr.Dataset(coords={"x": midx}) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + ds = xr.Dataset(coords=midx_coords) reset = ds.reset_index(arg, drop=drop) for name in dropped: @@ -3443,7 +3551,8 @@ def test_reorder_levels(self) -> None: ds = create_test_multiindex() mindex = ds["x"].to_index() midx = mindex.reorder_levels(["level_2", "level_1"]) - expected = Dataset({}, coords={"x": midx}) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + expected = Dataset({}, coords=midx_coords) # check attrs propagated ds["level_1"].attrs["foo"] = "bar" @@ -3508,10 +3617,12 @@ def test_stack(self) -> None: coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, ) - exp_index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) + midx_expected = pd.MultiIndex.from_product( + [[0, 1], ["a", "b"]], names=["x", "y"] + ) + midx_coords_expected = Coordinates.from_pandas_multiindex(midx_expected, "z") expected = Dataset( - data_vars={"b": ("z", [0, 1, 2, 3])}, - coords={"z": exp_index}, + data_vars={"b": ("z", [0, 1, 2, 3])}, coords=midx_coords_expected ) # check attrs propagated ds["x"].attrs["foo"] = "bar" @@ -3532,10 +3643,12 @@ def test_stack(self) -> None: actual = ds.stack(z=[..., "y"]) assert_identical(expected, actual) - exp_index = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=["y", "x"]) + midx_expected = pd.MultiIndex.from_product( + [["a", "b"], [0, 1]], names=["y", "x"] + ) + midx_coords_expected = Coordinates.from_pandas_multiindex(midx_expected, "z") expected = Dataset( - data_vars={"b": ("z", [0, 2, 1, 3])}, - coords={"z": exp_index}, + data_vars={"b": ("z", [0, 2, 1, 3])}, coords=midx_coords_expected ) expected["x"].attrs["foo"] = "bar" @@ -3566,9 +3679,11 @@ def test_stack_create_index(self, create_index, expected_keys) -> None: def test_stack_multi_index(self) -> None: # multi-index on a dimension to stack is discarded too midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("lvl1", "lvl2")) + coords = Coordinates.from_pandas_multiindex(midx, "x") + coords["y"] = [0, 1] ds = xr.Dataset( data_vars={"b": (("x", "y"), [[0, 1], [2, 3], [4, 5], [6, 7]])}, - coords={"x": midx, "y": [0, 1]}, + coords=coords, ) expected = Dataset( data_vars={"b": ("z", [0, 1, 2, 3, 4, 5, 6, 7])}, @@ -3593,10 +3708,8 @@ def test_stack_non_dim_coords(self) -> None: ).rename_vars(x="xx") exp_index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["xx", "y"]) - expected = Dataset( - data_vars={"b": ("z", [0, 1, 2, 3])}, - coords={"z": exp_index}, - ) + exp_coords = Coordinates.from_pandas_multiindex(exp_index, "z") + expected = Dataset(data_vars={"b": ("z", [0, 1, 2, 3])}, coords=exp_coords) actual = ds.stack(z=["x", "y"]) assert_identical(expected, actual) @@ -3604,7 +3717,8 @@ def test_stack_non_dim_coords(self) -> None: def test_unstack(self) -> None: index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) - ds = Dataset(data_vars={"b": ("z", [0, 1, 2, 3])}, coords={"z": index}) + coords = Coordinates.from_pandas_multiindex(index, "z") + ds = Dataset(data_vars={"b": ("z", [0, 1, 2, 3])}, coords=coords) expected = Dataset( {"b": (("x", "y"), [[0, 1], [2, 3]]), "x": [0, 1], "y": ["a", "b"]} ) @@ -3619,7 +3733,10 @@ def test_unstack(self) -> None: def test_unstack_errors(self) -> None: ds = Dataset({"x": [1, 2, 3]}) - with pytest.raises(ValueError, match=r"does not contain the dimensions"): + with pytest.raises( + ValueError, + match=re.escape("Dimensions ('foo',) not found in data dimensions ('x',)"), + ): ds.unstack("foo") with pytest.raises(ValueError, match=r".*do not have exactly one multi-index"): ds.unstack("x") @@ -3666,12 +3783,12 @@ def test_unstack_sparse(self) -> None: assert actual2.variable._to_dense().equals(expected2.variable) assert actual2.data.density < 1.0 - mindex = pd.MultiIndex.from_arrays( - [np.arange(3), np.arange(3)], names=["a", "b"] - ) + midx = pd.MultiIndex.from_arrays([np.arange(3), np.arange(3)], names=["a", "b"]) + coords = Coordinates.from_pandas_multiindex(midx, "z") + coords["foo"] = np.arange(4) + coords["bar"] = np.arange(5) ds_eye = Dataset( - {"var": (("z", "foo", "bar"), np.ones((3, 4, 5)))}, - coords={"z": mindex, "foo": np.arange(4), "bar": np.arange(5)}, + {"var": (("z", "foo", "bar"), np.ones((3, 4, 5)))}, coords=coords ) actual3 = ds_eye.unstack(sparse=True, fill_value=0) assert isinstance(actual3["var"].data, sparse_array_type) @@ -3728,7 +3845,10 @@ def test_to_stacked_array_invalid_sample_dims(self) -> None: data_vars={"a": (("x", "y"), [[0, 1, 2], [3, 4, 5]]), "b": ("x", [6, 7])}, coords={"y": ["u", "v", "w"]}, ) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=r"Variables in the dataset must contain all ``sample_dims`` \(\['y'\]\) but 'b' misses \['y'\]", + ): data.to_stacked_array("features", sample_dims=["y"]) def test_to_stacked_array_name(self) -> None: @@ -4008,7 +4128,8 @@ def test_setitem(self) -> None: data4[{"dim2": [2, 3]}] = data3["var1"][{"dim2": [3, 4]}].values data5 = data4.astype(str) data5["var4"] = data4["var1"] - err_msg = "could not convert string to float: 'a'" + # convert to `np.str_('a')` once `numpy<2.0` has been dropped + err_msg = "could not convert string to float: .*'a'.*" with pytest.raises(ValueError, match=err_msg): data5[{"dim2": 1}] = "a" @@ -4129,6 +4250,29 @@ def test_setitem_align_new_indexes(self) -> None: ) assert_identical(ds, expected) + def test_setitem_vectorized(self) -> None: + # Regression test for GH:7030 + # Positional indexing + da = xr.DataArray(np.r_[:120].reshape(2, 3, 4, 5), dims=["a", "b", "c", "d"]) + ds = xr.Dataset({"da": da}) + b = xr.DataArray([[0, 0], [1, 0]], dims=["u", "v"]) + c = xr.DataArray([[0, 1], [2, 3]], dims=["u", "v"]) + w = xr.DataArray([-1, -2], dims=["u"]) + index = dict(b=b, c=c) + ds[index] = xr.Dataset({"da": w}) + assert (ds[index]["da"] == w).all() + + # Indexing with coordinates + da = xr.DataArray(np.r_[:120].reshape(2, 3, 4, 5), dims=["a", "b", "c", "d"]) + ds = xr.Dataset({"da": da}) + ds.coords["b"] = [2, 4, 6] + b = xr.DataArray([[2, 2], [4, 2]], dims=["u", "v"]) + c = xr.DataArray([[0, 1], [2, 3]], dims=["u", "v"]) + w = xr.DataArray([-1, -2], dims=["u"]) + index = dict(b=b, c=c) + ds.loc[index] = xr.Dataset({"da": w}, coords={"b": ds.coords["b"]}) + assert (ds.loc[index]["da"] == w).all() + @pytest.mark.parametrize("dtype", [str, bytes]) def test_setitem_str_dtype(self, dtype) -> None: ds = xr.Dataset(coords={"x": np.array(["x", "y"], dtype=dtype)}) @@ -4213,22 +4357,58 @@ def test_assign_attrs(self) -> None: def test_assign_multiindex_level(self) -> None: data = create_test_multiindex() - with pytest.raises( - ValueError, match=r"cannot set or update variable.*corrupt.*index " - ): + with pytest.raises(ValueError, match=r"cannot drop or update.*corrupt.*index "): data.assign(level_1=range(4)) data.assign_coords(level_1=range(4)) + def test_assign_new_multiindex(self) -> None: + midx = pd.MultiIndex.from_arrays([["a", "a", "b", "b"], [0, 1, 0, 1]]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + + ds = Dataset(coords={"x": [1, 2]}) + expected = Dataset(coords=midx_coords) + + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + actual = ds.assign(x=midx) + assert_identical(actual, expected) + + @pytest.mark.parametrize("orig_coords", [{}, {"x": range(4)}]) + def test_assign_coords_new_multiindex(self, orig_coords) -> None: + ds = Dataset(coords=orig_coords) + midx = pd.MultiIndex.from_arrays( + [["a", "a", "b", "b"], [0, 1, 0, 1]], names=("one", "two") + ) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + + expected = Dataset(coords=midx_coords) + + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + actual = ds.assign_coords({"x": midx}) + assert_identical(actual, expected) + + actual = ds.assign_coords(midx_coords) + assert_identical(actual, expected) + def test_assign_coords_existing_multiindex(self) -> None: data = create_test_multiindex() - with pytest.warns(FutureWarning, match=r"Updating MultiIndexed coordinate"): - data.assign_coords(x=range(4)) - - with pytest.warns(FutureWarning, match=r"Updating MultiIndexed coordinate"): - data.assign(x=range(4)) + with pytest.warns( + FutureWarning, match=r"updating coordinate.*MultiIndex.*inconsistent" + ): + updated = data.assign_coords(x=range(4)) + # https://github.com/pydata/xarray/issues/7097 (coord names updated) + assert len(updated.coords) == 1 + with pytest.warns( + FutureWarning, match=r"updating coordinate.*MultiIndex.*inconsistent" + ): + updated = data.assign(x=range(4)) # https://github.com/pydata/xarray/issues/7097 (coord names updated) - updated = data.assign_coords(x=range(4)) assert len(updated.coords) == 1 def test_assign_all_multiindex_coords(self) -> None: @@ -4255,6 +4435,25 @@ class CustomIndex(PandasIndex): actual = ds.assign_coords(y=[4, 5, 6]) assert isinstance(actual.xindexes["x"], CustomIndex) + def test_assign_coords_custom_index(self) -> None: + class CustomIndex(Index): + pass + + coords = Coordinates( + coords={"x": ("x", [1, 2, 3])}, indexes={"x": CustomIndex()} + ) + ds = Dataset() + actual = ds.assign_coords(coords) + assert isinstance(actual.xindexes["x"], CustomIndex) + + def test_assign_coords_no_default_index(self) -> None: + coords = Coordinates({"y": [1, 2, 3]}, indexes={}) + ds = Dataset() + actual = ds.assign_coords(coords) + expected = coords.to_dataset() + assert_identical(expected, actual, check_default_indexes=False) + assert "y" not in actual.xindexes + def test_merge_multiindex_level(self) -> None: data = create_test_multiindex() @@ -4590,7 +4789,7 @@ def test_convert_dataframe_with_many_types_and_multiindex(self) -> None: "e": [True, False, True], "f": pd.Categorical(list("abc")), "g": pd.date_range("20130101", periods=3), - "h": pd.date_range("20130101", periods=3, tz="US/Eastern"), + "h": pd.date_range("20130101", periods=3, tz="America/New_York"), } ) df.index = pd.MultiIndex.from_product([["a"], range(3)], names=["one", "two"]) @@ -4854,12 +5053,15 @@ def test_dropna(self) -> None: expected = ds.isel(a=[1, 3]) assert_identical(actual, ds) - with pytest.raises(ValueError, match=r"a single dataset dimension"): + with pytest.raises( + ValueError, + match=r"'foo' not found in data dimensions \('a', 'b'\)", + ): ds.dropna("foo") with pytest.raises(ValueError, match=r"invalid how"): - ds.dropna("a", how="somehow") # type: ignore + ds.dropna("a", how="somehow") # type: ignore[arg-type] with pytest.raises(TypeError, match=r"must specify how or thresh"): - ds.dropna("a", how=None) # type: ignore + ds.dropna("a", how=None) # type: ignore[arg-type] def test_fillna(self) -> None: ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) @@ -5172,7 +5374,10 @@ def test_mean_uint_dtype(self) -> None: def test_reduce_bad_dim(self) -> None: data = create_test_data() - with pytest.raises(ValueError, match=r"Dataset does not contain"): + with pytest.raises( + ValueError, + match=r"Dimensions \('bad_dim',\) not found in data dimensions", + ): data.mean(dim="bad_dim") def test_reduce_cumsum(self) -> None: @@ -5198,7 +5403,10 @@ def test_reduce_cumsum(self) -> None: @pytest.mark.parametrize("func", ["cumsum", "cumprod"]) def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None: data = create_test_data() - with pytest.raises(ValueError, match=r"Dataset does not contain"): + with pytest.raises( + ValueError, + match=r"Dimensions \('bad_dim',\) not found in data dimensions", + ): getattr(data, func)(dim="bad_dim") # ensure dimensions are correct @@ -5374,7 +5582,7 @@ def test_reduce_keepdims(self) -> None: @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) def test_quantile(self, q, skipna) -> None: ds = create_test_data(seed=123) - ds.var1.data[0, 0] = np.NaN + ds.var1.data[0, 0] = np.nan for dim in [None, "dim1", ["dim1"]]: ds_quantile = ds.quantile(q, dim=dim, skipna=skipna) @@ -5446,7 +5654,12 @@ def test_rank(self) -> None: assert list(z.coords) == list(ds.coords) assert list(x.coords) == list(y.coords) # invalid dim - with pytest.raises(ValueError, match=r"does not contain"): + with pytest.raises( + ValueError, + match=re.escape( + "Dimension 'invalid_dim' not found in data dimensions ('dim3', 'dim1')" + ), + ): x.rank("invalid_dim") def test_rank_use_bottleneck(self) -> None: @@ -6201,6 +6414,13 @@ def test_ipython_key_completion(self) -> None: ds["var3"].coords[item] # should not raise assert sorted(actual) == sorted(expected) + coords = Coordinates(ds.coords) + actual = coords._ipython_key_completions_() + expected = ["time", "dim2", "dim3", "numbers"] + for item in actual: + coords[item] # should not raise + assert sorted(actual) == sorted(expected) + # data_vars actual = ds.data_vars._ipython_key_completions_() expected = ["var1", "var2", "var3", "dim1"] @@ -6235,7 +6455,7 @@ def test_polyfit_warnings(self) -> None: with warnings.catch_warnings(record=True) as ws: ds.var1.polyfit("dim2", 10, full=False) assert len(ws) == 1 - assert ws[0].category == np.RankWarning + assert ws[0].category == RankWarning ds.var1.polyfit("dim2", 10, full=True) assert len(ws) == 1 @@ -6562,7 +6782,7 @@ def test_dir_unicode(ds) -> None: def test_raise_no_warning_for_nan_in_binary_ops() -> None: with assert_no_warnings(): - Dataset(data_vars={"x": ("y", [1, 2, np.NaN])}) > 0 + Dataset(data_vars={"x": ("y", [1, 2, np.nan])}) > 0 @pytest.mark.filterwarnings("error") @@ -6972,7 +7192,12 @@ def test_drop_duplicates_1d(self, keep) -> None: result = ds.drop_duplicates("time", keep=keep) assert_equal(expected, result) - with pytest.raises(ValueError, match="['space'] not found"): + with pytest.raises( + ValueError, + match=re.escape( + "Dimensions ('space',) not found in data dimensions ('time',)" + ), + ): ds.drop_duplicates("space", keep=keep) diff --git a/xarray/tests/test_deprecation_helpers.py b/xarray/tests/test_deprecation_helpers.py index 35128829073..f21c8097060 100644 --- a/xarray/tests/test_deprecation_helpers.py +++ b/xarray/tests/test_deprecation_helpers.py @@ -15,15 +15,15 @@ def f1(a, b, *, c="c", d="d"): assert result == (1, 2, 3, 4) with pytest.warns(FutureWarning, match=r".*v0.1"): - result = f1(1, 2, 3) + result = f1(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): - result = f1(1, 2, 3) + result = f1(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): - result = f1(1, 2, 3, 4) + result = f1(1, 2, 3, 4) # type: ignore[misc] assert result == (1, 2, 3, 4) @_deprecate_positional_args("v0.1") @@ -31,7 +31,7 @@ def f2(a="a", *, b="b", c="c", d="d"): return a, b, c, d with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f2(1, 2) + result = f2(1, 2) # type: ignore[misc] assert result == (1, 2, "c", "d") @_deprecate_positional_args("v0.1") @@ -39,11 +39,11 @@ def f3(a, *, b="b", **kwargs): return a, b, kwargs with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f3(1, 2) + result = f3(1, 2) # type: ignore[misc] assert result == (1, 2, {}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f3(1, 2, f="f") + result = f3(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) @_deprecate_positional_args("v0.1") @@ -57,7 +57,7 @@ def f4(a, /, *, b="b", **kwargs): assert result == (1, 2, {"f": "f"}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f4(1, 2, f="f") + result = f4(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) with pytest.raises(TypeError, match=r"Keyword-only param without default"): @@ -80,15 +80,15 @@ def method(self, a, b, *, c="c", d="d"): assert result == (1, 2, 3, 4) with pytest.warns(FutureWarning, match=r".*v0.1"): - result = A1().method(1, 2, 3) + result = A1().method(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): - result = A1().method(1, 2, 3) + result = A1().method(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): - result = A1().method(1, 2, 3, 4) + result = A1().method(1, 2, 3, 4) # type: ignore[misc] assert result == (1, 2, 3, 4) class A2: @@ -97,11 +97,11 @@ def method(self, a=1, b=1, *, c="c", d="d"): return a, b, c, d with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): - result = A2().method(1, 2, 3) + result = A2().method(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): - result = A2().method(1, 2, 3, 4) + result = A2().method(1, 2, 3, 4) # type: ignore[misc] assert result == (1, 2, 3, 4) class A3: @@ -110,11 +110,11 @@ def method(self, a, *, b="b", **kwargs): return a, b, kwargs with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = A3().method(1, 2) + result = A3().method(1, 2) # type: ignore[misc] assert result == (1, 2, {}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = A3().method(1, 2, f="f") + result = A3().method(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) class A4: @@ -129,7 +129,7 @@ def method(self, a, /, *, b="b", **kwargs): assert result == (1, 2, {"f": "f"}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = A4().method(1, 2, f="f") + result = A4().method(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) with pytest.raises(TypeError, match=r"Keyword-only param without default"): diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index a57f8728ef0..bfc37121597 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -6,7 +6,6 @@ import numpy as np import pytest -from packaging.version import Version if TYPE_CHECKING: import dask @@ -127,7 +126,8 @@ def test_dask_distributed_write_netcdf_with_dimensionless_variables( @requires_cftime @requires_netCDF4 -def test_open_mfdataset_can_open_files_with_cftime_index(tmp_path): +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_can_open_files_with_cftime_index(parallel, tmp_path): T = xr.cftime_range("20010101", "20010501", calendar="360_day") Lon = np.arange(100) data = np.random.random((T.size, Lon.size)) @@ -136,9 +136,59 @@ def test_open_mfdataset_can_open_files_with_cftime_index(tmp_path): da.to_netcdf(file_path) with cluster() as (s, [a, b]): with Client(s["address"]): - for parallel in (False, True): - with xr.open_mfdataset(file_path, parallel=parallel) as tf: - assert_identical(tf["test"], da) + with xr.open_mfdataset(file_path, parallel=parallel) as tf: + assert_identical(tf["test"], da) + + +@requires_cftime +@requires_netCDF4 +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_multiple_files_parallel_distributed(parallel, tmp_path): + lon = np.arange(100) + time = xr.cftime_range("20010101", periods=100, calendar="360_day") + data = np.random.random((time.size, lon.size)) + da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test") + + fnames = [] + for i in range(0, 100, 10): + fname = tmp_path / f"test_{i}.nc" + da.isel(time=slice(i, i + 10)).to_netcdf(fname) + fnames.append(fname) + + with cluster() as (s, [a, b]): + with Client(s["address"]): + with xr.open_mfdataset( + fnames, parallel=parallel, concat_dim="time", combine="nested" + ) as tf: + assert_identical(tf["test"], da) + + +# TODO: move this to test_backends.py +@requires_cftime +@requires_netCDF4 +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_multiple_files_parallel(parallel, tmp_path): + if parallel: + pytest.skip( + "Flaky in CI. Would be a welcome contribution to make a similar test reliable." + ) + lon = np.arange(100) + time = xr.cftime_range("20010101", periods=100, calendar="360_day") + data = np.random.random((time.size, lon.size)) + da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test") + + fnames = [] + for i in range(0, 100, 10): + fname = tmp_path / f"test_{i}.nc" + da.isel(time=slice(i, i + 10)).to_netcdf(fname) + fnames.append(fname) + + for get in [dask.threaded.get, dask.multiprocessing.get, dask.local.get_sync, None]: + with dask.config.set(scheduler=get): + with xr.open_mfdataset( + fnames, parallel=parallel, concat_dim="time", combine="nested" + ) as tf: + assert_identical(tf["test"], da) @pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) @@ -194,10 +244,6 @@ def test_dask_distributed_zarr_integration_test( assert_allclose(original, computed) -@pytest.mark.xfail( - condition=Version(distributed.__version__) < Version("2022.02.0"), - reason="https://github.com/dask/distributed/pull/5739", -) @gen_cluster(client=True) async def test_async(c, s, a, b) -> None: x = create_test_data() @@ -230,10 +276,6 @@ def test_hdf5_lock() -> None: assert isinstance(HDF5_LOCK, dask.utils.SerializableLock) -@pytest.mark.xfail( - condition=Version(distributed.__version__) < Version("2022.02.0"), - reason="https://github.com/dask/distributed/pull/5739", -) @gen_cluster(client=True) async def test_serializable_locks(c, s, a, b) -> None: def f(x, lock=None): diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index 490520c8f54..3c2ee5e8f6f 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -10,12 +10,12 @@ "args, expected", [ ([bool], bool), - ([bool, np.string_], np.object_), + ([bool, np.bytes_], np.object_), ([np.float32, np.float64], np.float64), - ([np.float32, np.string_], np.object_), - ([np.unicode_, np.int64], np.object_), - ([np.unicode_, np.unicode_], np.unicode_), - ([np.bytes_, np.unicode_], np.object_), + ([np.float32, np.bytes_], np.object_), + ([np.str_, np.int64], np.object_), + ([np.str_, np.str_], np.str_), + ([np.bytes_, np.str_], np.object_), ], ) def test_result_type(args, expected) -> None: diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index bf5f7d0bdc5..d5c8e0c0d0a 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -218,31 +218,70 @@ def test_attribute_repr(self) -> None: assert "\n" not in newlines assert "\t" not in tabs - def test_index_repr(self): + def test_index_repr(self) -> None: from xarray.core.indexes import Index class CustomIndex(Index): - def __init__(self, names): + names: tuple[str, ...] + + def __init__(self, names: tuple[str, ...]): self.names = names def __repr__(self): return f"CustomIndex(coords={self.names})" - coord_names = ["x", "y"] + coord_names = ("x", "y") index = CustomIndex(coord_names) - name = "x" + names = ("x",) - normal = formatting.summarize_index(name, index, col_width=20) - assert name in normal + normal = formatting.summarize_index(names, index, col_width=20) + assert names[0] in normal + assert len(normal.splitlines()) == len(names) assert "CustomIndex" in normal - CustomIndex._repr_inline_ = ( - lambda self, max_width: f"CustomIndex[{', '.join(self.names)}]" - ) - inline = formatting.summarize_index(name, index, col_width=20) - assert name in inline + class IndexWithInlineRepr(CustomIndex): + def _repr_inline_(self, max_width: int): + return f"CustomIndex[{', '.join(self.names)}]" + + index = IndexWithInlineRepr(coord_names) + inline = formatting.summarize_index(names, index, col_width=20) + assert names[0] in inline assert index._repr_inline_(max_width=40) in inline + @pytest.mark.parametrize( + "names", + ( + ("x",), + ("x", "y"), + ("x", "y", "z"), + ("x", "y", "z", "a"), + ), + ) + def test_index_repr_grouping(self, names) -> None: + from xarray.core.indexes import Index + + class CustomIndex(Index): + def __init__(self, names): + self.names = names + + def __repr__(self): + return f"CustomIndex(coords={self.names})" + + index = CustomIndex(names) + + normal = formatting.summarize_index(names, index, col_width=20) + assert all(name in normal for name in names) + assert len(normal.splitlines()) == len(names) + assert "CustomIndex" in normal + + hint_chars = [line[2] for line in normal.splitlines()] + + if len(names) <= 1: + assert hint_chars == [" "] + else: + assert hint_chars[0] == "┌" and hint_chars[-1] == "└" + assert len(names) == 2 or hint_chars[1:-1] == ["│"] * (len(names) - 2) + def test_diff_array_repr(self) -> None: da_a = xr.DataArray( np.array([[1, 2, 3], [4, 5, 6]], dtype="int64"), @@ -458,7 +497,7 @@ def test_array_repr_variable(self) -> None: def test_array_repr_recursive(self) -> None: # GH:issue:7111 - # direct recurion + # direct recursion var = xr.Variable("x", [0, 1]) var.attrs["x"] = var formatting.array_repr(var) @@ -510,7 +549,7 @@ def _repr_inline_(self, width): return formatted - def __array_function__(self, *args, **kwargs): + def __array_namespace__(self, *args, **kwargs): return NotImplemented @property diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index 7ea5c19019b..6540406e914 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -6,29 +6,31 @@ import xarray as xr from xarray.core import formatting_html as fh +from xarray.core.coordinates import Coordinates @pytest.fixture -def dataarray(): +def dataarray() -> xr.DataArray: return xr.DataArray(np.random.RandomState(0).randn(4, 6)) @pytest.fixture -def dask_dataarray(dataarray): +def dask_dataarray(dataarray: xr.DataArray) -> xr.DataArray: pytest.importorskip("dask") return dataarray.chunk() @pytest.fixture -def multiindex(): - mindex = pd.MultiIndex.from_product( +def multiindex() -> xr.Dataset: + midx = pd.MultiIndex.from_product( [["a", "b"], [1, 2]], names=("level_1", "level_2") ) - return xr.Dataset({}, {"x": mindex}) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + return xr.Dataset({}, midx_coords) @pytest.fixture -def dataset(): +def dataset() -> xr.Dataset: times = pd.date_range("2000-01-01", "2001-12-31", name="time") annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28)) @@ -46,17 +48,17 @@ def dataset(): ) -def test_short_data_repr_html(dataarray) -> None: +def test_short_data_repr_html(dataarray: xr.DataArray) -> None: data_repr = fh.short_data_repr_html(dataarray) assert data_repr.startswith("
    array")
     
     
    -def test_short_data_repr_html_non_str_keys(dataset) -> None:
    +def test_short_data_repr_html_non_str_keys(dataset: xr.Dataset) -> None:
         ds = dataset.assign({2: lambda x: x["tmin"]})
         fh.dataset_repr(ds)
     
     
    -def test_short_data_repr_html_dask(dask_dataarray) -> None:
    +def test_short_data_repr_html_dask(dask_dataarray: xr.DataArray) -> None:
         assert hasattr(dask_dataarray.data, "_repr_html_")
         data_repr = fh.short_data_repr_html(dask_dataarray)
         assert data_repr == dask_dataarray.data._repr_html_()
    @@ -97,7 +99,7 @@ def test_summarize_attrs_with_unsafe_attr_name_and_value() -> None:
         assert "
    <pd.DataFrame>
    " in formatted -def test_repr_of_dataarray(dataarray) -> None: +def test_repr_of_dataarray(dataarray: xr.DataArray) -> None: formatted = fh.array_repr(dataarray) assert "dim_0" in formatted # has an expanded data section @@ -119,12 +121,12 @@ def test_repr_of_dataarray(dataarray) -> None: ) -def test_repr_of_multiindex(multiindex) -> None: +def test_repr_of_multiindex(multiindex: xr.Dataset) -> None: formatted = fh.dataset_repr(multiindex) assert "(x)" in formatted -def test_repr_of_dataset(dataset) -> None: +def test_repr_of_dataset(dataset: xr.Dataset) -> None: formatted = fh.dataset_repr(dataset) # coords, attrs, and data_vars are expanded assert ( @@ -151,7 +153,7 @@ def test_repr_of_dataset(dataset) -> None: assert "<IA>" in formatted -def test_repr_text_fallback(dataset) -> None: +def test_repr_text_fallback(dataset: xr.Dataset) -> None: formatted = fh.dataset_repr(dataset) # Just test that the "pre" block used for fallback to plain text is present. @@ -170,7 +172,7 @@ def test_variable_repr_html() -> None: assert "xarray.Variable" in html -def test_repr_of_nonstr_dataset(dataset) -> None: +def test_repr_of_nonstr_dataset(dataset: xr.Dataset) -> None: ds = dataset.copy() ds.attrs[1] = "Test value" ds[2] = ds["tmin"] @@ -179,7 +181,7 @@ def test_repr_of_nonstr_dataset(dataset) -> None: assert "
    2" in formatted -def test_repr_of_nonstr_dataarray(dataarray) -> None: +def test_repr_of_nonstr_dataarray(dataarray: xr.DataArray) -> None: da = dataarray.rename(dim_0=15) da.attrs[1] = "value" formatted = fh.array_repr(da) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 79b18cf1f5e..320ba999318 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -12,6 +12,7 @@ from xarray import DataArray, Dataset, Variable from xarray.core.groupby import _consolidate_slices from xarray.tests import ( + InaccessibleArray, assert_allclose, assert_array_equal, assert_equal, @@ -231,11 +232,11 @@ def test_da_groupby_quantile() -> None: assert_identical(expected, actual) array = xr.DataArray( - data=[np.NaN, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" + data=[np.nan, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" ) for skipna in (True, False, None): - e = [np.NaN, 5] if skipna is False else [2.5, 5] + e = [np.nan, 5] if skipna is False else [2.5, 5] expected = xr.DataArray(data=e, coords={"x": [1, 2], "quantile": 0.5}, dims="x") actual = array.groupby("x").quantile(0.5, skipna=skipna) @@ -345,12 +346,12 @@ def test_ds_groupby_quantile() -> None: assert_identical(expected, actual) ds = xr.Dataset( - data_vars={"a": ("x", [np.NaN, 2, 3, 4, 5, 6])}, + data_vars={"a": ("x", [np.nan, 2, 3, 4, 5, 6])}, coords={"x": [1, 1, 1, 2, 2, 2]}, ) for skipna in (True, False, None): - e = [np.NaN, 5] if skipna is False else [2.5, 5] + e = [np.nan, 5] if skipna is False else [2.5, 5] expected = xr.Dataset( data_vars={"a": ("x", e)}, coords={"quantile": 0.5, "x": [1, 2]} @@ -727,7 +728,7 @@ def test_groupby_dataset_iter() -> None: def test_groupby_dataset_errors() -> None: data = create_test_data() with pytest.raises(TypeError, match=r"`group` must be"): - data.groupby(np.arange(10)) + data.groupby(np.arange(10)) # type: ignore[arg-type,unused-ignore] with pytest.raises(ValueError, match=r"length does not match"): data.groupby(data["dim1"][:3]) with pytest.raises(TypeError, match=r"`group` must be"): @@ -809,9 +810,9 @@ def test_groupby_math_more() -> None: with pytest.raises(TypeError, match=r"only support binary ops"): grouped + 1 # type: ignore[operator] with pytest.raises(TypeError, match=r"only support binary ops"): - grouped + grouped + grouped + grouped # type: ignore[operator] with pytest.raises(TypeError, match=r"in-place operations"): - ds += grouped + ds += grouped # type: ignore[arg-type] ds = Dataset( { @@ -2392,3 +2393,17 @@ def test_min_count_error(use_flox: bool) -> None: with xr.set_options(use_flox=use_flox): with pytest.raises(TypeError): da.groupby("labels").mean(min_count=1) + + +@requires_dask +def test_groupby_math_auto_chunk(): + da = xr.DataArray( + [[1, 2, 3], [1, 2, 3], [1, 2, 3]], + dims=("y", "x"), + coords={"label": ("x", [2, 2, 1])}, + ) + sub = xr.DataArray( + InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]} + ) + actual = da.chunk(x=1, y=2).groupby("label") - sub + assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)} diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 27b5cf2119c..866c2ef7e85 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -145,6 +145,11 @@ def test_from_variables(self) -> None: with pytest.raises(ValueError, match=r".*only accepts one variable.*"): PandasIndex.from_variables({"x": var, "foo": var2}, options={}) + with pytest.raises( + ValueError, match=r".*cannot set a PandasIndex.*scalar variable.*" + ): + PandasIndex.from_variables({"foo": xr.Variable((), 1)}, options={}) + with pytest.raises( ValueError, match=r".*only accepts a 1-dimensional variable.*" ): @@ -482,7 +487,10 @@ def test_sel(self) -> None: index.sel({"x": 0}) with pytest.raises(ValueError, match=r"cannot provide labels for both.*"): index.sel({"one": 0, "x": "a"}) - with pytest.raises(ValueError, match=r"invalid multi-index level names"): + with pytest.raises( + ValueError, + match=r"multi-index level names \('three',\) not found in indexes", + ): index.sel({"x": {"three": 0}}) with pytest.raises(IndexError): index.sel({"x": (slice(None), 1, "no_level")}) @@ -582,7 +590,12 @@ def indexes( _, variables = indexes_and_vars - return Indexes(indexes, variables) + if isinstance(x_idx, Index): + index_type = Index + else: + index_type = pd.Index + + return Indexes(indexes, variables, index_type=index_type) def test_interface(self, unique_indexes, indexes) -> None: x_idx = unique_indexes[0] diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 9f57b3b9056..10192b6587a 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -307,6 +307,7 @@ def test_lazily_indexed_array(self) -> None: assert expected.shape == actual.shape assert_array_equal(expected, actual) assert isinstance(actual._data, indexing.LazilyIndexedArray) + assert isinstance(v_lazy._data, indexing.LazilyIndexedArray) # make sure actual.key is appropriate type if all( diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 8957f9c829a..c6597d5abb0 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -235,6 +235,13 @@ def test_merge_dicts_dims(self): expected = xr.Dataset({"x": [12], "y": ("x", [13])}) assert_identical(actual, expected) + def test_merge_coordinates(self): + coords1 = xr.Coordinates({"x": ("x", [0, 1, 2])}) + coords2 = xr.Coordinates({"y": ("y", [3, 4, 5])}) + expected = xr.Dataset(coords={"x": [0, 1, 2], "y": [3, 4, 5]}) + actual = xr.merge([coords1, coords2]) + assert_identical(actual, expected) + def test_merge_error(self): ds = xr.Dataset({"x": 0}) with pytest.raises(xr.MergeError): @@ -375,6 +382,16 @@ def test_merge_compat(self): assert ds1.identical(ds1.merge(ds2, compat="override")) + def test_merge_compat_minimal(self) -> None: + # https://github.com/pydata/xarray/issues/7405 + # https://github.com/pydata/xarray/issues/7588 + ds1 = xr.Dataset(coords={"foo": [1, 2, 3], "bar": 4}) + ds2 = xr.Dataset(coords={"foo": [1, 2, 3], "bar": 5}) + + actual = xr.merge([ds1, ds2], compat="minimal") + expected = xr.Dataset(coords={"foo": [1, 2, 3]}) + assert_identical(actual, expected) + def test_merge_auto_align(self): ds1 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) ds2 = xr.Dataset({"b": ("x", [3, 4]), "x": [1, 2]}) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index a6b6b1f80ce..c57d84c927d 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -92,32 +92,43 @@ def make_interpolate_example_data(shape, frac_nan, seed=12345, non_uniform=False return da, df +@pytest.mark.parametrize("fill_value", [None, np.nan, 47.11]) +@pytest.mark.parametrize( + "method", ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] +) @requires_scipy -def test_interpolate_pd_compat(): +def test_interpolate_pd_compat(method, fill_value) -> None: shapes = [(8, 8), (1, 20), (20, 1), (100, 100)] frac_nans = [0, 0.5, 1] - methods = ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] - for shape, frac_nan, method in itertools.product(shapes, frac_nans, methods): + for shape, frac_nan in itertools.product(shapes, frac_nans): da, df = make_interpolate_example_data(shape, frac_nan) for dim in ["time", "x"]: - actual = da.interpolate_na(method=method, dim=dim, fill_value=np.nan) + actual = da.interpolate_na(method=method, dim=dim, fill_value=fill_value) + # need limit_direction="both" here, to let pandas fill + # in both directions instead of default forward direction only expected = df.interpolate( - method=method, axis=da.get_axis_num(dim), fill_value=(np.nan, np.nan) + method=method, + axis=da.get_axis_num(dim), + limit_direction="both", + fill_value=fill_value, ) - # Note, Pandas does some odd things with the left/right fill_value - # for the linear methods. This next line inforces the xarray - # fill_value convention on the pandas output. Therefore, this test - # only checks that interpolated values are the same (not nans) - expected.values[pd.isnull(actual.values)] = np.nan + + if method == "linear": + # Note, Pandas does not take left/right fill_value into account + # for the numpy linear methods. + # see https://github.com/pandas-dev/pandas/issues/55144 + # This aligns the pandas output with the xarray output + expected.values[pd.isnull(actual.values)] = np.nan + expected.values[actual.values == fill_value] = fill_value np.testing.assert_allclose(actual.values, expected.values) @requires_scipy -@pytest.mark.parametrize("method", ["barycentric", "krog", "pchip", "spline", "akima"]) -def test_scipy_methods_function(method): +@pytest.mark.parametrize("method", ["barycentric", "krogh", "pchip", "spline", "akima"]) +def test_scipy_methods_function(method) -> None: # Note: Pandas does some wacky things with these methods and the full # integration tests won't work. da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) @@ -140,7 +151,8 @@ def test_interpolate_pd_compat_non_uniform_index(): method="linear", dim=dim, use_coordinate=True, fill_value=np.nan ) expected = df.interpolate( - method=method, axis=da.get_axis_num(dim), fill_value=np.nan + method=method, + axis=da.get_axis_num(dim), ) # Note, Pandas does some odd things with the left/right fill_value diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py new file mode 100644 index 00000000000..448e8cf819a --- /dev/null +++ b/xarray/tests/test_namedarray.py @@ -0,0 +1,439 @@ +from __future__ import annotations + +import copy +import warnings +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Generic, cast, overload + +import numpy as np +import pytest + +from xarray.core.indexing import ExplicitlyIndexed +from xarray.namedarray._typing import _arrayfunction_or_api, _DType_co, _ShapeType_co +from xarray.namedarray.core import NamedArray, from_array +from xarray.namedarray.utils import _default + +if TYPE_CHECKING: + from types import ModuleType + + from numpy.typing import ArrayLike, DTypeLike, NDArray + + from xarray.namedarray._typing import ( + _AttrsLike, + _DimsLike, + _DType, + _Shape, + duckarray, + ) + from xarray.namedarray.utils import Default + + +class CustomArrayBase(Generic[_ShapeType_co, _DType_co]): + def __init__(self, array: duckarray[Any, _DType_co]) -> None: + self.array: duckarray[Any, _DType_co] = array + + @property + def dtype(self) -> _DType_co: + return self.array.dtype + + @property + def shape(self) -> _Shape: + return self.array.shape + + +class CustomArray( + CustomArrayBase[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] +): + def __array__(self) -> np.ndarray[Any, np.dtype[np.generic]]: + return np.array(self.array) + + +class CustomArrayIndexable( + CustomArrayBase[_ShapeType_co, _DType_co], + ExplicitlyIndexed, + Generic[_ShapeType_co, _DType_co], +): + def __array_namespace__(self) -> ModuleType: + return np + + +@pytest.fixture +def random_inputs() -> np.ndarray[Any, np.dtype[np.float32]]: + return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + + +def test_namedarray_init() -> None: + dtype = np.dtype(np.int8) + expected = np.array([1, 2], dtype=dtype) + actual: NamedArray[Any, np.dtype[np.int8]] + actual = NamedArray(("x",), expected) + assert np.array_equal(np.asarray(actual.data), expected) + + with pytest.raises(AttributeError): + expected2 = [1, 2] + actual2: NamedArray[Any, Any] + actual2 = NamedArray(("x",), expected2) # type: ignore[arg-type] + assert np.array_equal(np.asarray(actual2.data), expected2) + + +@pytest.mark.parametrize( + "dims, data, expected, raise_error", + [ + (("x",), [1, 2, 3], np.array([1, 2, 3]), False), + ((1,), np.array([4, 5, 6]), np.array([4, 5, 6]), False), + ((), 2, np.array(2), False), + # Fail: + (("x",), NamedArray("time", np.array([1, 2, 3])), np.array([1, 2, 3]), True), + ], +) +def test_from_array( + dims: _DimsLike, + data: ArrayLike, + expected: np.ndarray[Any, Any], + raise_error: bool, +) -> None: + actual: NamedArray[Any, Any] + if raise_error: + with pytest.raises(TypeError, match="already a Named array"): + actual = from_array(dims, data) + + # Named arrays are not allowed: + from_array(actual) # type: ignore[call-overload] + else: + actual = from_array(dims, data) + + assert np.array_equal(np.asarray(actual.data), expected) + + +def test_from_array_with_masked_array() -> None: + masked_array: np.ndarray[Any, np.dtype[np.generic]] + masked_array = np.ma.array([1, 2, 3], mask=[False, True, False]) # type: ignore[no-untyped-call] + with pytest.raises(NotImplementedError): + from_array(("x",), masked_array) + + +def test_from_array_with_0d_object() -> None: + data = np.empty((), dtype=object) + data[()] = (10, 12, 12) + narr = from_array((), data) + np.array_equal(np.asarray(narr.data), data) + + +# TODO: Make xr.core.indexing.ExplicitlyIndexed pass as a subclass of_arrayfunction_or_api +# and remove this test. +def test_from_array_with_explicitly_indexed( + random_inputs: np.ndarray[Any, Any] +) -> None: + array: CustomArray[Any, Any] + array = CustomArray(random_inputs) + output: NamedArray[Any, Any] + output = from_array(("x", "y", "z"), array) + assert isinstance(output.data, np.ndarray) + + array2: CustomArrayIndexable[Any, Any] + array2 = CustomArrayIndexable(random_inputs) + output2: NamedArray[Any, Any] + output2 = from_array(("x", "y", "z"), array2) + assert isinstance(output2.data, CustomArrayIndexable) + + +def test_properties() -> None: + data = 0.5 * np.arange(10).reshape(2, 5) + named_array: NamedArray[Any, Any] + named_array = NamedArray(["x", "y"], data, {"key": "value"}) + assert named_array.dims == ("x", "y") + assert np.array_equal(np.asarray(named_array.data), data) + assert named_array.attrs == {"key": "value"} + assert named_array.ndim == 2 + assert named_array.sizes == {"x": 2, "y": 5} + assert named_array.size == 10 + assert named_array.nbytes == 80 + assert len(named_array) == 2 + + +def test_attrs() -> None: + named_array: NamedArray[Any, Any] + named_array = NamedArray(["x", "y"], np.arange(10).reshape(2, 5)) + assert named_array.attrs == {} + named_array.attrs["key"] = "value" + assert named_array.attrs == {"key": "value"} + named_array.attrs = {"key": "value2"} + assert named_array.attrs == {"key": "value2"} + + +def test_data(random_inputs: np.ndarray[Any, Any]) -> None: + named_array: NamedArray[Any, Any] + named_array = NamedArray(["x", "y", "z"], random_inputs) + assert np.array_equal(np.asarray(named_array.data), random_inputs) + with pytest.raises(ValueError): + named_array.data = np.random.random((3, 4)).astype(np.float64) + + +def test_real_and_imag() -> None: + expected_real: np.ndarray[Any, np.dtype[np.float64]] + expected_real = np.arange(3, dtype=np.float64) + + expected_imag: np.ndarray[Any, np.dtype[np.float64]] + expected_imag = -np.arange(3, dtype=np.float64) + + arr: np.ndarray[Any, np.dtype[np.complex128]] + arr = expected_real + 1j * expected_imag + + named_array: NamedArray[Any, np.dtype[np.complex128]] + named_array = NamedArray(["x"], arr) + + actual_real: duckarray[Any, np.dtype[np.float64]] = named_array.real.data + assert np.array_equal(np.asarray(actual_real), expected_real) + assert actual_real.dtype == expected_real.dtype + + actual_imag: duckarray[Any, np.dtype[np.float64]] = named_array.imag.data + assert np.array_equal(np.asarray(actual_imag), expected_imag) + assert actual_imag.dtype == expected_imag.dtype + + +# Additional tests as per your original class-based code +@pytest.mark.parametrize( + "data, dtype", + [ + ("foo", np.dtype("U3")), + (b"foo", np.dtype("S3")), + ], +) +def test_0d_string(data: Any, dtype: DTypeLike) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], data) + assert named_array.data == data + assert named_array.dims == () + assert named_array.sizes == {} + assert named_array.attrs == {} + assert named_array.ndim == 0 + assert named_array.size == 1 + assert named_array.dtype == dtype + + +def test_0d_object() -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], (10, 12, 12)) + expected_data = np.empty((), dtype=object) + expected_data[()] = (10, 12, 12) + assert np.array_equal(np.asarray(named_array.data), expected_data) + + assert named_array.dims == () + assert named_array.sizes == {} + assert named_array.attrs == {} + assert named_array.ndim == 0 + assert named_array.size == 1 + assert named_array.dtype == np.dtype("O") + + +def test_0d_datetime() -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], np.datetime64("2000-01-01")) + assert named_array.dtype == np.dtype("datetime64[D]") + + +@pytest.mark.parametrize( + "timedelta, expected_dtype", + [ + (np.timedelta64(1, "D"), np.dtype("timedelta64[D]")), + (np.timedelta64(1, "s"), np.dtype("timedelta64[s]")), + (np.timedelta64(1, "m"), np.dtype("timedelta64[m]")), + (np.timedelta64(1, "h"), np.dtype("timedelta64[h]")), + (np.timedelta64(1, "us"), np.dtype("timedelta64[us]")), + (np.timedelta64(1, "ns"), np.dtype("timedelta64[ns]")), + (np.timedelta64(1, "ps"), np.dtype("timedelta64[ps]")), + (np.timedelta64(1, "fs"), np.dtype("timedelta64[fs]")), + (np.timedelta64(1, "as"), np.dtype("timedelta64[as]")), + ], +) +def test_0d_timedelta( + timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64] +) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], timedelta) + assert named_array.dtype == expected_dtype + assert named_array.data == timedelta + + +@pytest.mark.parametrize( + "dims, data_shape, new_dims, raises", + [ + (["x", "y", "z"], (2, 3, 4), ["a", "b", "c"], False), + (["x", "y", "z"], (2, 3, 4), ["a", "b"], True), + (["x", "y", "z"], (2, 4, 5), ["a", "b", "c", "d"], True), + ([], [], (), False), + ([], [], ("x",), True), + ], +) +def test_dims_setter(dims: Any, data_shape: Any, new_dims: Any, raises: bool) -> None: + named_array: NamedArray[Any, Any] + named_array = NamedArray(dims, np.asarray(np.random.random(data_shape))) + assert named_array.dims == tuple(dims) + if raises: + with pytest.raises(ValueError): + named_array.dims = new_dims + else: + named_array.dims = new_dims + assert named_array.dims == tuple(new_dims) + + +def test_duck_array_class() -> None: + def test_duck_array_typevar(a: duckarray[Any, _DType]) -> duckarray[Any, _DType]: + # Mypy checks a is valid: + b: duckarray[Any, _DType] = a + + # Runtime check if valid: + if isinstance(b, _arrayfunction_or_api): + return b + else: + raise TypeError(f"a ({type(a)}) is not a valid _arrayfunction or _arrayapi") + + numpy_a: NDArray[np.int64] + numpy_a = np.array([2.1, 4], dtype=np.dtype(np.int64)) + custom_a: CustomArrayIndexable[Any, np.dtype[np.int64]] + custom_a = CustomArrayIndexable(numpy_a) + + test_duck_array_typevar(numpy_a) + test_duck_array_typevar(custom_a) + + # Test numpy's array api: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + r"The numpy.array_api submodule is still experimental", + category=UserWarning, + ) + import numpy.array_api as nxp + + # TODO: nxp doesn't use dtype typevars, so can only use Any for the moment: + arrayapi_a: duckarray[Any, Any] # duckarray[Any, np.dtype[np.int64]] + arrayapi_a = nxp.asarray([2.1, 4], dtype=np.dtype(np.int64)) + test_duck_array_typevar(arrayapi_a) + + +def test_new_namedarray() -> None: + dtype_float = np.dtype(np.float32) + narr_float: NamedArray[Any, np.dtype[np.float32]] + narr_float = NamedArray(("x",), np.array([1.5, 3.2], dtype=dtype_float)) + assert narr_float.dtype == dtype_float + + dtype_int = np.dtype(np.int8) + narr_int: NamedArray[Any, np.dtype[np.int8]] + narr_int = narr_float._new(("x",), np.array([1, 3], dtype=dtype_int)) + assert narr_int.dtype == dtype_int + + # Test with a subclass: + class Variable( + NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] + ): + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: duckarray[Any, _DType] = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[Any, _DType]: + ... + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[_ShapeType_co, _DType_co]: + ... + + def _new( + self, + dims: _DimsLike | Default = _default, + data: duckarray[Any, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]: + dims_ = copy.copy(self._dims) if dims is _default else dims + + attrs_: Mapping[Any, Any] | None + if attrs is _default: + attrs_ = None if self._attrs is None else self._attrs.copy() + else: + attrs_ = attrs + + if data is _default: + return type(self)(dims_, copy.copy(self._data), attrs_) + else: + cls_ = cast("type[Variable[Any, _DType]]", type(self)) + return cls_(dims_, data, attrs_) + + var_float: Variable[Any, np.dtype[np.float32]] + var_float = Variable(("x",), np.array([1.5, 3.2], dtype=dtype_float)) + assert var_float.dtype == dtype_float + + var_int: Variable[Any, np.dtype[np.int8]] + var_int = var_float._new(("x",), np.array([1, 3], dtype=dtype_int)) + assert var_int.dtype == dtype_int + + +def test_replace_namedarray() -> None: + dtype_float = np.dtype(np.float32) + np_val: np.ndarray[Any, np.dtype[np.float32]] + np_val = np.array([1.5, 3.2], dtype=dtype_float) + np_val2: np.ndarray[Any, np.dtype[np.float32]] + np_val2 = 2 * np_val + + narr_float: NamedArray[Any, np.dtype[np.float32]] + narr_float = NamedArray(("x",), np_val) + assert narr_float.dtype == dtype_float + + narr_float2: NamedArray[Any, np.dtype[np.float32]] + narr_float2 = NamedArray(("x",), np_val2) + assert narr_float2.dtype == dtype_float + + # Test with a subclass: + class Variable( + NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] + ): + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: duckarray[Any, _DType] = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[Any, _DType]: + ... + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[_ShapeType_co, _DType_co]: + ... + + def _new( + self, + dims: _DimsLike | Default = _default, + data: duckarray[Any, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]: + dims_ = copy.copy(self._dims) if dims is _default else dims + + attrs_: Mapping[Any, Any] | None + if attrs is _default: + attrs_ = None if self._attrs is None else self._attrs.copy() + else: + attrs_ = attrs + + if data is _default: + return type(self)(dims_, copy.copy(self._data), attrs_) + else: + cls_ = cast("type[Variable[Any, _DType]]", type(self)) + return cls_(dims_, data, attrs_) + + var_float: Variable[Any, np.dtype[np.float32]] + var_float = Variable(("x",), np_val) + assert var_float.dtype == dtype_float + + var_float2: Variable[Any, np.dtype[np.float32]] + var_float2 = var_float._replace(("x",), np_val2) + assert var_float2.dtype == dtype_float diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index 3cecf1b52ec..8ad1cbe11be 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -165,7 +165,6 @@ def test_concat_attr_retention(self) -> None: result = concat([ds1, ds2], dim="dim1") assert result.attrs == original_attrs - @pytest.mark.xfail def test_merge_attr_retention(self) -> None: da1 = create_test_dataarray_attrs(var="var1") da2 = create_test_dataarray_attrs(var="var2") diff --git a/xarray/tests/test_parallelcompat.py b/xarray/tests/test_parallelcompat.py index 2c3378a2816..ea324cafb76 100644 --- a/xarray/tests/test_parallelcompat.py +++ b/xarray/tests/test_parallelcompat.py @@ -12,7 +12,7 @@ guess_chunkmanager, list_chunkmanagers, ) -from xarray.core.types import T_Chunks, T_NormalizedChunks +from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks from xarray.tests import has_dask, requires_dask @@ -76,7 +76,7 @@ def normalize_chunks( return normalize_chunks(chunks, shape, limit, dtype, previous_chunks) def from_array( - self, data: np.ndarray, chunks: T_Chunks, **kwargs + self, data: T_DuckArray | np.typing.ArrayLike, chunks: T_Chunks, **kwargs ) -> DummyChunkedArray: from dask import array as da diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 8b2dfbdec41..31c23955b02 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -43,6 +43,7 @@ # import mpl and change the backend before other mpl imports try: import matplotlib as mpl + import matplotlib.dates import matplotlib.pyplot as plt import mpl_toolkits except ImportError: @@ -167,7 +168,14 @@ def imshow_called(self, plotmethod): def contourf_called(self, plotmethod): plotmethod() - paths = plt.gca().findobj(mpl.collections.PathCollection) + + # Compatible with mpl before (PathCollection) and after (QuadContourSet) 3.8 + def matchfunc(x): + return isinstance( + x, (mpl.collections.PathCollection, mpl.contour.QuadContourSet) + ) + + paths = plt.gca().findobj(matchfunc) return len(paths) > 0 @@ -421,6 +429,7 @@ def test2d_1d_2d_coordinates_pcolormesh(self) -> None: ]: p = a.plot.pcolormesh(x=x, y=y) v = p.get_paths()[0].vertices + assert isinstance(v, np.ndarray) # Check all vertices are different, except last vertex which should be the # same as the first @@ -440,7 +449,7 @@ def test_str_coordinates_pcolormesh(self) -> None: def test_contourf_cmap_set(self) -> None: a = DataArray(easy_array((4, 4)), dims=["z", "time"]) - cmap = mpl.cm.viridis + cmap_expected = mpl.colormaps["viridis"] # use copy to ensure cmap is not changed by contourf() # Set vmin and vmax so that _build_discrete_colormap is called with @@ -450,55 +459,59 @@ def test_contourf_cmap_set(self) -> None: # extend='neither' (but if extend='neither' the under and over values # would not be used because the data would all be within the plotted # range) - pl = a.plot.contourf(cmap=copy(cmap), vmin=0.1, vmax=0.9) + pl = a.plot.contourf(cmap=copy(cmap_expected), vmin=0.1, vmax=0.9) # check the set_bad color + cmap = pl.cmap + assert cmap is not None assert_array_equal( - pl.cmap(np.ma.masked_invalid([np.nan]))[0], cmap(np.ma.masked_invalid([np.nan]))[0], + cmap_expected(np.ma.masked_invalid([np.nan]))[0], ) # check the set_under color - assert pl.cmap(-np.inf) == cmap(-np.inf) + assert cmap(-np.inf) == cmap_expected(-np.inf) # check the set_over color - assert pl.cmap(np.inf) == cmap(np.inf) + assert cmap(np.inf) == cmap_expected(np.inf) def test_contourf_cmap_set_with_bad_under_over(self) -> None: a = DataArray(easy_array((4, 4)), dims=["z", "time"]) # make a copy here because we want a local cmap that we will modify. - cmap = copy(mpl.cm.viridis) + cmap_expected = copy(mpl.colormaps["viridis"]) - cmap.set_bad("w") + cmap_expected.set_bad("w") # check we actually changed the set_bad color assert np.all( - cmap(np.ma.masked_invalid([np.nan]))[0] - != mpl.cm.viridis(np.ma.masked_invalid([np.nan]))[0] + cmap_expected(np.ma.masked_invalid([np.nan]))[0] + != mpl.colormaps["viridis"](np.ma.masked_invalid([np.nan]))[0] ) - cmap.set_under("r") + cmap_expected.set_under("r") # check we actually changed the set_under color - assert cmap(-np.inf) != mpl.cm.viridis(-np.inf) + assert cmap_expected(-np.inf) != mpl.colormaps["viridis"](-np.inf) - cmap.set_over("g") + cmap_expected.set_over("g") # check we actually changed the set_over color - assert cmap(np.inf) != mpl.cm.viridis(-np.inf) + assert cmap_expected(np.inf) != mpl.colormaps["viridis"](-np.inf) # copy to ensure cmap is not changed by contourf() - pl = a.plot.contourf(cmap=copy(cmap)) + pl = a.plot.contourf(cmap=copy(cmap_expected)) + cmap = pl.cmap + assert cmap is not None # check the set_bad color has been kept assert_array_equal( - pl.cmap(np.ma.masked_invalid([np.nan]))[0], cmap(np.ma.masked_invalid([np.nan]))[0], + cmap_expected(np.ma.masked_invalid([np.nan]))[0], ) # check the set_under color has been kept - assert pl.cmap(-np.inf) == cmap(-np.inf) + assert cmap(-np.inf) == cmap_expected(-np.inf) # check the set_over color has been kept - assert pl.cmap(np.inf) == cmap(np.inf) + assert cmap(np.inf) == cmap_expected(np.inf) def test3d(self) -> None: self.darray.plot() @@ -716,6 +729,13 @@ def test_labels_with_units_with_interval(self, dim) -> None: expected = "dim_0_bins_center [m]" assert actual == expected + def test_multiplot_over_length_one_dim(self) -> None: + a = easy_array((3, 1, 1, 1)) + d = DataArray(a, dims=("x", "col", "row", "hue")) + d.plot(col="col") + d.plot(row="row") + d.plot(hue="hue") + class TestPlot1D(PlotTestCase): @pytest.fixture(autouse=True) @@ -831,19 +851,25 @@ def test_coord_with_interval_step(self) -> None: """Test step plot with intervals.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step() - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) def test_coord_with_interval_step_x(self) -> None: """Test step plot with intervals explicitly on x axis.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(x="dim_0_bins") - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) def test_coord_with_interval_step_y(self) -> None: """Test step plot with intervals explicitly on y axis.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins") - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) def test_coord_with_interval_step_x_and_y_raises_valueeerror(self) -> None: """Test that step plot with intervals both on x and y axes raises an error.""" @@ -883,8 +909,11 @@ def test_can_pass_in_axis(self) -> None: self.pass_in_axis(self.darray.plot.hist) def test_primitive_returned(self) -> None: - h = self.darray.plot.hist() - assert isinstance(h[-1][0], mpl.patches.Rectangle) + n, bins, patches = self.darray.plot.hist() + assert isinstance(n, np.ndarray) + assert isinstance(bins, np.ndarray) + assert isinstance(patches, mpl.container.BarContainer) + assert isinstance(patches[0], mpl.patches.Rectangle) @pytest.mark.slow def test_plot_nans(self) -> None: @@ -928,9 +957,9 @@ def test_cmap_sequential_option(self) -> None: assert cmap_params["cmap"] == "magma" def test_cmap_sequential_explicit_option(self) -> None: - with xr.set_options(cmap_sequential=mpl.cm.magma): + with xr.set_options(cmap_sequential=mpl.colormaps["magma"]): cmap_params = _determine_cmap_params(self.data) - assert cmap_params["cmap"] == mpl.cm.magma + assert cmap_params["cmap"] == mpl.colormaps["magma"] def test_cmap_divergent_option(self) -> None: with xr.set_options(cmap_divergent="magma"): @@ -1170,7 +1199,7 @@ def test_discrete_colormap_list_of_levels(self) -> None: def test_discrete_colormap_int_levels(self) -> None: for extend, levels, vmin, vmax, cmap in [ ("neither", 7, None, None, None), - ("neither", 7, None, 20, mpl.cm.RdBu), + ("neither", 7, None, 20, mpl.colormaps["RdBu"]), ("both", 7, 4, 8, None), ("min", 10, 4, 15, None), ]: @@ -1720,8 +1749,8 @@ class TestContour(Common2dMixin, PlotTestCase): # matplotlib cmap.colors gives an rgbA ndarray # when seaborn is used, instead we get an rgb tuple @staticmethod - def _color_as_tuple(c): - return tuple(c[:3]) + def _color_as_tuple(c: Any) -> tuple[Any, Any, Any]: + return c[0], c[1], c[2] def test_colors(self) -> None: # with single color, we don't want rgb array @@ -1743,10 +1772,16 @@ def test_colors_np_levels(self) -> None: # https://github.com/pydata/xarray/issues/3284 levels = np.array([-0.5, 0.0, 0.5, 1.0]) artist = self.darray.plot.contour(levels=levels, colors=["k", "r", "w", "b"]) - assert self._color_as_tuple(artist.cmap.colors[1]) == (1.0, 0.0, 0.0) - assert self._color_as_tuple(artist.cmap.colors[2]) == (1.0, 1.0, 1.0) + cmap = artist.cmap + assert isinstance(cmap, mpl.colors.ListedColormap) + colors = cmap.colors + assert isinstance(colors, list) + + assert self._color_as_tuple(colors[1]) == (1.0, 0.0, 0.0) + assert self._color_as_tuple(colors[2]) == (1.0, 1.0, 1.0) # the last color is now under "over" - assert self._color_as_tuple(artist.cmap._rgba_over) == (0.0, 0.0, 1.0) + assert hasattr(cmap, "_rgba_over") + assert self._color_as_tuple(cmap._rgba_over) == (0.0, 0.0, 1.0) def test_cmap_and_color_both(self) -> None: with pytest.raises(ValueError): @@ -1798,7 +1833,9 @@ def test_dont_infer_interval_breaks_for_cartopy(self) -> None: artist = self.plotmethod(x="x2d", y="y2d", ax=ax) assert isinstance(artist, mpl.collections.QuadMesh) # Let cartopy handle the axis limits and artist size - assert artist.get_array().size <= self.darray.size + arr = artist.get_array() + assert arr is not None + assert arr.size <= self.darray.size class TestPcolormeshLogscale(PlotTestCase): @@ -1949,6 +1986,7 @@ def test_normalize_rgb_imshow( ) -> None: da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) arr = da.plot.imshow(vmin=vmin, vmax=vmax, robust=robust).get_array() + assert arr is not None assert 0 <= arr.min() <= arr.max() <= 1 def test_normalize_rgb_one_arg_error(self) -> None: @@ -1965,7 +2003,10 @@ def test_imshow_rgb_values_in_valid_range(self) -> None: da = DataArray(np.arange(75, dtype="uint8").reshape((5, 5, 3))) _, ax = plt.subplots() out = da.plot.imshow(ax=ax).get_array() - assert out.dtype == np.uint8 + assert out is not None + dtype = out.dtype + assert dtype is not None + assert dtype == np.uint8 assert (out[..., :3] == da.values).all() # Compare without added alpha @pytest.mark.filterwarnings("ignore:Several dimensions of this array") @@ -2000,6 +2041,7 @@ def test_2d_coord_names(self) -> None: self.plotmethod(x="x2d", y="y2d") # make sure labels came out ok ax = plt.gca() + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) assert "x2d" == ax.get_xlabel() assert "y2d" == ax.get_ylabel() assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() @@ -2122,6 +2164,7 @@ def test_colorbar(self) -> None: self.g.map_dataarray(xplt.imshow, "x", "y") for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) clim = np.array(image.get_clim()) assert np.allclose(expected, clim) @@ -2132,7 +2175,9 @@ def test_colorbar_scatter(self) -> None: fg: xplt.FacetGrid = ds.plot.scatter(x="a", y="a", row="x", hue="a") cbar = fg.cbar assert cbar is not None + assert hasattr(cbar, "vmin") assert cbar.vmin == 0 + assert hasattr(cbar, "vmax") assert cbar.vmax == 3 @pytest.mark.slow @@ -2199,6 +2244,7 @@ def test_can_set_vmin_vmax(self) -> None: self.g.map_dataarray(xplt.imshow, "x", "y", vmin=vmin, vmax=vmax) for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) clim = np.array(image.get_clim()) assert np.allclose(expected, clim) @@ -2215,6 +2261,7 @@ def test_can_set_norm(self) -> None: norm = mpl.colors.SymLogNorm(0.1) self.g.map_dataarray(xplt.imshow, "x", "y", norm=norm) for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) assert image.norm is norm @pytest.mark.slow @@ -2721,7 +2768,7 @@ def test_datetime_hue(self) -> None: def test_facetgrid_hue_style(self) -> None: ds2 = self.ds.copy() - # Numbers plots as continous: + # Numbers plots as continuous: g = ds2.plot.scatter(x="A", y="B", row="row", col="col", hue="hue") assert isinstance(g._mappables[-1], mpl.collections.PathCollection) @@ -2752,15 +2799,20 @@ def test_non_numeric_legend(self) -> None: ds2 = self.ds.copy() ds2["hue"] = ["a", "b", "c", "d"] pc = ds2.plot.scatter(x="A", y="B", markersize="hue") + axes = pc.axes + assert axes is not None # should make a discrete legend - assert pc.axes.legend_ is not None + assert hasattr(axes, "legend_") + assert axes.legend_ is not None def test_legend_labels(self) -> None: # regression test for #4126: incorrect legend labels ds2 = self.ds.copy() ds2["hue"] = ["a", "a", "b", "b"] pc = ds2.plot.scatter(x="A", y="B", markersize="hue") - actual = [t.get_text() for t in pc.axes.get_legend().texts] + axes = pc.axes + assert axes is not None + actual = [t.get_text() for t in axes.get_legend().texts] expected = ["hue", "a", "b"] assert actual == expected @@ -2781,7 +2833,9 @@ def test_legend_labels_facetgrid(self) -> None: def test_add_legend_by_default(self) -> None: sc = self.ds.plot.scatter(x="A", y="B", hue="hue") - assert len(sc.figure.axes) == 2 + fig = sc.figure + assert fig is not None + assert len(fig.axes) == 2 class TestDatetimePlot(PlotTestCase): @@ -2834,6 +2888,7 @@ def test_datetime_plot2d(self) -> None: p = da.plot.pcolormesh() ax = p.axes + assert ax is not None # Make sure only mpl converters are used, use type() so only # mpl.dates.AutoDateLocator passes and no other subclasses: diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index ddc193712ae..3b213db0b88 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd import pytest -from packaging.version import Version import xarray as xr from xarray import DataArray, Dataset, set_options @@ -78,6 +77,12 @@ def test_rolling_properties(self, da) -> None: with pytest.raises(ValueError, match="min_periods must be greater than zero"): da.rolling(time=2, min_periods=0) + with pytest.raises( + KeyError, + match=r"\('foo',\) not found in DataArray dimensions", + ): + da.rolling(foo=2) + @pytest.mark.parametrize("name", ("sum", "mean", "std", "min", "max", "median")) @pytest.mark.parametrize("center", (True, False, None)) @pytest.mark.parametrize("min_periods", (1, None)) @@ -170,7 +175,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None: @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("window", (1, 2, 3, 4)) - def test_rolling_construct(self, center, window) -> None: + def test_rolling_construct(self, center: bool, window: int) -> None: s = pd.Series(np.arange(10)) da = DataArray.from_series(s) @@ -389,16 +394,8 @@ class TestDataArrayRollingExp: [["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]], ) @pytest.mark.parametrize("backend", ["numpy"], indirect=True) - @pytest.mark.parametrize("func", ["mean", "sum"]) + @pytest.mark.parametrize("func", ["mean", "sum", "var", "std"]) def test_rolling_exp_runs(self, da, dim, window_type, window, func) -> None: - import numbagg - - if ( - Version(getattr(numbagg, "__version__", "0.1.0")) < Version("0.2.1") - and func == "sum" - ): - pytest.skip("rolling_exp.sum requires numbagg 0.2.1") - da = da.where(da > 0.2) rolling_exp = da.rolling_exp(window_type=window_type, **{dim: window}) @@ -430,14 +427,6 @@ def test_rolling_exp_mean_pandas(self, da, dim, window_type, window) -> None: @pytest.mark.parametrize("backend", ["numpy"], indirect=True) @pytest.mark.parametrize("func", ["mean", "sum"]) def test_rolling_exp_keep_attrs(self, da, func) -> None: - import numbagg - - if ( - Version(getattr(numbagg, "__version__", "0.1.0")) < Version("0.2.1") - and func == "sum" - ): - pytest.skip("rolling_exp.sum requires numbagg 0.2.1") - attrs = {"attrs": "da"} da.attrs = attrs @@ -557,6 +546,11 @@ def test_rolling_properties(self, ds) -> None: ds.rolling(time=2, min_periods=0) with pytest.raises(KeyError, match="time2"): ds.rolling(time2=2) + with pytest.raises( + KeyError, + match=r"\('foo',\) not found in Dataset dimensions", + ): + ds.rolling(foo=2) @pytest.mark.parametrize( "name", ("sum", "mean", "std", "var", "min", "max", "median") @@ -616,7 +610,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None: @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("window", (1, 2, 3, 4)) - def test_rolling_construct(self, center, window) -> None: + def test_rolling_construct(self, center: bool, window: int) -> None: df = pd.DataFrame( { "x": np.random.randn(20), @@ -633,12 +627,6 @@ def test_rolling_construct(self, center, window) -> None: np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values) np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"]) - # with stride - ds_rolling_mean = ds_rolling.construct("window", stride=2).mean("window") - np.testing.assert_allclose( - df_rolling["x"][::2].values, ds_rolling_mean["x"].values - ) - np.testing.assert_allclose(df_rolling.index[::2], ds_rolling_mean["index"]) # with fill_value ds_rolling_mean = ds_rolling.construct("window", stride=2, fill_value=0.0).mean( "window" @@ -646,6 +634,51 @@ def test_rolling_construct(self, center, window) -> None: assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all() assert (ds_rolling_mean["x"] == 0.0).sum() >= 0 + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + def test_rolling_construct_stride(self, center: bool, window: int) -> None: + df = pd.DataFrame( + { + "x": np.random.randn(20), + "y": np.random.randn(20), + "time": np.linspace(0, 1, 20), + } + ) + ds = Dataset.from_dataframe(df) + df_rolling_mean = df.rolling(window, center=center, min_periods=1).mean() + + # With an index (dimension coordinate) + ds_rolling = ds.rolling(index=window, center=center) + ds_rolling_mean = ds_rolling.construct("w", stride=2).mean("w") + np.testing.assert_allclose( + df_rolling_mean["x"][::2].values, ds_rolling_mean["x"].values + ) + np.testing.assert_allclose(df_rolling_mean.index[::2], ds_rolling_mean["index"]) + + # Without index (https://github.com/pydata/xarray/issues/7021) + ds2 = ds.drop_vars("index") + ds2_rolling = ds2.rolling(index=window, center=center) + ds2_rolling_mean = ds2_rolling.construct("w", stride=2).mean("w") + np.testing.assert_allclose( + df_rolling_mean["x"][::2].values, ds2_rolling_mean["x"].values + ) + + # Mixed coordinates, indexes and 2D coordinates + ds3 = xr.Dataset( + {"x": ("t", range(20)), "x2": ("y", range(5))}, + { + "t": range(20), + "y": ("y", range(5)), + "t2": ("t", range(20)), + "y2": ("y", range(5)), + "yt": (["t", "y"], np.ones((20, 5))), + }, + ) + ds3_rolling = ds3.rolling(t=window, center=center) + ds3_rolling_mean = ds3_rolling.construct("w", stride=2).mean("w") + for coord in ds3.coords: + assert coord in ds3_rolling_mean.coords + @pytest.mark.slow @pytest.mark.parametrize("ds", (1, 2), indirect=True) @pytest.mark.parametrize("center", (True, False)) @@ -733,9 +766,7 @@ def test_ndrolling_construct(self, center, fill_value, dask) -> None: ) assert_allclose(actual, expected) - @pytest.mark.xfail( - reason="See https://github.com/pydata/xarray/pull/4369 or docstring" - ) + @requires_dask @pytest.mark.filterwarnings("error") @pytest.mark.parametrize("ds", (2,), indirect=True) @pytest.mark.parametrize("name", ("mean", "max")) @@ -757,7 +788,9 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None: @requires_numbagg class TestDatasetRollingExp: - @pytest.mark.parametrize("backend", ["numpy"], indirect=True) + @pytest.mark.parametrize( + "backend", ["numpy", pytest.param("dask", marks=requires_dask)], indirect=True + ) def test_rolling_exp(self, ds) -> None: result = ds.rolling_exp(time=10, window_type="span").mean() assert isinstance(result, Dataset) @@ -778,13 +811,18 @@ def test_rolling_exp_keep_attrs(self, ds) -> None: # discard attrs result = ds.rolling_exp(time=10).mean(keep_attrs=False) assert result.attrs == {} - assert result.z1.attrs == {} + # TODO: from #8114 — this arguably should be empty, but `apply_ufunc` doesn't do + # that at the moment. We should change in `apply_func` rather than + # special-case it here. + # + # assert result.z1.attrs == {} # test discard attrs using global option with set_options(keep_attrs=False): result = ds.rolling_exp(time=10).mean() assert result.attrs == {} - assert result.z1.attrs == {} + # See above + # assert result.z1.attrs == {} # keyword takes precedence over global option with set_options(keep_attrs=False): @@ -795,7 +833,8 @@ def test_rolling_exp_keep_attrs(self, ds) -> None: with set_options(keep_attrs=True): result = ds.rolling_exp(time=10).mean(keep_attrs=False) assert result.attrs == {} - assert result.z1.attrs == {} + # See above + # assert result.z1.attrs == {} with pytest.warns( UserWarning, diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index f64ce9338d7..489836b70fd 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -147,7 +147,6 @@ def test_variable_property(prop): ], ), True, - marks=xfail(reason="Coercion to dense"), ), param( do("conjugate"), @@ -201,7 +200,6 @@ def test_variable_property(prop): param( do("reduce", func="sum", dim="x"), True, - marks=xfail(reason="Coercion to dense"), ), param( do("rolling_window", dim="x", window=2, window_dim="x_win"), @@ -218,7 +216,7 @@ def test_variable_property(prop): param( do("var"), False, marks=xfail(reason="Missing implementation for np.nanvar") ), - param(do("to_dict"), False, marks=xfail(reason="Coercion to dense")), + param(do("to_dict"), False), (do("where", cond=make_xrvar({"x": 10, "y": 5}) > 0.5), True), ], ids=repr, @@ -237,7 +235,14 @@ def test_variable_method(func, sparse_output): assert isinstance(ret_s.data, sparse.SparseArray) assert np.allclose(ret_s.data.todense(), ret_d.data, equal_nan=True) else: - assert np.allclose(ret_s, ret_d, equal_nan=True) + if func.meth != "to_dict": + assert np.allclose(ret_s, ret_d) + else: + # pop the arrays from the dict + arr_s, arr_d = ret_s.pop("data"), ret_d.pop("data") + + assert np.allclose(arr_s, arr_d) + assert ret_s == ret_d @pytest.mark.parametrize( diff --git a/xarray/tests/test_typed_ops.py b/xarray/tests/test_typed_ops.py new file mode 100644 index 00000000000..1d4ef89ae29 --- /dev/null +++ b/xarray/tests/test_typed_ops.py @@ -0,0 +1,246 @@ +import numpy as np + +from xarray import DataArray, Dataset, Variable + + +def test_variable_typed_ops() -> None: + """Tests for type checking of typed_ops on Variable""" + + var = Variable(dims=["t"], data=[1, 2, 3]) + + def _test(var: Variable) -> None: + # mypy checks the input type + assert isinstance(var, Variable) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + + # __add__ as an example of binary ops + _test(var + _int) + _test(var + _list) + _test(var + _ndarray) + _test(var + var) + + # __radd__ as an example of reflexive binary ops + _test(_int + var) + _test(_list + var) + _test(_ndarray + var) # type: ignore[arg-type] # numpy problem + + # __eq__ as an example of cmp ops + _test(var == _int) + _test(var == _list) + _test(var == _ndarray) + _test(_int == var) # type: ignore[arg-type] # typeshed problem + _test(_list == var) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == var) + + # __lt__ as another example of cmp ops + _test(var < _int) + _test(var < _list) + _test(var < _ndarray) + _test(_int > var) + _test(_list > var) + _test(_ndarray > var) # type: ignore[arg-type] # numpy problem + + # __iadd__ as an example of inplace binary ops + var += _int + var += _list + var += _ndarray + + # __neg__ as an example of unary ops + _test(-var) + + +def test_dataarray_typed_ops() -> None: + """Tests for type checking of typed_ops on DataArray""" + + da = DataArray([1, 2, 3], dims=["t"]) + + def _test(da: DataArray) -> None: + # mypy checks the input type + assert isinstance(da, DataArray) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + _var = Variable(dims=["t"], data=[1, 2, 3]) + + # __add__ as an example of binary ops + _test(da + _int) + _test(da + _list) + _test(da + _ndarray) + _test(da + _var) + _test(da + da) + + # __radd__ as an example of reflexive binary ops + _test(_int + da) + _test(_list + da) + _test(_ndarray + da) # type: ignore[arg-type] # numpy problem + _test(_var + da) + + # __eq__ as an example of cmp ops + _test(da == _int) + _test(da == _list) + _test(da == _ndarray) + _test(da == _var) + _test(_int == da) # type: ignore[arg-type] # typeshed problem + _test(_list == da) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == da) + _test(_var == da) + + # __lt__ as another example of cmp ops + _test(da < _int) + _test(da < _list) + _test(da < _ndarray) + _test(da < _var) + _test(_int > da) + _test(_list > da) + _test(_ndarray > da) # type: ignore[arg-type] # numpy problem + _test(_var > da) + + # __iadd__ as an example of inplace binary ops + da += _int + da += _list + da += _ndarray + da += _var + + # __neg__ as an example of unary ops + _test(-da) + + +def test_dataset_typed_ops() -> None: + """Tests for type checking of typed_ops on Dataset""" + + ds = Dataset({"a": ("t", [1, 2, 3])}) + + def _test(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + _var = Variable(dims=["t"], data=[1, 2, 3]) + _da = DataArray([1, 2, 3], dims=["t"]) + + # __add__ as an example of binary ops + _test(ds + _int) + _test(ds + _list) + _test(ds + _ndarray) + _test(ds + _var) + _test(ds + _da) + _test(ds + ds) + + # __radd__ as an example of reflexive binary ops + _test(_int + ds) + _test(_list + ds) + _test(_ndarray + ds) + _test(_var + ds) + _test(_da + ds) + + # __eq__ as an example of cmp ops + _test(ds == _int) + _test(ds == _list) + _test(ds == _ndarray) + _test(ds == _var) + _test(ds == _da) + _test(_int == ds) # type: ignore[arg-type] # typeshed problem + _test(_list == ds) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == ds) + _test(_var == ds) + _test(_da == ds) + + # __lt__ as another example of cmp ops + _test(ds < _int) + _test(ds < _list) + _test(ds < _ndarray) + _test(ds < _var) + _test(ds < _da) + _test(_int > ds) + _test(_list > ds) + _test(_ndarray > ds) # type: ignore[arg-type] # numpy problem + _test(_var > ds) + _test(_da > ds) + + # __iadd__ as an example of inplace binary ops + ds += _int + ds += _list + ds += _ndarray + ds += _var + ds += _da + + # __neg__ as an example of unary ops + _test(-ds) + + +def test_dataarray_groupy_typed_ops() -> None: + """Tests for type checking of typed_ops on DataArrayGroupBy""" + + da = DataArray([1, 2, 3], coords={"x": ("t", [1, 2, 2])}, dims=["t"]) + grp = da.groupby("x") + + def _testda(da: DataArray) -> None: + # mypy checks the input type + assert isinstance(da, DataArray) + + def _testds(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x") + _ds = _da.to_dataset(name="a") + + # __add__ as an example of binary ops + _testda(grp + _da) + _testds(grp + _ds) + + # __radd__ as an example of reflexive binary ops + _testda(_da + grp) + _testds(_ds + grp) + + # __eq__ as an example of cmp ops + _testda(grp == _da) + _testda(_da == grp) + _testds(grp == _ds) + _testds(_ds == grp) + + # __lt__ as another example of cmp ops + _testda(grp < _da) + _testda(_da > grp) + _testds(grp < _ds) + _testds(_ds > grp) + + +def test_dataset_groupy_typed_ops() -> None: + """Tests for type checking of typed_ops on DatasetGroupBy""" + + ds = Dataset({"a": ("t", [1, 2, 3])}, coords={"x": ("t", [1, 2, 2])}) + grp = ds.groupby("x") + + def _test(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x") + _ds = _da.to_dataset(name="a") + + # __add__ as an example of binary ops + _test(grp + _da) + _test(grp + _ds) + + # __radd__ as an example of reflexive binary ops + _test(_da + grp) + _test(_ds + grp) + + # __eq__ as an example of cmp ops + _test(grp == _da) + _test(_da == grp) + _test(grp == _ds) + _test(_ds == grp) + + # __lt__ as another example of cmp ops + _test(grp < _da) + _test(_da > grp) + _test(grp < _ds) + _test(_ds > grp) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index addd7587544..14a7a10f734 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -18,6 +18,7 @@ assert_identical, requires_dask, requires_matplotlib, + requires_numbagg, ) from xarray.tests.test_plot import PlotTestCase from xarray.tests.test_variable import _PAD_XR_NP_ARGS @@ -304,11 +305,13 @@ def __call__(self, obj, *args, **kwargs): all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} + from xarray.core.groupby import GroupBy + xarray_classes = ( xr.Variable, xr.DataArray, xr.Dataset, - xr.core.groupby.GroupBy, + GroupBy, ) if not isinstance(obj, xarray_classes): @@ -1499,10 +1502,11 @@ def test_dot_dataarray(dtype): data_array = xr.DataArray(data=array1, dims=("x", "y")) other = xr.DataArray(data=array2, dims=("y", "z")) - expected = attach_units( - xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m} - ) - actual = xr.dot(data_array, other) + with xr.set_options(use_opt_einsum=False): + expected = attach_units( + xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m} + ) + actual = xr.dot(data_array, other) assert_units_equal(expected, actual) assert_identical(expected, actual) @@ -2462,8 +2466,9 @@ def test_binary_operations(self, func, dtype): data_array = xr.DataArray(data=array) units = extract_units(func(array)) - expected = attach_units(func(strip_units(data_array)), units) - actual = func(data_array) + with xr.set_options(use_opt_einsum=False): + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) assert_units_equal(expected, actual) assert_identical(expected, actual) @@ -2548,7 +2553,6 @@ def test_univariate_ufunc(self, units, error, dtype): assert_units_equal(expected, actual) assert_identical(expected, actual) - @pytest.mark.xfail(reason="needs the type register system for __array_ufunc__") @pytest.mark.parametrize( "unit,error", ( @@ -3827,8 +3831,9 @@ def test_computation(self, func, variant, dtype): if not isinstance(func, (function, method)): units.update(extract_units(func(array.reshape(-1)))) - expected = attach_units(func(strip_units(data_array)), units) - actual = func(data_array) + with xr.set_options(use_opt_einsum=False): + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) assert_units_equal(expected, actual) assert_identical(expected, actual) @@ -3849,23 +3854,21 @@ def test_computation(self, func, variant, dtype): method("groupby", "x"), method("groupby_bins", "y", bins=4), method("coarsen", y=2), - pytest.param( - method("rolling", y=3), - marks=pytest.mark.xfail( - reason="numpy.lib.stride_tricks.as_strided converts to ndarray" - ), - ), - pytest.param( - method("rolling_exp", y=3), - marks=pytest.mark.xfail( - reason="numbagg functions are not supported by pint" - ), - ), + method("rolling", y=3), + pytest.param(method("rolling_exp", y=3), marks=requires_numbagg), method("weighted", xr.DataArray(data=np.linspace(0, 1, 10), dims="y")), ), ids=repr, ) def test_computation_objects(self, func, variant, dtype): + if variant == "data": + if func.name == "rolling_exp": + pytest.xfail(reason="numbagg functions are not supported by pint") + elif func.name == "rolling": + pytest.xfail( + reason="numpy.lib.stride_tricks.as_strided converts to ndarray" + ) + unit = unit_registry.m variants = { diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 9b70dcb5464..8a73e435977 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -5,14 +5,14 @@ from copy import copy, deepcopy from datetime import datetime, timedelta from textwrap import dedent +from typing import Generic import numpy as np import pandas as pd import pytest import pytz -from packaging.version import Version -from xarray import Coordinate, DataArray, Dataset, IndexVariable, Variable, set_options +from xarray import DataArray, Dataset, IndexVariable, Variable, set_options from xarray.core import dtypes, duck_array_ops, indexing from xarray.core.common import full_like, ones_like, zeros_like from xarray.core.indexing import ( @@ -27,6 +27,7 @@ VectorizedIndexer, ) from xarray.core.pycompat import array_type +from xarray.core.types import T_DuckArray from xarray.core.utils import NDArrayMixin from xarray.core.variable import as_compatible_data, as_variable from xarray.tests import ( @@ -197,7 +198,7 @@ def test_index_0d_int(self): self._assertIndexedLikeNDArray(x, value, dtype) def test_index_0d_float(self): - for value, dtype in [(0.5, np.float_), (np.float32(0.5), np.float32)]: + for value, dtype in [(0.5, float), (np.float32(0.5), np.float32)]: x = self.cls(["x"], [value]) self._assertIndexedLikeNDArray(x, value, dtype) @@ -339,10 +340,10 @@ def test_pandas_data(self): assert v[0].values == v.values[0] def test_pandas_period_index(self): - v = self.cls(["x"], pd.period_range(start="2000", periods=20, freq="B")) + v = self.cls(["x"], pd.period_range(start="2000", periods=20, freq="D")) v = v.load() # for dask-based Variable - assert v[0] == pd.Period("2000", freq="B") - assert "Period('2000-01-03', 'B')" in repr(v) + assert v[0] == pd.Period("2000", freq="D") + assert "Period('2000-01-01', 'D')" in repr(v) @pytest.mark.parametrize("dtype", [float, int]) def test_1d_math(self, dtype: np.typing.DTypeLike) -> None: @@ -472,12 +473,12 @@ def test_encoding_preserved(self): assert_identical(expected.to_base_variable(), actual.to_base_variable()) assert expected.encoding == actual.encoding - def test_reset_encoding(self) -> None: + def test_drop_encoding(self) -> None: encoding1 = {"scale_factor": 1} # encoding set via cls constructor v1 = self.cls(["a"], [0, 1, 2], encoding=encoding1) assert v1.encoding == encoding1 - v2 = v1.reset_encoding() + v2 = v1.drop_encoding() assert v1.encoding == encoding1 assert v2.encoding == {} @@ -485,7 +486,7 @@ def test_reset_encoding(self) -> None: encoding3 = {"scale_factor": 10} v3 = self.cls(["a"], [0, 1, 2], encoding=encoding3) assert v3.encoding == encoding3 - v4 = v3.reset_encoding() + v4 = v3.drop_encoding() assert v3.encoding == encoding3 assert v4.encoding == {} @@ -616,7 +617,7 @@ def test_copy_with_data_errors(self) -> None: orig = Variable(("x", "y"), [[1.5, 2.0], [3.1, 4.3]], {"foo": "bar"}) new_data = [2.5, 5.0] with pytest.raises(ValueError, match=r"must match shape of object"): - orig.copy(data=new_data) + orig.copy(data=new_data) # type: ignore[arg-type] def test_copy_index_with_data(self) -> None: orig = IndexVariable("x", np.arange(5)) @@ -884,20 +885,10 @@ def test_getitem_error(self): "mode", [ "mean", - pytest.param( - "median", - marks=pytest.mark.xfail(reason="median is not implemented by Dask"), - ), - pytest.param( - "reflect", marks=pytest.mark.xfail(reason="dask.array.pad bug") - ), + "median", + "reflect", "edge", - pytest.param( - "linear_ramp", - marks=pytest.mark.xfail( - reason="pint bug: https://github.com/hgrecco/pint/issues/1026" - ), - ), + "linear_ramp", "maximum", "minimum", "symmetric", @@ -925,7 +916,7 @@ def test_pad_constant_values(self, xr_arg, np_arg): actual = v.pad(**xr_arg) expected = np.pad( - np.array(v.data.astype(float)), + np.array(duck_array_ops.astype(v.data, float)), np_arg, mode="constant", constant_values=np.nan, @@ -1092,7 +1083,7 @@ def test_data_and_values(self): def test_numpy_same_methods(self): v = Variable([], np.float32(0.0)) assert v.item() == 0 - assert type(v.item()) is float + assert type(v.item()) is float # noqa: E721 v = IndexVariable("x", np.arange(5)) assert 2 == v.searchsorted(2) @@ -1128,9 +1119,9 @@ def test_0d_str(self): assert v.dtype == np.dtype("U3") assert v.values == "foo" - v = Variable([], np.string_("foo")) + v = Variable([], np.bytes_("foo")) assert v.dtype == np.dtype("S3") - assert v.values == bytes("foo", "ascii") + assert v.values == "foo".encode("ascii") def test_0d_datetime(self): v = Variable([], pd.Timestamp("2000-01-01")) @@ -1467,10 +1458,10 @@ def test_isel(self): def test_index_0d_numpy_string(self): # regression test to verify our work around for indexing 0d strings - v = Variable([], np.string_("asdf")) + v = Variable([], np.bytes_("asdf")) assert_identical(v[()], v) - v = Variable([], np.unicode_("asdf")) + v = Variable([], np.str_("asdf")) assert_identical(v[()], v) def test_indexing_0d_unicode(self): @@ -1708,6 +1699,15 @@ def test_stack_unstack_consistency(self): actual = v.stack(z=("x", "y")).unstack(z={"x": 2, "y": 2}) assert_identical(actual, v) + @pytest.mark.filterwarnings("error::RuntimeWarning") + def test_unstack_without_missing(self): + v = Variable(["z"], [0, 1, 2, 3]) + expected = Variable(["x", "y"], [[0, 1], [2, 3]]) + + actual = v.unstack(z={"x": 2, "y": 2}) + + assert_identical(actual, expected) + def test_broadcasting_math(self): x = np.random.randn(2, 3) v = Variable(["a", "b"], x) @@ -1802,7 +1802,7 @@ def raise_if_called(*args, **kwargs): ) def test_quantile(self, q, axis, dim, skipna): d = self.d.copy() - d[0, 0] = np.NaN + d[0, 0] = np.nan v = Variable(["x", "y"], d) actual = v.quantile(q, dim=dim, skipna=skipna) @@ -1832,10 +1832,7 @@ def test_quantile_method(self, method, use_dask) -> None: q = np.array([0.25, 0.5, 0.75]) actual = v.quantile(q, dim="y", method=method) - if Version(np.__version__) >= Version("1.22"): - expected = np.nanquantile(self.d, q, axis=1, method=method) - else: - expected = np.nanquantile(self.d, q, axis=1, interpolation=method) + expected = np.nanquantile(self.d, q, axis=1, method=method) if use_dask: assert isinstance(actual.data, dask_array_type) @@ -2338,12 +2335,35 @@ def test_dask_rolling(self, dim, window, center): assert actual.shape == expected.shape assert_equal(actual, expected) - @pytest.mark.xfail( - reason="https://github.com/pydata/xarray/issues/6209#issuecomment-1025116203" - ) def test_multiindex(self): super().test_multiindex() + @pytest.mark.parametrize( + "mode", + [ + "mean", + pytest.param( + "median", + marks=pytest.mark.xfail(reason="median is not implemented by Dask"), + ), + pytest.param( + "reflect", marks=pytest.mark.xfail(reason="dask.array.pad bug") + ), + "edge", + "linear_ramp", + "maximum", + "minimum", + "symmetric", + "wrap", + ], + ) + @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS) + @pytest.mark.filterwarnings( + r"ignore:dask.array.pad.+? converts integers to floats." + ) + def test_pad(self, mode, xr_arg, np_arg): + super().test_pad(mode, xr_arg, np_arg) + @requires_sparse class TestVariableWithSparse: @@ -2447,11 +2467,6 @@ def test_concat_str_dtype(self, dtype): assert actual.identical(expected) assert np.issubdtype(actual.dtype, dtype) - def test_coordinate_alias(self): - with pytest.warns(Warning, match="deprecated"): - x = Coordinate("x", [1, 2, 3]) - assert isinstance(x, IndexVariable) - def test_datetime64(self): # GH:1932 Make sure indexing keeps precision t = np.array([1518418799999986560, 1518418799999996560], dtype="datetime64[ns]") @@ -2529,7 +2544,7 @@ def test_to_index_variable_copy(self) -> None: assert a.dims == ("x",) -class TestAsCompatibleData: +class TestAsCompatibleData(Generic[T_DuckArray]): def test_unchanged_types(self): types = (np.asarray, PandasIndexingAdapter, LazilyIndexedArray) for t in types: @@ -2604,23 +2619,23 @@ def test_datetime(self): @requires_pandas_version_two def test_tz_datetime(self) -> None: - tz = pytz.timezone("US/Eastern") + tz = pytz.timezone("America/New_York") times_ns = pd.date_range("2000", periods=1, tz=tz) times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz)) with warnings.catch_warnings(): warnings.simplefilter("ignore") - actual = as_compatible_data(times_s) + actual: T_DuckArray = as_compatible_data(times_s) assert actual.array == times_s assert actual.array.dtype == pd.DatetimeTZDtype("ns", tz) series = pd.Series(times_s) with warnings.catch_warnings(): warnings.simplefilter("ignore") - actual = as_compatible_data(series) + actual2: T_DuckArray = as_compatible_data(series) - np.testing.assert_array_equal(actual, series.values) - assert actual.dtype == np.dtype("datetime64[ns]") + np.testing.assert_array_equal(actual2, series.values) + assert actual2.dtype == np.dtype("datetime64[ns]") def test_full_like(self) -> None: # For more thorough tests, see test_variable.py @@ -2648,7 +2663,7 @@ def test_full_like(self) -> None: def test_full_like_dask(self) -> None: orig = Variable( dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} - ).chunk(((1, 1), (2,))) + ).chunk(dict(x=(1, 1), y=(2,))) def check(actual, expect_dtype, expect_values): assert actual.dtype == expect_dtype @@ -2719,7 +2734,7 @@ def __init__(self, array): def test_raise_no_warning_for_nan_in_binary_ops(): with assert_no_warnings(): - Variable("x", [1, 2, np.NaN]) > 0 + Variable("x", [1, 2, np.nan]) > 0 class TestBackendIndexing: @@ -2904,9 +2919,11 @@ def test_from_pint_wrapping_dask(self, Var): (pd.date_range("2000", periods=1), False), (datetime(2000, 1, 1), False), (np.array([datetime(2000, 1, 1)]), False), - (pd.date_range("2000", periods=1, tz=pytz.timezone("US/Eastern")), False), + (pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")), False), ( - pd.Series(pd.date_range("2000", periods=1, tz=pytz.timezone("US/Eastern"))), + pd.Series( + pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")) + ), False, ), ], @@ -2928,8 +2945,9 @@ def test_datetime_conversion_warning(values, warns_under_pandas_version_two) -> # The only case where a non-datetime64 dtype can occur currently is in # the case that the variable is backed by a timezone-aware # DatetimeIndex, and thus is hidden within the PandasIndexingAdapter class. + assert isinstance(var._data, PandasIndexingAdapter) assert var._data.array.dtype == pd.DatetimeTZDtype( - "ns", pytz.timezone("US/Eastern") + "ns", pytz.timezone("America/New_York") ) @@ -2941,12 +2959,14 @@ def test_pandas_two_only_datetime_conversion_warnings() -> None: (pd.date_range("2000", periods=1), "datetime64[s]"), (pd.Series(pd.date_range("2000", periods=1)), "datetime64[s]"), ( - pd.date_range("2000", periods=1, tz=pytz.timezone("US/Eastern")), - pd.DatetimeTZDtype("s", pytz.timezone("US/Eastern")), + pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")), + pd.DatetimeTZDtype("s", pytz.timezone("America/New_York")), ), ( - pd.Series(pd.date_range("2000", periods=1, tz=pytz.timezone("US/Eastern"))), - pd.DatetimeTZDtype("s", pytz.timezone("US/Eastern")), + pd.Series( + pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")) + ), + pd.DatetimeTZDtype("s", pytz.timezone("America/New_York")), ), ] for data, dtype in cases: @@ -2959,8 +2979,9 @@ def test_pandas_two_only_datetime_conversion_warnings() -> None: # The only case where a non-datetime64 dtype can occur currently is in # the case that the variable is backed by a timezone-aware # DatetimeIndex, and thus is hidden within the PandasIndexingAdapter class. + assert isinstance(var._data, PandasIndexingAdapter) assert var._data.array.dtype == pd.DatetimeTZDtype( - "ns", pytz.timezone("US/Eastern") + "ns", pytz.timezone("America/New_York") ) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index e2530d41fbe..95fda3fac62 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -608,7 +608,7 @@ def test_weighted_operations_3D(dim, add_nans, skipna): # add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700) if add_nans: c = int(data.size * 0.25) - data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.nan data = DataArray(data, dims=dims, coords=coords) @@ -631,7 +631,7 @@ def test_weighted_quantile_3D(dim, q, add_nans, skipna): # add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700) if add_nans: c = int(data.size * 0.25) - data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.nan da = DataArray(data, dims=dims, coords=coords) @@ -709,7 +709,7 @@ def test_weighted_operations_different_shapes( # add approximately 25 % NaNs if add_nans: c = int(data.size * 0.25) - data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.nan data = DataArray(data) @@ -782,9 +782,12 @@ def test_weighted_bad_dim(operation, as_dataset): if operation == "quantile": kwargs["q"] = 0.5 - error_msg = ( - f"{data.__class__.__name__}Weighted" - " does not contain the dimensions: {'bad_dim'}" - ) - with pytest.raises(ValueError, match=error_msg): + with pytest.raises( + ValueError, + match=( + f"Dimensions \\('bad_dim',\\) not found in {data.__class__.__name__}Weighted " + # the order of (dim_0, dim_1) varies + "dimensions \\(('dim_0', 'dim_1'|'dim_1', 'dim_0')\\)" + ), + ): getattr(data.weighted(weights), operation)(**kwargs) diff --git a/xarray/util/deprecation_helpers.py b/xarray/util/deprecation_helpers.py index e9681bdf398..7b4cf901aa1 100644 --- a/xarray/util/deprecation_helpers.py +++ b/xarray/util/deprecation_helpers.py @@ -34,6 +34,9 @@ import inspect import warnings from functools import wraps +from typing import Callable, TypeVar + +T = TypeVar("T", bound=Callable) POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY @@ -41,7 +44,7 @@ EMPTY = inspect.Parameter.empty -def _deprecate_positional_args(version): +def _deprecate_positional_args(version) -> Callable[[T], T]: """Decorator for methods that issues warnings for positional arguments Using the keyword-only argument syntax in pep 3102, arguments after the diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index 6c50474030d..a1233ea0291 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -14,7 +14,7 @@ """ import collections import textwrap -from dataclasses import dataclass +from dataclasses import dataclass, field MODULE_PREAMBLE = '''\ """Mixin classes with reduction operations.""" @@ -27,16 +27,30 @@ from xarray.core import duck_array_ops from xarray.core.options import OPTIONS -from xarray.core.types import Dims +from xarray.core.types import Dims, Self from xarray.core.utils import contains_only_chunked_or_numpy, module_available if TYPE_CHECKING: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset -flox_available = module_available("flox")''' +flox_available = module_available("flox") +''' -DEFAULT_PREAMBLE = """ +NAMED_ARRAY_MODULE_PREAMBLE = '''\ +"""Mixin classes with reduction operations.""" +# This file was generated using xarray.util.generate_aggregations. Do not edit manually. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, Callable + +from xarray.core import duck_array_ops +from xarray.core.types import Dims, Self +''' + +AGGREGATIONS_PREAMBLE = """ class {obj}{cls}Aggregations: __slots__ = () @@ -50,9 +64,26 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> {obj}: + ) -> Self: + raise NotImplementedError()""" + +NAMED_ARRAY_AGGREGATIONS_PREAMBLE = """ + +class {obj}{cls}Aggregations: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> Self: raise NotImplementedError()""" + GROUPBY_PREAMBLE = """ class {obj}{cls}Aggregations: @@ -104,11 +135,9 @@ def _flox_reduce( TEMPLATE_REDUCTION_SIGNATURE = ''' def {method}( self, - dim: Dims = None, - *,{extra_kwargs} - keep_attrs: bool | None = None, + dim: Dims = None,{kw_only}{extra_kwargs}{keep_attrs} **kwargs: Any, - ) -> {obj}: + ) -> Self: """ Reduce this {obj}'s data by applying ``{method}`` along some dimension(s). @@ -139,9 +168,7 @@ def {method}( TEMPLATE_SEE_ALSO = """ See Also -------- - numpy.{method} - dask.array.{method} - {see_also_obj}.{method} +{see_also_methods} :ref:`{docref}` User guide on {docref_description}.""" @@ -229,6 +256,15 @@ def {method}( ) +@dataclass +class DataStructure: + name: str + create_example: str + example_var_name: str + numeric_only: bool = False + see_also_modules: tuple[str] = tuple + + class Method: def __init__( self, @@ -236,11 +272,12 @@ def __init__( bool_reduce=False, extra_kwargs=tuple(), numeric_only=False, + see_also_modules=("numpy", "dask.array"), ): self.name = name self.extra_kwargs = extra_kwargs self.numeric_only = numeric_only - + self.see_also_modules = see_also_modules if bool_reduce: self.array_method = f"array_{name}" self.np_example_array = """ @@ -249,37 +286,29 @@ def __init__( else: self.array_method = name self.np_example_array = """ - ... np.array([1, 2, 3, 1, 2, np.nan])""" + ... np.array([1, 2, 3, 0, 2, np.nan])""" +@dataclass class AggregationGenerator: _dim_docstring = _DIM_DOCSTRING _template_signature = TEMPLATE_REDUCTION_SIGNATURE - def __init__( - self, - cls, - datastructure, - methods, - docref, - docref_description, - example_call_preamble, - definition_preamble, - see_also_obj=None, - notes=None, - ): - self.datastructure = datastructure - self.cls = cls - self.methods = methods - self.docref = docref - self.docref_description = docref_description - self.example_call_preamble = example_call_preamble - self.preamble = definition_preamble.format(obj=datastructure.name, cls=cls) - self.notes = "" if notes is None else notes - if not see_also_obj: - self.see_also_obj = self.datastructure.name - else: - self.see_also_obj = see_also_obj + cls: str + datastructure: DataStructure + methods: tuple[Method, ...] + docref: str + docref_description: str + example_call_preamble: str + definition_preamble: str + has_keep_attrs: bool = True + notes: str = "" + preamble: str = field(init=False) + + def __post_init__(self): + self.preamble = self.definition_preamble.format( + obj=self.datastructure.name, cls=self.cls + ) def generate_methods(self): yield [self.preamble] @@ -287,7 +316,16 @@ def generate_methods(self): yield self.generate_method(method) def generate_method(self, method): - template_kwargs = dict(obj=self.datastructure.name, method=method.name) + has_kw_only = method.extra_kwargs or self.has_keep_attrs + + template_kwargs = dict( + obj=self.datastructure.name, + method=method.name, + keep_attrs="\n keep_attrs: bool | None = None," + if self.has_keep_attrs + else "", + kw_only="\n *," if has_kw_only else "", + ) if method.extra_kwargs: extra_kwargs = "\n " + "\n ".join( @@ -304,7 +342,7 @@ def generate_method(self, method): for text in [ self._dim_docstring.format(method=method.name, cls=self.cls), *(kwarg.docs for kwarg in method.extra_kwargs if kwarg.docs), - _KEEP_ATTRS_DOCSTRING, + _KEEP_ATTRS_DOCSTRING if self.has_keep_attrs else None, _KWARGS_DOCSTRING.format(method=method.name), ]: if text: @@ -312,11 +350,24 @@ def generate_method(self, method): yield TEMPLATE_RETURNS.format(**template_kwargs) + # we want Datset.count to refer to DataArray.count + # but we also want DatasetGroupBy.count to refer to Dataset.count + # The generic aggregations have self.cls == '' + others = ( + self.datastructure.see_also_modules + if self.cls == "" + else (self.datastructure.name,) + ) + see_also_methods = "\n".join( + " " * 8 + f"{mod}.{method.name}" + for mod in (method.see_also_modules + others) + ) + # Fixes broken links mentioned in #8055 yield TEMPLATE_SEE_ALSO.format( **template_kwargs, docref=self.docref, docref_description=self.docref_description, - see_also_obj=self.see_also_obj, + see_also_methods=see_also_methods, ) notes = self.notes @@ -331,18 +382,12 @@ def generate_method(self, method): yield textwrap.indent(self.generate_example(method=method), "") yield ' """' - yield self.generate_code(method) + yield self.generate_code(method, self.has_keep_attrs) def generate_example(self, method): - create_da = f""" - >>> da = xr.DataArray({method.np_example_array}, - ... dims="time", - ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), - ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), - ... ), - ... )""" - + created = self.datastructure.create_example.format( + example_array=method.np_example_array + ) calculation = f"{self.datastructure.example_var_name}{self.example_call_preamble}.{method.name}" if method.extra_kwargs: extra_examples = "".join( @@ -353,7 +398,8 @@ def generate_example(self, method): return f""" Examples - --------{create_da}{self.datastructure.docstring_create} + --------{created} + >>> {self.datastructure.example_var_name} >>> {calculation}(){extra_examples}""" @@ -362,7 +408,7 @@ class GroupByAggregationGenerator(AggregationGenerator): _dim_docstring = _DIM_DOCSTRING_GROUPBY _template_signature = TEMPLATE_REDUCTION_SIGNATURE_GROUPBY - def generate_code(self, method): + def generate_code(self, method, has_keep_attrs): extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] if self.datastructure.numeric_only: @@ -414,7 +460,7 @@ def generate_code(self, method): class GenericAggregationGenerator(AggregationGenerator): - def generate_code(self, method): + def generate_code(self, method, has_keep_attrs): extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] if self.datastructure.numeric_only: @@ -424,18 +470,20 @@ def generate_code(self, method): extra_kwargs = textwrap.indent("\n" + "\n".join(extra_kwargs), 12 * " ") else: extra_kwargs = "" + keep_attrs = ( + "\n" + 12 * " " + "keep_attrs=keep_attrs," if has_keep_attrs else "" + ) return f"""\ return self.reduce( duck_array_ops.{method.array_method}, - dim=dim,{extra_kwargs} - keep_attrs=keep_attrs, + dim=dim,{extra_kwargs}{keep_attrs} **kwargs, )""" AGGREGATION_METHODS = ( # Reductions: - Method("count"), + Method("count", see_also_modules=("pandas.DataFrame", "dask.dataframe.DataFrame")), Method("all", bool_reduce=True), Method("any", bool_reduce=True), Method("max", extra_kwargs=(skipna,)), @@ -452,28 +500,34 @@ def generate_code(self, method): ) -@dataclass -class DataStructure: - name: str - docstring_create: str - example_var_name: str - numeric_only: bool = False - - DATASET_OBJECT = DataStructure( name="Dataset", - docstring_create=""" - >>> ds = xr.Dataset(dict(da=da)) - >>> ds""", + create_example=""" + >>> da = xr.DataArray({example_array}, + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da))""", example_var_name="ds", numeric_only=True, + see_also_modules=("DataArray",), ) DATAARRAY_OBJECT = DataStructure( name="DataArray", - docstring_create=""" - >>> da""", + create_example=""" + >>> da = xr.DataArray({example_array}, + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... )""", example_var_name="da", numeric_only=False, + see_also_modules=("Dataset",), ) DATASET_GENERATOR = GenericAggregationGenerator( cls="", @@ -482,8 +536,7 @@ class DataStructure: docref="agg", docref_description="reduction or aggregation operations", example_call_preamble="", - see_also_obj="DataArray", - definition_preamble=DEFAULT_PREAMBLE, + definition_preamble=AGGREGATIONS_PREAMBLE, ) DATAARRAY_GENERATOR = GenericAggregationGenerator( cls="", @@ -492,8 +545,7 @@ class DataStructure: docref="agg", docref_description="reduction or aggregation operations", example_call_preamble="", - see_also_obj="Dataset", - definition_preamble=DEFAULT_PREAMBLE, + definition_preamble=AGGREGATIONS_PREAMBLE, ) DATAARRAY_GROUPBY_GENERATOR = GroupByAggregationGenerator( cls="GroupBy", @@ -536,24 +588,59 @@ class DataStructure: notes=_FLOX_RESAMPLE_NOTES, ) +NAMED_ARRAY_OBJECT = DataStructure( + name="NamedArray", + create_example=""" + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x",{example_array}, + ... )""", + example_var_name="na", + numeric_only=False, + see_also_modules=("Dataset", "DataArray"), +) + +NAMED_ARRAY_GENERATOR = GenericAggregationGenerator( + cls="", + datastructure=NAMED_ARRAY_OBJECT, + methods=AGGREGATION_METHODS, + docref="agg", + docref_description="reduction or aggregation operations", + example_call_preamble="", + definition_preamble=NAMED_ARRAY_AGGREGATIONS_PREAMBLE, + has_keep_attrs=False, +) + + +def write_methods(filepath, generators, preamble): + with open(filepath, mode="w", encoding="utf-8") as f: + f.write(preamble) + for gen in generators: + for lines in gen.generate_methods(): + for line in lines: + f.write(line + "\n") + if __name__ == "__main__": import os from pathlib import Path p = Path(os.getcwd()) - filepath = p.parent / "xarray" / "xarray" / "core" / "_aggregations.py" - # filepath = p.parent / "core" / "_aggregations.py" # Run from script location - with open(filepath, mode="w", encoding="utf-8") as f: - f.write(MODULE_PREAMBLE + "\n") - for gen in [ + write_methods( + filepath=p.parent / "xarray" / "xarray" / "core" / "_aggregations.py", + generators=[ DATASET_GENERATOR, DATAARRAY_GENERATOR, DATASET_GROUPBY_GENERATOR, DATASET_RESAMPLE_GENERATOR, DATAARRAY_GROUPBY_GENERATOR, DATAARRAY_RESAMPLE_GENERATOR, - ]: - for lines in gen.generate_methods(): - for line in lines: - f.write(line + "\n") + ], + preamble=MODULE_PREAMBLE, + ) + write_methods( + filepath=p.parent / "xarray" / "xarray" / "namedarray" / "_aggregations.py", + generators=[NAMED_ARRAY_GENERATOR], + preamble=NAMED_ARRAY_MODULE_PREAMBLE, + ) + # filepath = p.parent / "core" / "_aggregations.py" # Run from script location diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index cf0673e7cca..5859934f646 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -3,14 +3,16 @@ For internal xarray development use only. Usage: - python xarray/util/generate_ops.py --module > xarray/core/_typed_ops.py - python xarray/util/generate_ops.py --stubs > xarray/core/_typed_ops.pyi + python xarray/util/generate_ops.py > xarray/core/_typed_ops.py """ # Note: the comments in https://github.com/pydata/xarray/pull/4904 provide some # background to some of the design choices made here. -import sys +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from typing import Optional BINOPS_EQNE = (("__eq__", "nputils.array_eq"), ("__ne__", "nputils.array_ne")) BINOPS_CMP = ( @@ -74,155 +76,180 @@ ("conjugate", "ops.conjugate"), ) + +required_method_binary = """ + def _binary_op( + self, other: {other_type}, f: Callable, reflexive: bool = False + ) -> {return_type}: + raise NotImplementedError""" template_binop = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> {return_type}:{type_ignore} + return self._binary_op(other, {func})""" +template_binop_overload = """ + @overload{overload_type_ignore} + def {method}(self, other: {overload_type}) -> {overload_type}: + ... + + @overload + def {method}(self, other: {other_type}) -> {return_type}: + ... + + def {method}(self, other: {other_type}) -> {return_type} | {overload_type}:{type_ignore} return self._binary_op(other, {func})""" template_reflexive = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> {return_type}: return self._binary_op(other, {func}, reflexive=True)""" + +required_method_inplace = """ + def _inplace_binary_op(self, other: {other_type}, f: Callable) -> Self: + raise NotImplementedError""" template_inplace = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> Self:{type_ignore} return self._inplace_binary_op(other, {func})""" + +required_method_unary = """ + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: + raise NotImplementedError""" template_unary = """ - def {method}(self): + def {method}(self) -> Self: return self._unary_op({func})""" template_other_unary = """ - def {method}(self, *args, **kwargs): + def {method}(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op({func}, *args, **kwargs)""" -required_method_unary = """ - def _unary_op(self, f, *args, **kwargs): - raise NotImplementedError""" -required_method_binary = """ - def _binary_op(self, other, f, reflexive=False): - raise NotImplementedError""" -required_method_inplace = """ - def _inplace_binary_op(self, other, f): - raise NotImplementedError""" # For some methods we override return type `bool` defined by base class `object`. -OVERRIDE_TYPESHED = {"override": " # type: ignore[override]"} -NO_OVERRIDE = {"override": ""} - -# Note: in some of the overloads below the return value in reality is NotImplemented, -# which cannot accurately be expressed with type hints,e.g. Literal[NotImplemented] -# or type(NotImplemented) are not allowed and NoReturn has a different meaning. -# In such cases we are lending the type checkers a hand by specifying the return type -# of the corresponding reflexive method on `other` which will be called instead. -stub_ds = """\ - def {method}(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...{override}""" -stub_da = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def {method}(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...""" -stub_var = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: T_DataArray) -> T_DataArray: ... - @overload - def {method}(self: T_Variable, other: VarCompatible) -> T_Variable: ...""" -stub_dsgb = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: "DataArray") -> "Dataset": ... - @overload - def {method}(self, other: GroupByIncompatible) -> NoReturn: ...""" -stub_dagb = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: T_DataArray) -> T_DataArray: ... - @overload - def {method}(self, other: GroupByIncompatible) -> NoReturn: ...""" -stub_unary = """\ - def {method}(self: {self_type}) -> {self_type}: ...""" -stub_other_unary = """\ - def {method}(self: {self_type}, *args, **kwargs) -> {self_type}: ...""" -stub_required_unary = """\ - def _unary_op(self, f, *args, **kwargs): ...""" -stub_required_binary = """\ - def _binary_op(self, other, f, reflexive=...): ...""" -stub_required_inplace = """\ - def _inplace_binary_op(self, other, f): ...""" - - -def unops(self_type): - extra_context = {"self_type": self_type} +# We need to add "# type: ignore[override]" +# Keep an eye out for: +# https://discuss.python.org/t/make-type-hints-for-eq-of-primitives-less-strict/34240 +# The type ignores might not be necessary anymore at some point. +# +# We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray +# In reality this returns NotImplementes, but this is not a valid type in python 3.9. +# Therefore, we return DataArray. In reality this would call DataArray.__add__(Variable) +# TODO: change once python 3.10 is the minimum. +# +# Mypy seems to require that __iadd__ and __add__ have the same signature. +# This requires some extra type: ignores[misc] in the inplace methods :/ + + +def _type_ignore(ignore: str) -> str: + return f" # type:ignore[{ignore}]" if ignore else "" + + +FuncType = Sequence[tuple[Optional[str], Optional[str]]] +OpsType = tuple[FuncType, str, dict[str, str]] + + +def binops( + other_type: str, return_type: str = "Self", type_ignore_eq: str = "override" +) -> list[OpsType]: + extras = {"other_type": other_type, "return_type": return_type} + return [ + ([(None, None)], required_method_binary, extras), + (BINOPS_NUM + BINOPS_CMP, template_binop, extras | {"type_ignore": ""}), + ( + BINOPS_EQNE, + template_binop, + extras | {"type_ignore": _type_ignore(type_ignore_eq)}, + ), + (BINOPS_REFLEXIVE, template_reflexive, extras), + ] + + +def binops_overload( + other_type: str, + overload_type: str, + return_type: str = "Self", + type_ignore_eq: str = "override", +) -> list[OpsType]: + extras = {"other_type": other_type, "return_type": return_type} return [ - ([(None, None)], required_method_unary, stub_required_unary, {}), - (UNARY_OPS, template_unary, stub_unary, extra_context), - (OTHER_UNARY_METHODS, template_other_unary, stub_other_unary, extra_context), + ([(None, None)], required_method_binary, extras), + ( + BINOPS_NUM + BINOPS_CMP, + template_binop_overload, + extras + | { + "overload_type": overload_type, + "type_ignore": "", + "overload_type_ignore": "", + }, + ), + ( + BINOPS_EQNE, + template_binop_overload, + extras + | { + "overload_type": overload_type, + "type_ignore": "", + "overload_type_ignore": _type_ignore(type_ignore_eq), + }, + ), + (BINOPS_REFLEXIVE, template_reflexive, extras), ] -def binops(stub=""): +def inplace(other_type: str, type_ignore: str = "") -> list[OpsType]: + extras = {"other_type": other_type} return [ - ([(None, None)], required_method_binary, stub_required_binary, {}), - (BINOPS_NUM + BINOPS_CMP, template_binop, stub, NO_OVERRIDE), - (BINOPS_EQNE, template_binop, stub, OVERRIDE_TYPESHED), - (BINOPS_REFLEXIVE, template_reflexive, stub, NO_OVERRIDE), + ([(None, None)], required_method_inplace, extras), + ( + BINOPS_INPLACE, + template_inplace, + extras | {"type_ignore": _type_ignore(type_ignore)}, + ), ] -def inplace(): +def unops() -> list[OpsType]: return [ - ([(None, None)], required_method_inplace, stub_required_inplace, {}), - (BINOPS_INPLACE, template_inplace, "", {}), + ([(None, None)], required_method_unary, {}), + (UNARY_OPS, template_unary, {}), + (OTHER_UNARY_METHODS, template_other_unary, {}), ] ops_info = {} -ops_info["DatasetOpsMixin"] = binops(stub_ds) + inplace() + unops("T_Dataset") -ops_info["DataArrayOpsMixin"] = binops(stub_da) + inplace() + unops("T_DataArray") -ops_info["VariableOpsMixin"] = binops(stub_var) + inplace() + unops("T_Variable") -ops_info["DatasetGroupByOpsMixin"] = binops(stub_dsgb) -ops_info["DataArrayGroupByOpsMixin"] = binops(stub_dagb) +ops_info["DatasetOpsMixin"] = ( + binops(other_type="DsCompatible") + inplace(other_type="DsCompatible") + unops() +) +ops_info["DataArrayOpsMixin"] = ( + binops(other_type="DaCompatible") + inplace(other_type="DaCompatible") + unops() +) +ops_info["VariableOpsMixin"] = ( + binops_overload(other_type="VarCompatible", overload_type="T_DataArray") + + inplace(other_type="VarCompatible", type_ignore="misc") + + unops() +) +ops_info["DatasetGroupByOpsMixin"] = binops( + other_type="GroupByCompatible", return_type="Dataset" +) +ops_info["DataArrayGroupByOpsMixin"] = binops( + other_type="T_Xarray", return_type="T_Xarray" +) MODULE_PREAMBLE = '''\ """Mixin classes with arithmetic operators.""" # This file was generated using xarray.util.generate_ops. Do not edit manually. -import operator - -from . import nputils, ops''' +from __future__ import annotations -STUBFILE_PREAMBLE = '''\ -"""Stub file for mixin classes with arithmetic operators.""" -# This file was generated using xarray.util.generate_ops. Do not edit manually. - -from typing import NoReturn, TypeVar, overload - -import numpy as np -from numpy.typing import ArrayLike +import operator +from typing import TYPE_CHECKING, Any, Callable, overload -from .dataarray import DataArray -from .dataset import Dataset -from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy -from .types import ( +from xarray.core import nputils, ops +from xarray.core.types import ( DaCompatible, DsCompatible, - GroupByIncompatible, - ScalarOrArray, + GroupByCompatible, + Self, + T_DataArray, + T_Xarray, VarCompatible, ) -from .variable import Variable -try: - from dask.array import Array as DaskArray -except ImportError: - DaskArray = np.ndarray # type: ignore - -# DatasetOpsMixin etc. are parent classes of Dataset etc. -# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally -# we use the ones in `types`. (We're open to refining this, and potentially integrating -# the `py` & `pyi` files to simplify them.) -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin")''' +if TYPE_CHECKING: + from xarray.core.dataset import Dataset''' CLASS_PREAMBLE = """{newline} @@ -233,35 +260,28 @@ class {cls_name}: {method}.__doc__ = {func}.__doc__""" -def render(ops_info, is_module): +def render(ops_info: dict[str, list[OpsType]]) -> Iterator[str]: """Render the module or stub file.""" - yield MODULE_PREAMBLE if is_module else STUBFILE_PREAMBLE + yield MODULE_PREAMBLE for cls_name, method_blocks in ops_info.items(): - yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n" * is_module) - yield from _render_classbody(method_blocks, is_module) + yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n") + yield from _render_classbody(method_blocks) -def _render_classbody(method_blocks, is_module): - for method_func_pairs, method_template, stub_template, extra in method_blocks: - template = method_template if is_module else stub_template +def _render_classbody(method_blocks: list[OpsType]) -> Iterator[str]: + for method_func_pairs, template, extra in method_blocks: if template: for method, func in method_func_pairs: yield template.format(method=method, func=func, **extra) - if is_module: - yield "" - for method_func_pairs, *_ in method_blocks: - for method, func in method_func_pairs: - if method and func: - yield COPY_DOCSTRING.format(method=method, func=func) + yield "" + for method_func_pairs, *_ in method_blocks: + for method, func in method_func_pairs: + if method and func: + yield COPY_DOCSTRING.format(method=method, func=func) if __name__ == "__main__": - option = sys.argv[1].lower() if len(sys.argv) == 2 else None - if option not in {"--module", "--stubs"}: - raise SystemExit(f"Usage: {sys.argv[0]} --module | --stubs") - is_module = option == "--module" - - for line in render(ops_info, is_module): + for line in render(ops_info): print(line)