Skip to content

Commit

Permalink
Only load RNG keys that exist (mosaicml#2901)
Browse files Browse the repository at this point in the history
* cut

* fix call

* fix test
  • Loading branch information
mvpatel2000 authored and ShashankMosaicML committed Feb 3, 2024
1 parent 3d8165e commit aaaa1db
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
12 changes: 11 additions & 1 deletion composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,15 @@ def load_sharded_checkpoint(
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner

def _get_num_ranks_that_saved_rng(metadata: Metadata):
rng_inds = []
for field_name, field_value in metadata.planner_data.items():
if 'rng' in field_name:
_, rng_rank_index, _ = field_value
rng_inds.append(rng_rank_index)
rng_inds = set(rng_inds)
return len(rng_inds)

class FileSystemReaderWithValidation(dist_cp.FileSystemReader):
"""FileSystemReader that validates checkpoint files prior to reading."""

Expand Down Expand Up @@ -496,9 +505,10 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
# For older versions of torch, we load optimizer separately.
if version.parse(torch.__version__) < version.parse('2.2.9'):
cur_state_dict.pop('optimizers')
num_rng_ranks = _get_num_ranks_that_saved_rng(storage_reader.read_metadata())
state_dict: Dict[str, Any] = {
'state': cur_state_dict,
'rng': reproducibility.get_rng_state(),
'rng': reproducibility.get_rng_state()[:num_rng_ranks],
}

if ignore_keys:
Expand Down
6 changes: 1 addition & 5 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,11 +561,7 @@ def test_checkpoint_loading_with_validation(world_size, tmp_path, is_valid_check
# Set the error expectations.
expectation = does_not_raise()
if not is_valid_checkpoint:
if using_torch_2() and state_dict_type == 'sharded':
from torch.distributed.checkpoint import CheckpointException
expectation = pytest.raises(CheckpointException)
else:
expectation = pytest.raises(ValueError)
expectation = pytest.raises(ValueError)

def mock_get_checkpoint_validation_function():
return lambda _: is_valid_checkpoint
Expand Down

0 comments on commit aaaa1db

Please sign in to comment.