Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add atleast_nd #3

Merged
merged 11 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ jobs:
with:
pixi-version: v0.30.0
cache: true
- name: Run Pylint
run: pixi run -e lint pylint
- name: Run Pylint & Mypy
run: |
pixi run -e lint pylint
pixi run -e lint mypy

checks:
name: Check ${{ matrix.environment }}
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ instance/

# Sphinx documentation
docs/_build/
docs/generated/

# PyBuilder
.pybuilder/
Expand Down
9 changes: 0 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ repos:
args: ["--fix", "--show-fixes"]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.11.1"
hooks:
- id: mypy
files: src|tests
args: []
additional_dependencies:
- pytest

- repo: https://github.com/codespell-project/codespell
rev: "v2.3.0"
hooks:
Expand Down
10 changes: 10 additions & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# API Reference

```{eval-rst}
.. currentmodule:: array_api_extra
.. autosummary::
:nosignatures:
:toctree: generated

atleast_nd
```
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
extensions = [
"myst_parser",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.mathjax",
"sphinx.ext.napoleon",
Expand Down
8 changes: 1 addition & 7 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,9 @@
```{toctree}
:maxdepth: 2
:hidden:

api-reference.md
```

```{include} ../README.md
:start-after: <!-- SPHINX-START -->
```

## Indices and tables

- {ref}`genindex`
- {ref}`modindex`
- {ref}`search`
2,090 changes: 1,876 additions & 214 deletions pixi.lock

Large diffs are not rendered by default.

39 changes: 23 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,8 @@ dependencies = []
test = [
"pytest >=6",
"pytest-cov >=3",
]
dev = [
"pytest >=6",
"pytest-cov >=3",
"pylint",
"array-api-strict",
"numpy",
]
docs = [
"sphinx>=7.0",
Expand Down Expand Up @@ -68,29 +65,30 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
[tool.pixi.pypi-dependencies]
array-api-extra = { path = ".", editable = true }

[tool.pixi.tasks]
pre-commit = { cmd = "pre-commit install && pre-commit run -v --all-files --show-diff-on-failure" }

[tool.pixi.feature.lint.dependencies]
pre-commit = "*"
mypy = "*"
pylint = "*"
# import dependencies for mypy:
array-api-strict = "*"
numpy = "*"

[tool.pixi.feature.lint.tasks]
pre-commit = { cmd = "pre-commit install && pre-commit run -v --all-files --show-diff-on-failure" }
mypy = { cmd = "mypy", cwd = "." }
pylint = { cmd = ["pylint", "array_api_extra"], cwd = "src" }
lint = { depends-on = ["pre-commit", "pylint"] }
lint = { depends-on = ["pre-commit", "pylint", "mypy"] }

[tool.pixi.feature.test.dependencies]
pytest = ">=6"
pytest-cov = ">=3"
array-api-strict = "*"
numpy = "*"

[tool.pixi.feature.test.tasks]
test = { cmd = "pytest" }
test-ci = { cmd = "pytest -ra --cov --cov-report=xml --cov-report=term --durations=20" }

[tool.pixi.feature.dev.dependencies]
pytest = ">=6"
pytest-cov = ">=3"
pylint = "*"

[tool.pixi.feature.docs.dependencies]
sphinx = ">=7.0"
furo = ">=2023.08.17"
Expand All @@ -100,6 +98,15 @@ myst_parser = ">=0.13"
sphinx_copybutton = "*"
sphinx_autodoc_typehints = "*"

[tool.pixi.feature.docs.tasks]
docs = { cmd = ["sphinx-build", ".", "build/"], cwd = "docs" }

[tool.pixi.feature.dev.dependencies]
ipython = "*"

[tool.pixi.feature.dev.tasks]
ipython = { cmd = "ipython" }

[tool.pixi.feature.py309.dependencies]
python = "~=3.9.0"

Expand All @@ -109,9 +116,9 @@ python = "~=3.12.0"
[tool.pixi.environments]
default = { solve-group = "default" }
lint = { features = ["lint"], solve-group = "default" }
docs = { features = ["docs"], solve-group = "default" }
test = { features = ["test"], solve-group = "default" }
dev = { features = ["dev", "docs"], solve-group = "default" }
docs = { features = ["docs"], solve-group = "default" }
dev = { features = ["lint", "test", "docs", "dev"], solve-group = "default" }
ci-py309 = ["py309", "test"]
ci-py312 = ["py312", "test"]

Expand Down
4 changes: 3 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from ._funcs import atleast_nd

__version__ = "0.1.dev0"

__all__ = ["__version__"]
__all__ = ["__version__", "atleast_nd"]
48 changes: 48 additions & 0 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ._typing import Array, ModuleType

__all__ = ["atleast_nd"]


def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
"""
Recursively expand the dimension of an array to at least `ndim`.

Parameters
----------
x : array
ndim : int
The minimum number of dimensions for the result.
xp : array_namespace
The standard-compatible namespace for `x`.

Returns
-------
res : array
An array with ``res.ndim`` >= `ndim`.
If ``x.ndim`` >= `ndim`, `x` is returned.
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
until ``res.ndim`` equals `ndim`.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([1])
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
Array([[[1]]], dtype=array_api_strict.int64)

>>> x = xp.asarray([[[1, 2],
... [3, 4]]])
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
True

"""
if x.ndim < ndim:
x = xp.expand_dims(x, axis=0)
x = atleast_nd(x, ndim=ndim, xp=xp)
return x
9 changes: 9 additions & 0 deletions src/array_api_extra/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from types import ModuleType
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
Array = Any # To be changed to a Protocol later (see array-api#589)

__all__ = ["Array", "ModuleType"]
69 changes: 69 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

# array-api-strict#6
import array_api_strict as xp # type: ignore[import-untyped]
from numpy.testing import assert_array_equal

from array_api_extra import atleast_nd


class TestAtLeastND:
def test_0D(self):
x = xp.asarray(1)

y = atleast_nd(x, ndim=0, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=1, xp=xp)
assert_array_equal(y, xp.ones((1,)))

y = atleast_nd(x, ndim=5, xp=xp)
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1)))

def test_1D(self):
x = xp.asarray([0, 1])

y = atleast_nd(x, ndim=0, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=1, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=2, xp=xp)
assert_array_equal(y, xp.asarray([[0, 1]]))

y = atleast_nd(x, ndim=5, xp=xp)
assert_array_equal(y, xp.reshape(xp.arange(2), (1, 1, 1, 1, 2)))

def test_2D(self):
x = xp.asarray([[3]])

y = atleast_nd(x, ndim=0, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=2, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=3, xp=xp)
assert_array_equal(y, 3 * xp.ones((1, 1, 1)))

y = atleast_nd(x, ndim=5, xp=xp)
assert_array_equal(y, 3 * xp.ones((1, 1, 1, 1, 1)))

def test_5D(self):
x = xp.ones((1, 1, 1, 1, 1))

y = atleast_nd(x, ndim=0, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=4, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=5, xp=xp)
assert_array_equal(y, x)

y = atleast_nd(x, ndim=6, xp=xp)
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1)))

y = atleast_nd(x, ndim=9, xp=xp)
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))