diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 956a4e5f5ec..06d754d2fdb 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -388,6 +388,15 @@ def load_sharded_checkpoint( from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner + def _get_num_ranks_that_saved_rng(metadata: Metadata): + rng_inds = [] + for field_name, field_value in metadata.planner_data.items(): + if 'rng' in field_name: + _, rng_rank_index, _ = field_value + rng_inds.append(rng_rank_index) + rng_inds = set(rng_inds) + return len(rng_inds) + class FileSystemReaderWithValidation(dist_cp.FileSystemReader): """FileSystemReader that validates checkpoint files prior to reading.""" @@ -496,9 +505,10 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): # For older versions of torch, we load optimizer separately. if version.parse(torch.__version__) < version.parse('2.2.9'): cur_state_dict.pop('optimizers') + num_rng_ranks = _get_num_ranks_that_saved_rng(storage_reader.read_metadata()) state_dict: Dict[str, Any] = { 'state': cur_state_dict, - 'rng': reproducibility.get_rng_state(), + 'rng': reproducibility.get_rng_state()[:num_rng_ranks], } if ignore_keys: diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 5bd416f4c79..0799d815d40 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -561,11 +561,7 @@ def test_checkpoint_loading_with_validation(world_size, tmp_path, is_valid_check # Set the error expectations. expectation = does_not_raise() if not is_valid_checkpoint: - if using_torch_2() and state_dict_type == 'sharded': - from torch.distributed.checkpoint import CheckpointException - expectation = pytest.raises(CheckpointException) - else: - expectation = pytest.raises(ValueError) + expectation = pytest.raises(ValueError) def mock_get_checkpoint_validation_function(): return lambda _: is_valid_checkpoint