From 9e11ee727a4097960e86a252eda2bd2bb8b0cd6f Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 10 Apr 2024 05:02:51 +0400 Subject: [PATCH] Update function names --- README.md | 4 ++-- src/einshard/__init__.py | 4 ++-- src/einshard/{shard.py => einshard.py} | 14 +++++++------- tests/test_einshard.py | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) rename src/einshard/{shard.py => einshard.py} (91%) diff --git a/README.md b/README.md index 02869ed..8989eac 100644 --- a/README.md +++ b/README.md @@ -34,12 +34,12 @@ os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + f' --xla_force_host_ Code: ```python -import einshard +from einshard import einshard import jax import jax.numpy as jnp a = jnp.zeros((4, 8)) -a = einshard.shard(a, 'a b -> * a* b2*') +a = einshard(a, 'a b -> * a* b2*') jax.debug.visualize_array_sharding(a) ``` diff --git a/src/einshard/__init__.py b/src/einshard/__init__.py index 09a30a6..eb36319 100644 --- a/src/einshard/__init__.py +++ b/src/einshard/__init__.py @@ -1,3 +1,3 @@ -from .shard import shard, sharding +from .einshard import einshard, make_sharding -__version__ = '0.1.1' +__version__ = '0.2.0' diff --git a/src/einshard/shard.py b/src/einshard/einshard.py similarity index 91% rename from src/einshard/shard.py rename to src/einshard/einshard.py index 5420e22..11e902f 100644 --- a/src/einshard/shard.py +++ b/src/einshard/einshard.py @@ -13,9 +13,9 @@ def _partition_at_ellipsis(lst: list) -> tuple[list, list]: r = lst[idx + 1:] return l, r -def sharding(expression: str, *, n_dims: int | None = None) -> NamedSharding: +def make_sharding(expression: str, *, n_dims: int | None = None) -> NamedSharding: ''' - Get sharding from einshard expression. + Make sharding from einshard expression. Args: expression (str): The einshard expression. @@ -84,10 +84,10 @@ def sharding(expression: str, *, n_dims: int | None = None) -> NamedSharding: devices = mesh_utils.create_device_mesh(mesh_shape) mesh = Mesh(devices, axis_names=axis_names) - sharding_ = NamedSharding(mesh, P(*partition_spec)) - return sharding_ + sharding = NamedSharding(mesh, P(*partition_spec)) + return sharding -def shard(arr: Array, expression: str) -> Array: +def einshard(arr: Array, expression: str) -> Array: ''' Shards a :class:`jax.Array` according to the given einshard expression. @@ -98,6 +98,6 @@ def shard(arr: Array, expression: str) -> Array: Returns: jax.Array: The sharded array. ''' - sharding_ = sharding(expression, n_dims=len(arr.shape)) - arr = jax.make_array_from_callback(arr.shape, sharding_, lambda idx: arr[idx]) + sharding = make_sharding(expression, n_dims=len(arr.shape)) + arr = jax.make_array_from_callback(arr.shape, sharding, lambda idx: arr[idx]) return arr diff --git a/tests/test_einshard.py b/tests/test_einshard.py index e448a5f..19810e1 100644 --- a/tests/test_einshard.py +++ b/tests/test_einshard.py @@ -6,7 +6,7 @@ from jax import Array import jax.numpy as jnp -import einshard +from einshard import einshard def set_device_count(n: int) -> None: os.environ['JAX_PLATFORMS'] = 'cpu' @@ -61,7 +61,7 @@ def assert_equal(a, b): def invoke_test(spec) -> None: set_device_count(spec['n_devices']) a = jnp.zeros(spec['shape']) - a = einshard.shard(a, spec['expr']) + a = einshard(a, spec['expr']) assert_equal(get_shard_shape(a), spec['ans']) def main() -> None: