diff --git a/jax/experimental/gda_serialization/serialization.py b/jax/experimental/gda_serialization/serialization.py index 438f10829..f39905f85 100644 --- a/jax/experimental/gda_serialization/serialization.py +++ b/jax/experimental/gda_serialization/serialization.py @@ -108,18 +108,26 @@ async def _run_serializer(): asyncio.run(_run_serializer()) -async def async_deserialize(mesh, mesh_axes, tensorstore_spec): +async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None): t = ts.open(ts.Spec(tensorstore_spec), open=True).result() + shape = t.shape if global_shape is None else global_shape + new_shard_shape = gda.get_shard_shape(shape, mesh, mesh_axes) async def cb(index): - return await t[index].read() + out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) + requested_domain = ts.IndexTransform(input_shape=shape)[index].domain + restricted_domain = t.domain.intersect(requested_domain) + await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write(t[restricted_domain]) + return out - return await create_async_gda_from_callback(t.shape, mesh, mesh_axes, cb) + return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb) -def run_deserialization(global_meshes, mesh_axes, tensorstore_specs): +def run_deserialization(global_meshes, mesh_axes, tensorstore_specs, + global_shapes=None): async def _run_deserializer(): - future_gdas = jax.tree_map(async_deserialize, global_meshes, mesh_axes, - tensorstore_specs) + future_gdas = jax.tree_map( + async_deserialize, global_meshes, mesh_axes, tensorstore_specs, + [None] * len(tensorstore_specs) if global_shapes is None else global_shapes) return await asyncio.gather(*future_gdas) return asyncio.run(_run_deserializer()) diff --git a/jax/experimental/gda_serialization/serialization_test.py b/jax/experimental/gda_serialization/serialization_test.py index a45a8a4d8..1813c1a61 100644 --- a/jax/experimental/gda_serialization/serialization_test.py +++ b/jax/experimental/gda_serialization/serialization_test.py @@ -99,6 +99,44 @@ def cb3(index): self.assertArraysEqual(s.data.to_py(), np.array([])) self.assertEqual(m3.dtype, np.float32) + def test_checkpointing_with_bigger_shape(self): + global_mesh = create_global_mesh((2, 2), ('x', 'y')) + global_input_shape = (8, 2) + num = util.prod(global_input_shape) + + # First GDA + global_input_data1 = np.arange(num).reshape(global_input_shape) + def cb1(index): + return global_input_data1[index] + gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, + ['x', 'y'], cb1) + ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path) + + ckpt_paths = [str(ckpt_dir1)] + tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + + serialization.run_serialization([gda1], tspecs) + + m1, = serialization.run_deserialization( + [create_global_mesh((4, 2), ('x', 'y'))], + [['x', 'y']], + tspecs, + [(12, 2)], + ) + + expected_data = { + 0: np.array([[0], [2], [4]]), + 1: np.array([[1], [3], [5]]), + 2: np.array([[6], [8], [10]]), + 3: np.array([[7], [9], [11]]), + 4: np.array([[12], [14], [0]]), + 5: np.array([[13], [15], [0]]), + 6: np.array([[0], [0], [0]]), + 7: np.array([[0], [0], [0]]), + } + + for l in m1.local_shards: + self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id]) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())