diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 1f8b0879bb9..ce1e074db2e 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -273,6 +273,20 @@ def test_save_state_dict_with_cpu_shards(self): self.assertTrue( isinstance(planner.sharded_state_dict['fc1.weight'], _CpuShards)) + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for sharded test") + def test_cpu_state_dict_flattening(self): + # In the case of a nested state_dict with fully sharded parameters, + # _CpuShards should be treated as terminal nodes. + t = torch.randn(128, 128).to(xm.xla_device()) + mesh = self._get_mesh((self.n_devices, 1)) + xs.mark_sharding(t, mesh, (0, 1)) + state_dict = _sharded_cpu_state_dict({'model': {'weight': t}}) + planner = SPMDSavePlanner() + planner.set_up_planner(state_dict, True) + # model.weight should be flattened and tracked in the sharded state dict. + self.assertCountEqual(planner.sharded_state_dict, ["model.weight"]) + def test_local_save_plan(self): def _write_item_assertions(plan, n_devices, parameter_count): diff --git a/torch_xla/experimental/distributed_checkpoint/_helpers.py b/torch_xla/experimental/distributed_checkpoint/_helpers.py index b49e7419dcd..9be506cc5a8 100644 --- a/torch_xla/experimental/distributed_checkpoint/_helpers.py +++ b/torch_xla/experimental/distributed_checkpoint/_helpers.py @@ -34,8 +34,13 @@ CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM] +# TODO(jonbolin): Logic here is modified from the upstream to enable async +# checkpointing. If the state_dict is comprised entirely of _CpuShards, +# flatten_state_dict will not actually flatten the dict. +# Once we can represent XLAShardedTensor on CPU, either directly or through +# DistributedTensor, we can reuse the upstream logic. def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: - return isinstance(value, torch.Tensor) + return isinstance(value, torch.Tensor) or isinstance(value, _CpuShards) def _traverse_state_dict(