Skip to content

Commit

Permalink
Bug fix for async checkpointing fully sharded state dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Nov 10, 2023
1 parent d05a6a8 commit bd2c05f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
14 changes: 14 additions & 0 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,20 @@ def test_save_state_dict_with_cpu_shards(self):
self.assertTrue(
isinstance(planner.sharded_state_dict['fc1.weight'], _CpuShards))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for sharded test")
def test_cpu_state_dict_flattening(self):
# In the case of a nested state_dict with fully sharded parameters,
# _CpuShards should be treated as terminal nodes.
t = torch.randn(128, 128).to(xm.xla_device())
mesh = self._get_mesh((self.n_devices, 1))
xs.mark_sharding(t, mesh, (0, 1))
state_dict = _sharded_cpu_state_dict({'model': {'weight': t}})
planner = SPMDSavePlanner()
planner.set_up_planner(state_dict, True)
# model.weight should be flattened and tracked in the sharded state dict.
self.assertCountEqual(planner.sharded_state_dict, ["model.weight"])

def test_local_save_plan(self):

def _write_item_assertions(plan, n_devices, parameter_count):
Expand Down
7 changes: 6 additions & 1 deletion torch_xla/experimental/distributed_checkpoint/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@
CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]


# TODO(jonbolin): Logic here is modified from the upstream to enable async
# checkpointing. If the state_dict is comprised entirely of _CpuShards,
# flatten_state_dict will not actually flatten the dict.
# Once we can represent XLAShardedTensor on CPU, either directly or through
# DistributedTensor, we can reuse the upstream logic.
def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
return isinstance(value, torch.Tensor)
return isinstance(value, torch.Tensor) or isinstance(value, _CpuShards)


def _traverse_state_dict(
Expand Down

0 comments on commit bd2c05f

Please sign in to comment.