Skip to content
This repository has been archived by the owner on Mar 6, 2023. It is now read-only.

Commit

Permalink
Make resharding of GDA work if the shape is larger than what it was s…
Browse files Browse the repository at this point in the history
…erialized with.

For example: If you serialize with shape (8, 2) and want to deserialize with global shape (12, 2).

PiperOrigin-RevId: 429680502
  • Loading branch information
yashk2810 authored and jax authors committed Feb 19, 2022
1 parent c161c62 commit 3290dd3
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 6 deletions.
20 changes: 14 additions & 6 deletions jax/experimental/gda_serialization/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,26 @@ async def _run_serializer():
asyncio.run(_run_serializer())


async def async_deserialize(mesh, mesh_axes, tensorstore_spec):
async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None):
t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
shape = t.shape if global_shape is None else global_shape
new_shard_shape = gda.get_shard_shape(shape, mesh, mesh_axes)

async def cb(index):
return await t[index].read()
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write(t[restricted_domain])
return out

return await create_async_gda_from_callback(t.shape, mesh, mesh_axes, cb)
return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)


def run_deserialization(global_meshes, mesh_axes, tensorstore_specs):
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
global_shapes=None):
async def _run_deserializer():
future_gdas = jax.tree_map(async_deserialize, global_meshes, mesh_axes,
tensorstore_specs)
future_gdas = jax.tree_map(
async_deserialize, global_meshes, mesh_axes, tensorstore_specs,
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes)
return await asyncio.gather(*future_gdas)
return asyncio.run(_run_deserializer())
38 changes: 38 additions & 0 deletions jax/experimental/gda_serialization/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,44 @@ def cb3(index):
self.assertArraysEqual(s.data.to_py(), np.array([]))
self.assertEqual(m3.dtype, np.float32)

def test_checkpointing_with_bigger_shape(self):
global_mesh = create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
num = util.prod(global_input_shape)

# First GDA
global_input_data1 = np.arange(num).reshape(global_input_shape)
def cb1(index):
return global_input_data1[index]
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
['x', 'y'], cb1)
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

ckpt_paths = [str(ckpt_dir1)]
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

serialization.run_serialization([gda1], tspecs)

m1, = serialization.run_deserialization(
[create_global_mesh((4, 2), ('x', 'y'))],
[['x', 'y']],
tspecs,
[(12, 2)],
)

expected_data = {
0: np.array([[0], [2], [4]]),
1: np.array([[1], [3], [5]]),
2: np.array([[6], [8], [10]]),
3: np.array([[7], [9], [11]]),
4: np.array([[12], [14], [0]]),
5: np.array([[13], [15], [0]]),
6: np.array([[0], [0], [0]]),
7: np.array([[0], [0], [0]]),
}

for l in m1.local_shards:
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 3290dd3

Please sign in to comment.