Skip to content

Commit

Permalink
Fix bug and update documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Apr 10, 2024
1 parent 2d557c1 commit 0ae6f82
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__pycache__
/venv
/build
/dist
*.egg-info
/docs/_build
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ This library requires at least Python 3.12.
pip install einshard
```

You need to have JAX installed by [choosing the correct installation method](https://jax.readthedocs.io/en/latest/installation.html) before installing Einshard.

## Usage

For testing purpose, we initialise the JAX CPU backend with 16 devices. This should be run before the actual code (e.g. placed at the top of the script):
Expand Down Expand Up @@ -59,11 +61,15 @@ Output:

## Development

Crente venv:

```sh
python3.12 -m venv venv
. venv/bin/activate
```

Install dependencies:

```sh
pip install -U pip
pip install -U wheel
Expand Down
4 changes: 3 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
contain the root `toctree` directive.
Einshard Documentation
====================================
======================

**Useful links**: `Source Repository <https://github.com/ayaka14732/einshard>`_ | `Issue Tracker <https://github.com/ayaka14732/einshard/issues>`_

API
---
Expand Down
4 changes: 2 additions & 2 deletions src/einshard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .einshard import shard, sharding
from .shard import shard, sharding

__version__ = '0.1.0'
__version__ = '0.1.1'
3 changes: 2 additions & 1 deletion src/einshard/einshard.py → src/einshard/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def sharding(expression: str, *, n_dims: int | None = None) -> NamedSharding:
assert n_left_ellipses == n_right_ellipses and n_left_ellipses <= 1

if n_left_ellipses == 0:
assert n_dims == len(elements_left)
if n_dims is not None:
assert n_dims == len(elements_left)
else: # == 1
assert n_dims is not None
n_dims_elided = n_dims - len(elements_left) + 1
Expand Down

0 comments on commit 0ae6f82

Please sign in to comment.