diff --git a/tests/test_snapshot.py b/tests/test_snapshot.py index 96c4031..b4f3553 100644 --- a/tests/test_snapshot.py +++ b/tests/test_snapshot.py @@ -31,11 +31,37 @@ def test_state_dict(self) -> None: }, ) bar = torchsnapshot.StateDict( + { + "a": torch.rand(40, 40), + "b": torch.rand(40, 40), + "c": 41, + "d/e": 42, + "[@x]->&y^%": {"(z)": 43}, + }, + ) + self.assertFalse(check_state_dict_eq(foo.state_dict(), bar.state_dict())) + self.assertTrue(type(foo.state_dict()) == dict) + + with tempfile.TemporaryDirectory() as path: + snapshot = torchsnapshot.Snapshot.take(path, {"foo": foo}) + snapshot.restore({"foo": bar}) + assert_state_dict_eq(self, foo.state_dict(), bar.state_dict()) + + def test_incomplete_state_dict(self) -> None: + foo = torchsnapshot.StateDict( { "a": torch.rand(40, 40), "b": torch.rand(40, 40), "c": 42, "d/e": 43, + "f": {"g": torch.rand(40, 40)}, + "[@x]->&y^%": {"(z)": 44}, + }, + ) + bar = torchsnapshot.StateDict( + { + "a": torch.rand(40, 40), + "c": 42, "[@x]->&y^%": {"(z)": 44}, }, ) diff --git a/torchsnapshot/flatten.py b/torchsnapshot/flatten.py index 7afd11a..23d6926 100644 --- a/torchsnapshot/flatten.py +++ b/torchsnapshot/flatten.py @@ -70,6 +70,25 @@ def flatten(obj: Any, prefix: str = "") -> Tuple[Manifest, Dict[str, Any]]: return manifest, flattened +def _populate_mnfst_entry(mnfst, logical_path): + path = os.path.dirname(logical_path) + basename = os.path.basename(logical_path) + if path not in mnfst: + _populate_mnfst_entry(mnfst, path) + mnfst_entry = mnfst.setdefault(path, DictEntry(keys=[])) + if isinstance(mnfst_entry, ListEntry): + pass + # idx_str = os.path.basename(logical_path) + elif isinstance(mnfst_entry, DictEntry): + basename = _filename_to_key(basename) + mnfst_entry.keys.append(basename) + elif isinstance(mnfst_entry, OrderedDictEntry): + basename = _filename_to_key(basename) + mnfst_entry.keys.append(basename) + else: + raise NotImplementedError(f"Unknown entry type {type(mnfst_entry)}") + + # pyre-ignore[3]: Return annotation cannot be `Any` def inflate( manifest: Manifest, flattened: Dict[str, Any], prefix: str = "" diff --git a/torchsnapshot/io_preparer.py b/torchsnapshot/io_preparer.py index e307d80..7145fb6 100644 --- a/torchsnapshot/io_preparer.py +++ b/torchsnapshot/io_preparer.py @@ -902,6 +902,29 @@ def prepare_write( return entry, obj_write_req +def _make_obj_from_entry(entry: Entry): + if isinstance(entry, PrimitiveEntry): + obj_out = entry.get_value() + elif isinstance(entry, (ObjectEntry,)): + obj_out = None + elif isinstance(entry, (TensorEntry,)): + # we could perhaps code a get_value() for those too? + obj_out = torch.empty( + torch.Size(entry.shape), + dtype=string_to_dtype(entry.dtype), + device=torch.device("cpu"), + ) + if isinstance(entry, ShardedTensorEntry): + # Do we need this? + obj_out.share_memory_() + else: + raise NotImplementedError( + f"populating non-instantiated {type(entry)} is not implemented. " + f"Got object: {entry}." + ) + return obj_out + + def prepare_read( entry: Entry, obj_out: Optional[Any] = None, diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index c68eb97..0517718 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -30,9 +30,10 @@ from .dist_store import get_or_create_store, LinearBarrier -from .flatten import flatten, inflate +from .flatten import _populate_mnfst_entry, flatten, inflate from .io_preparer import ( _identity_tensor_prepare_func, + _make_obj_from_entry, ObjectBufferConsumer, prepare_read, prepare_write, @@ -678,32 +679,53 @@ def _load_stateful( del state_dict read_reqs: List[ReadReq] = [] - for logical_path, obj in flattened.items(): - if logical_path not in available_entries: - raise RuntimeError( - f""" -When restoring from the snapshot, stateful object "{stateful_key}" requested -path "{logical_path}" which was not available to rank {rank}. - -- If the entry does not exist in the snapshot, it means that the state dict - entry was introduced after the snapshot was taken. To partially restore from - the snapshot, please explicitly ignore the state dict entries missing from - the snapshot. - -- If the entry exists in the snapshot, it could mean that the world size has - changed and the entry was not marked as replicated when the snapshot was - taken. To resolve the issue, try any of: - - Re-taking the snapshot with the new world size - - Re-taking the snapshot with the original world size, ensuring all - non-sharded values are marked as replicated - - Coerce the missing entry into replicated on restore""" - ) + available_entries = { + k: item + for k, item in available_entries.items() + if k.startswith(stateful_key) + } + flatten_items = list(flattened.items()) + while True: + if len(flatten_items): + logical_path, obj = flatten_items[0] + flatten_items = flatten_items[1:] + if logical_path not in available_entries: + raise RuntimeError( + f""" + When restoring from the snapshot, stateful object "{stateful_key}" requested + path "{logical_path}" which was not available to rank {rank}. + + - If the entry does not exist in the snapshot, it means that the state dict + entry was introduced after the snapshot was taken. To partially restore from + the snapshot, please explicitly ignore the state dict entries missing from + the snapshot. + + - If the entry exists in the snapshot, it could mean that the world size has + changed and the entry was not marked as replicated when the snapshot was + taken. To resolve the issue, try any of: + - Re-taking the snapshot with the new world size + - Re-taking the snapshot with the original world size, ensuring all + non-sharded values are marked as replicated + - Coerce the missing entry into replicated on restore""" + ) + entry = available_entries.pop(logical_path) + if isinstance(entry, PrimitiveEntry): + # for primitive types, directly materialize from PrimitiveEntry + flattened[logical_path] = entry.get_value() + continue + elif len(available_entries): + logical_path, entry = available_entries.popitem() + obj = _make_obj_from_entry(entry) + flattened[logical_path] = obj + # populate manifest + _populate_mnfst_entry(mnfst, logical_path) + if isinstance(entry, PrimitiveEntry): + # for primitive types, directly materialize from PrimitiveEntry + flattened[logical_path] = obj + continue + else: + break - entry = available_entries[logical_path] - if isinstance(entry, PrimitiveEntry): - # for primitive types, directly materialize from PrimitiveEntry - flattened[logical_path] = entry.get_value() - continue rrs = prepare_read( entry=entry, obj_out=obj,