Skip to content

Commit

Permalink
Update function names
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Apr 10, 2024
1 parent 0ae6f82 commit 9e11ee7
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + f' --xla_force_host_
Code:

```python
import einshard
from einshard import einshard
import jax
import jax.numpy as jnp

a = jnp.zeros((4, 8))
a = einshard.shard(a, 'a b -> * a* b2*')
a = einshard(a, 'a b -> * a* b2*')
jax.debug.visualize_array_sharding(a)
```

Expand Down
4 changes: 2 additions & 2 deletions src/einshard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .shard import shard, sharding
from .einshard import einshard, make_sharding

__version__ = '0.1.1'
__version__ = '0.2.0'
14 changes: 7 additions & 7 deletions src/einshard/shard.py → src/einshard/einshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ def _partition_at_ellipsis(lst: list) -> tuple[list, list]:
r = lst[idx + 1:]
return l, r

def sharding(expression: str, *, n_dims: int | None = None) -> NamedSharding:
def make_sharding(expression: str, *, n_dims: int | None = None) -> NamedSharding:
'''
Get sharding from einshard expression.
Make sharding from einshard expression.
Args:
expression (str): The einshard expression.
Expand Down Expand Up @@ -84,10 +84,10 @@ def sharding(expression: str, *, n_dims: int | None = None) -> NamedSharding:

devices = mesh_utils.create_device_mesh(mesh_shape)
mesh = Mesh(devices, axis_names=axis_names)
sharding_ = NamedSharding(mesh, P(*partition_spec))
return sharding_
sharding = NamedSharding(mesh, P(*partition_spec))
return sharding

def shard(arr: Array, expression: str) -> Array:
def einshard(arr: Array, expression: str) -> Array:
'''
Shards a :class:`jax.Array` according to the given einshard expression.
Expand All @@ -98,6 +98,6 @@ def shard(arr: Array, expression: str) -> Array:
Returns:
jax.Array: The sharded array.
'''
sharding_ = sharding(expression, n_dims=len(arr.shape))
arr = jax.make_array_from_callback(arr.shape, sharding_, lambda idx: arr[idx])
sharding = make_sharding(expression, n_dims=len(arr.shape))
arr = jax.make_array_from_callback(arr.shape, sharding, lambda idx: arr[idx])
return arr
4 changes: 2 additions & 2 deletions tests/test_einshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax import Array
import jax.numpy as jnp

import einshard
from einshard import einshard

def set_device_count(n: int) -> None:
os.environ['JAX_PLATFORMS'] = 'cpu'
Expand Down Expand Up @@ -61,7 +61,7 @@ def assert_equal(a, b):
def invoke_test(spec) -> None:
set_device_count(spec['n_devices'])
a = jnp.zeros(spec['shape'])
a = einshard.shard(a, spec['expr'])
a = einshard(a, spec['expr'])
assert_equal(get_shard_shape(a), spec['ans'])

def main() -> None:
Expand Down

0 comments on commit 9e11ee7

Please sign in to comment.