Skip to content

Commit

Permalink
Get sharding without actually sharding an array
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Apr 9, 2024
1 parent 08063ed commit 2d557c1
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 20 deletions.
32 changes: 28 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

High-level array sharding API for JAX

## Introduction

TODO: Add detailed introduction.

This project originated as a part of the [Mistral 7B v0.2 JAX](https://github.com/yixiaoer/mistral-v0.2-jax) project and has since evolved into an independent project.

This project is supported by Cloud TPUs from Google's [TPU Research Cloud](https://sites.research.google/trc/about/) (TRC).

## Installation

This library requires at least Python 3.12.
Expand All @@ -12,20 +20,24 @@ pip install einshard

## Usage

For testing purpose, we initialise the JAX CPU backend with 16 devices. This should be run before the actual code (e.g. placed at the top of the script):

```python
# initialising JAX CPU backend with 16 devices
n_devices = 16
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + f' --xla_force_host_platform_device_count={n_devices}'
```

# actual code starts here
from einshard import einshard
Code:

```python
import einshard
import jax
import jax.numpy as jnp

a = jnp.zeros((4, 8))
a = einshard(a, 'a b -> * a* b2*')
a = einshard.shard(a, 'a b -> * a* b2*')
jax.debug.visualize_array_sharding(a)
```

Expand Down Expand Up @@ -71,3 +83,15 @@ Build package:
pip install build
python -m build
```

Build docs:

```sh
cd docs
make html
```

```sh
cd docs/_build/html
python -m http.server -b 127.0.0.1
```
4 changes: 2 additions & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
jax[cpu]

sphinx
sphinx-autobuild
sphinx-autodoc-typehints
sphinx-book-theme

jax[cpu]
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ where = ["src"]

[project]
name = "einshard"
version = "0.0.1"
authors = [
{ name="Ayaka", email="[email protected]" },
{ name="Shin", email="[email protected]" },
Expand All @@ -25,7 +24,11 @@ dependencies = [
"jax",
"mypy",
]
dynamic = ["version"]

[project.urls]
Homepage = "https://github.com/ayaka14732/einshard"
Issues = "https://github.com/ayaka14732/einshard/issues"

[tool.setuptools.dynamic]
version = {attr = "einshard.__version__"}
4 changes: 3 additions & 1 deletion src/einshard/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .einshard import einshard
from .einshard import shard, sharding

__version__ = '0.1.0'
37 changes: 27 additions & 10 deletions src/einshard/einshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ def _partition_at_ellipsis(lst: list) -> tuple[list, list]:
r = lst[idx + 1:]
return l, r

def einshard(arr: Array, expression: str) -> Array:
"""
Shards a :class:`jax.Array` according to a specified pattern, using a human-readable expression similar to that used in einsum notation.
def sharding(expression: str, *, n_dims: int | None = None) -> NamedSharding:
'''
Get sharding from einshard expression.
Args:
arr (jax.Array): The Array to be processed with tensor parallelism.
expression (str): A human-readable expression similar to einsum notation that specifies the sharding pattern.
expression (str): The einshard expression.
n_dims (int | None): The number of dimensions of the array to be sharded. This argument must be provided if ellipsis is used in the einshard expression.
Returns:
Array: The sharded array.
"""
jax.sharding.NamedSharding: The :class:`jax.sharding.Sharding` object corresponding to the given expression.
'''
n_devices = jax.device_count()

res = parse_expression(expression, 0)
Expand All @@ -36,8 +36,10 @@ def einshard(arr: Array, expression: str) -> Array:
n_right_ellipses = sum(element_right is ... for element_right in elements_right)
assert n_left_ellipses == n_right_ellipses and n_left_ellipses <= 1

if n_left_ellipses > 0: # == 1
n_dims = len(arr.shape)
if n_left_ellipses == 0:
assert n_dims == len(elements_left)
else: # == 1
assert n_dims is not None
n_dims_elided = n_dims - len(elements_left) + 1
axis_names_for_left_augmented = [f'?{i}' for i in range(n_dims_elided)]
axis_names_for_right_augmented = [(identifier, 1, False) for identifier in axis_names_for_left_augmented] # 1: `sharding_number`, False: `is_proportional`
Expand Down Expand Up @@ -81,5 +83,20 @@ def einshard(arr: Array, expression: str) -> Array:

devices = mesh_utils.create_device_mesh(mesh_shape)
mesh = Mesh(devices, axis_names=axis_names)
arr = jax.make_array_from_callback(arr.shape, NamedSharding(mesh, P(*partition_spec)), lambda idx: arr[idx])
sharding_ = NamedSharding(mesh, P(*partition_spec))
return sharding_

def shard(arr: Array, expression: str) -> Array:
'''
Shards a :class:`jax.Array` according to the given einshard expression.
Args:
arr (jax.Array): The array to be sharded.
expression (str): The einshard expression.
Returns:
jax.Array: The sharded array.
'''
sharding_ = sharding(expression, n_dims=len(arr.shape))
arr = jax.make_array_from_callback(arr.shape, sharding_, lambda idx: arr[idx])
return arr
4 changes: 2 additions & 2 deletions tests/test_einshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax import Array
import jax.numpy as jnp

from einshard import einshard
import einshard

def set_device_count(n: int) -> None:
os.environ['JAX_PLATFORMS'] = 'cpu'
Expand Down Expand Up @@ -61,7 +61,7 @@ def assert_equal(a, b):
def invoke_test(spec) -> None:
set_device_count(spec['n_devices'])
a = jnp.zeros(spec['shape'])
a = einshard(a, spec['expr'])
a = einshard.shard(a, spec['expr'])
assert_equal(get_shard_shape(a), spec['ans'])

def main() -> None:
Expand Down

0 comments on commit 2d557c1

Please sign in to comment.