diff --git a/README.md b/README.md index 6199c18..27912bd 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,14 @@ High-level array sharding API for JAX +## Introduction + +TODO: Add detailed introduction. + +This project originated as a part of the [Mistral 7B v0.2 JAX](https://github.com/yixiaoer/mistral-v0.2-jax) project and has since evolved into an independent project. + +This project is supported by Cloud TPUs from Google's [TPU Research Cloud](https://sites.research.google/trc/about/) (TRC). + ## Installation This library requires at least Python 3.12. @@ -12,20 +20,24 @@ pip install einshard ## Usage +For testing purpose, we initialise the JAX CPU backend with 16 devices. This should be run before the actual code (e.g. placed at the top of the script): + ```python -# initialising JAX CPU backend with 16 devices n_devices = 16 import os os.environ['JAX_PLATFORMS'] = 'cpu' os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + f' --xla_force_host_platform_device_count={n_devices}' +``` -# actual code starts here -from einshard import einshard +Code: + +```python +import einshard import jax import jax.numpy as jnp a = jnp.zeros((4, 8)) -a = einshard(a, 'a b -> * a* b2*') +a = einshard.shard(a, 'a b -> * a* b2*') jax.debug.visualize_array_sharding(a) ``` @@ -71,3 +83,15 @@ Build package: pip install build python -m build ``` + +Build docs: + +```sh +cd docs +make html +``` + +```sh +cd docs/_build/html +python -m http.server -b 127.0.0.1 +``` diff --git a/docs/requirements.txt b/docs/requirements.txt index 0378c0f..7e42a6d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ +jax[cpu] + sphinx sphinx-autobuild sphinx-autodoc-typehints sphinx-book-theme - -jax[cpu] diff --git a/pyproject.toml b/pyproject.toml index 7347792..2c3bbe3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,6 @@ where = ["src"] [project] name = "einshard" -version = "0.0.1" authors = [ { name="Ayaka", email="ayaka@mail.shn.hk" }, { name="Shin", email="shin@yixiaoer.sg" }, @@ -25,7 +24,11 @@ dependencies = [ "jax", "mypy", ] +dynamic = ["version"] [project.urls] Homepage = "https://github.com/ayaka14732/einshard" Issues = "https://github.com/ayaka14732/einshard/issues" + +[tool.setuptools.dynamic] +version = {attr = "einshard.__version__"} diff --git a/src/einshard/__init__.py b/src/einshard/__init__.py index 0cf6db6..5fe99dc 100644 --- a/src/einshard/__init__.py +++ b/src/einshard/__init__.py @@ -1 +1,3 @@ -from .einshard import einshard +from .einshard import shard, sharding + +__version__ = '0.1.0' diff --git a/src/einshard/einshard.py b/src/einshard/einshard.py index 18c30d2..5828489 100644 --- a/src/einshard/einshard.py +++ b/src/einshard/einshard.py @@ -13,17 +13,17 @@ def _partition_at_ellipsis(lst: list) -> tuple[list, list]: r = lst[idx + 1:] return l, r -def einshard(arr: Array, expression: str) -> Array: - """ - Shards a :class:`jax.Array` according to a specified pattern, using a human-readable expression similar to that used in einsum notation. +def sharding(expression: str, *, n_dims: int | None = None) -> NamedSharding: + ''' + Get sharding from einshard expression. Args: - arr (jax.Array): The Array to be processed with tensor parallelism. - expression (str): A human-readable expression similar to einsum notation that specifies the sharding pattern. + expression (str): The einshard expression. + n_dims (int | None): The number of dimensions of the array to be sharded. This argument must be provided if ellipsis is used in the einshard expression. Returns: - Array: The sharded array. - """ + jax.sharding.NamedSharding: The :class:`jax.sharding.Sharding` object corresponding to the given expression. + ''' n_devices = jax.device_count() res = parse_expression(expression, 0) @@ -36,8 +36,10 @@ def einshard(arr: Array, expression: str) -> Array: n_right_ellipses = sum(element_right is ... for element_right in elements_right) assert n_left_ellipses == n_right_ellipses and n_left_ellipses <= 1 - if n_left_ellipses > 0: # == 1 - n_dims = len(arr.shape) + if n_left_ellipses == 0: + assert n_dims == len(elements_left) + else: # == 1 + assert n_dims is not None n_dims_elided = n_dims - len(elements_left) + 1 axis_names_for_left_augmented = [f'?{i}' for i in range(n_dims_elided)] axis_names_for_right_augmented = [(identifier, 1, False) for identifier in axis_names_for_left_augmented] # 1: `sharding_number`, False: `is_proportional` @@ -81,5 +83,20 @@ def einshard(arr: Array, expression: str) -> Array: devices = mesh_utils.create_device_mesh(mesh_shape) mesh = Mesh(devices, axis_names=axis_names) - arr = jax.make_array_from_callback(arr.shape, NamedSharding(mesh, P(*partition_spec)), lambda idx: arr[idx]) + sharding_ = NamedSharding(mesh, P(*partition_spec)) + return sharding_ + +def shard(arr: Array, expression: str) -> Array: + ''' + Shards a :class:`jax.Array` according to the given einshard expression. + + Args: + arr (jax.Array): The array to be sharded. + expression (str): The einshard expression. + + Returns: + jax.Array: The sharded array. + ''' + sharding_ = sharding(expression, n_dims=len(arr.shape)) + arr = jax.make_array_from_callback(arr.shape, sharding_, lambda idx: arr[idx]) return arr diff --git a/tests/test_einshard.py b/tests/test_einshard.py index 19810e1..e448a5f 100644 --- a/tests/test_einshard.py +++ b/tests/test_einshard.py @@ -6,7 +6,7 @@ from jax import Array import jax.numpy as jnp -from einshard import einshard +import einshard def set_device_count(n: int) -> None: os.environ['JAX_PLATFORMS'] = 'cpu' @@ -61,7 +61,7 @@ def assert_equal(a, b): def invoke_test(spec) -> None: set_device_count(spec['n_devices']) a = jnp.zeros(spec['shape']) - a = einshard(a, spec['expr']) + a = einshard.shard(a, spec['expr']) assert_equal(get_shard_shape(a), spec['ans']) def main() -> None: