From 1cf918a77fac35ed0b2fb32b3d64113eecf2eba3 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 17 Oct 2022 10:32:47 +0100 Subject: [PATCH 1/5] init --- torchsnapshot/flatten.py | 14 ++++++++ torchsnapshot/io_preparer.py | 11 ++++++ torchsnapshot/snapshot.py | 70 ++++++++++++++++++++++-------------- 3 files changed, 69 insertions(+), 26 deletions(-) diff --git a/torchsnapshot/flatten.py b/torchsnapshot/flatten.py index 7afd11a..749d378 100644 --- a/torchsnapshot/flatten.py +++ b/torchsnapshot/flatten.py @@ -70,6 +70,20 @@ def flatten(obj: Any, prefix: str = "") -> Tuple[Manifest, Dict[str, Any]]: return manifest, flattened +def _populate_mnfst_entry(mnfst, logical_path): + path = os.path.dirname(logical_path) + basename = os.path.basename(logical_path) + mnfst_entry = mnfst[path] + if isinstance(mnfst_entry, ListEntry): + idx_str = os.path.basename(logical_path) + elif isinstance(mnfst_entry, DictEntry): + mnfst_entry.keys.append(basename) + elif isinstance(mnfst_entry, OrderedDictEntry): + mnfst_entry.keys.append(basename) + else: + raise NotImplementedError(f"Unknown entry type {type(mnfst_entry)}") + + # pyre-ignore[3]: Return annotation cannot be `Any` def inflate( manifest: Manifest, flattened: Dict[str, Any], prefix: str = "" diff --git a/torchsnapshot/io_preparer.py b/torchsnapshot/io_preparer.py index e307d80..4d128fc 100644 --- a/torchsnapshot/io_preparer.py +++ b/torchsnapshot/io_preparer.py @@ -902,6 +902,17 @@ def prepare_write( return entry, obj_write_req +def _make_obj_from_entry(entry: Entry): + obj_out = torch.empty( + *entry.shape, dtype=string_to_dtype(entry.dtype), device=torch.device("cpu") + ) + if isinstance(entry, ShardedTensorEntry): + # Do we need this? + obj_out.share_memory_() + + return obj_out + + def prepare_read( entry: Entry, obj_out: Optional[Any] = None, diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index c68eb97..199dcb2 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -30,9 +30,10 @@ from .dist_store import get_or_create_store, LinearBarrier -from .flatten import flatten, inflate +from .flatten import _populate_mnfst_entry, flatten, inflate from .io_preparer import ( _identity_tensor_prepare_func, + _make_obj_from_entry, ObjectBufferConsumer, prepare_read, prepare_write, @@ -678,32 +679,49 @@ def _load_stateful( del state_dict read_reqs: List[ReadReq] = [] - for logical_path, obj in flattened.items(): - if logical_path not in available_entries: - raise RuntimeError( - f""" -When restoring from the snapshot, stateful object "{stateful_key}" requested -path "{logical_path}" which was not available to rank {rank}. - -- If the entry does not exist in the snapshot, it means that the state dict - entry was introduced after the snapshot was taken. To partially restore from - the snapshot, please explicitly ignore the state dict entries missing from - the snapshot. - -- If the entry exists in the snapshot, it could mean that the world size has - changed and the entry was not marked as replicated when the snapshot was - taken. To resolve the issue, try any of: - - Re-taking the snapshot with the new world size - - Re-taking the snapshot with the original world size, ensuring all - non-sharded values are marked as replicated - - Coerce the missing entry into replicated on restore""" - ) + available_entries = { + k: item + for k, item in available_entries.items() + if k.startswith(stateful_key) + } + flatten_items = list(flattened.items()) + while True: + if len(flatten_items): + logical_path, obj = flatten_items[0] + flatten_items = flatten_items[1:] + if logical_path not in available_entries: + raise RuntimeError( + f""" + When restoring from the snapshot, stateful object "{stateful_key}" requested + path "{logical_path}" which was not available to rank {rank}. + + - If the entry does not exist in the snapshot, it means that the state dict + entry was introduced after the snapshot was taken. To partially restore from + the snapshot, please explicitly ignore the state dict entries missing from + the snapshot. + + - If the entry exists in the snapshot, it could mean that the world size has + changed and the entry was not marked as replicated when the snapshot was + taken. To resolve the issue, try any of: + - Re-taking the snapshot with the new world size + - Re-taking the snapshot with the original world size, ensuring all + non-sharded values are marked as replicated + - Coerce the missing entry into replicated on restore""" + ) + entry = available_entries.pop(logical_path) + if isinstance(entry, PrimitiveEntry): + # for primitive types, directly materialize from PrimitiveEntry + flattened[logical_path] = entry.get_value() + continue + elif len(available_entries): + logical_path, entry = available_entries.popitem() + obj = _make_obj_from_entry(entry) + flattened[logical_path] = obj + # populate manifest + _populate_mnfst_entry(mnfst, logical_path) + else: + break - entry = available_entries[logical_path] - if isinstance(entry, PrimitiveEntry): - # for primitive types, directly materialize from PrimitiveEntry - flattened[logical_path] = entry.get_value() - continue rrs = prepare_read( entry=entry, obj_out=obj, From 2ec08505d160a73a0effba7e9a98185c09bf42c7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 17 Oct 2022 12:36:54 +0100 Subject: [PATCH 2/5] tests --- tests/test_snapshot.py | 26 ++++++++++++++++++++++++++ torchsnapshot/flatten.py | 9 +++++++-- torchsnapshot/io_preparer.py | 21 ++++++++++++++------- torchsnapshot/snapshot.py | 4 ++++ 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/tests/test_snapshot.py b/tests/test_snapshot.py index 96c4031..b4f3553 100644 --- a/tests/test_snapshot.py +++ b/tests/test_snapshot.py @@ -31,11 +31,37 @@ def test_state_dict(self) -> None: }, ) bar = torchsnapshot.StateDict( + { + "a": torch.rand(40, 40), + "b": torch.rand(40, 40), + "c": 41, + "d/e": 42, + "[@x]->&y^%": {"(z)": 43}, + }, + ) + self.assertFalse(check_state_dict_eq(foo.state_dict(), bar.state_dict())) + self.assertTrue(type(foo.state_dict()) == dict) + + with tempfile.TemporaryDirectory() as path: + snapshot = torchsnapshot.Snapshot.take(path, {"foo": foo}) + snapshot.restore({"foo": bar}) + assert_state_dict_eq(self, foo.state_dict(), bar.state_dict()) + + def test_incomplete_state_dict(self) -> None: + foo = torchsnapshot.StateDict( { "a": torch.rand(40, 40), "b": torch.rand(40, 40), "c": 42, "d/e": 43, + "f": {"g": torch.rand(40, 40)}, + "[@x]->&y^%": {"(z)": 44}, + }, + ) + bar = torchsnapshot.StateDict( + { + "a": torch.rand(40, 40), + "c": 42, "[@x]->&y^%": {"(z)": 44}, }, ) diff --git a/torchsnapshot/flatten.py b/torchsnapshot/flatten.py index 749d378..23d6926 100644 --- a/torchsnapshot/flatten.py +++ b/torchsnapshot/flatten.py @@ -73,12 +73,17 @@ def flatten(obj: Any, prefix: str = "") -> Tuple[Manifest, Dict[str, Any]]: def _populate_mnfst_entry(mnfst, logical_path): path = os.path.dirname(logical_path) basename = os.path.basename(logical_path) - mnfst_entry = mnfst[path] + if path not in mnfst: + _populate_mnfst_entry(mnfst, path) + mnfst_entry = mnfst.setdefault(path, DictEntry(keys=[])) if isinstance(mnfst_entry, ListEntry): - idx_str = os.path.basename(logical_path) + pass + # idx_str = os.path.basename(logical_path) elif isinstance(mnfst_entry, DictEntry): + basename = _filename_to_key(basename) mnfst_entry.keys.append(basename) elif isinstance(mnfst_entry, OrderedDictEntry): + basename = _filename_to_key(basename) mnfst_entry.keys.append(basename) else: raise NotImplementedError(f"Unknown entry type {type(mnfst_entry)}") diff --git a/torchsnapshot/io_preparer.py b/torchsnapshot/io_preparer.py index 4d128fc..9b8e822 100644 --- a/torchsnapshot/io_preparer.py +++ b/torchsnapshot/io_preparer.py @@ -903,13 +903,20 @@ def prepare_write( def _make_obj_from_entry(entry: Entry): - obj_out = torch.empty( - *entry.shape, dtype=string_to_dtype(entry.dtype), device=torch.device("cpu") - ) - if isinstance(entry, ShardedTensorEntry): - # Do we need this? - obj_out.share_memory_() - + if isinstance(entry, PrimitiveEntry): + obj_out = entry.get_value() + elif isinstance(entry, (ChunkedTensorEntry, TensorEntry, ShardedTensorEntry)): + # we could perhaps code a get_value() for those too? + obj_out = torch.empty( + *entry.shape, dtype=string_to_dtype(entry.dtype), device=torch.device("cpu") + ) + if isinstance(entry, ShardedTensorEntry): + # Do we need this? + obj_out.share_memory_() + else: + raise NotImplementedError( + f"populating non-instantiated {type(entry)} is not implemented" + ) return obj_out diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index 199dcb2..0517718 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -719,6 +719,10 @@ def _load_stateful( flattened[logical_path] = obj # populate manifest _populate_mnfst_entry(mnfst, logical_path) + if isinstance(entry, PrimitiveEntry): + # for primitive types, directly materialize from PrimitiveEntry + flattened[logical_path] = obj + continue else: break From f32ce9b7fae9d407357d79e7f665be40da09e3d9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 17 Oct 2022 13:32:14 +0100 Subject: [PATCH 3/5] update tests --- torchsnapshot/io_preparer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchsnapshot/io_preparer.py b/torchsnapshot/io_preparer.py index 9b8e822..1d7ba06 100644 --- a/torchsnapshot/io_preparer.py +++ b/torchsnapshot/io_preparer.py @@ -905,7 +905,7 @@ def prepare_write( def _make_obj_from_entry(entry: Entry): if isinstance(entry, PrimitiveEntry): obj_out = entry.get_value() - elif isinstance(entry, (ChunkedTensorEntry, TensorEntry, ShardedTensorEntry)): + elif isinstance(entry, (TensorEntry, )): # we could perhaps code a get_value() for those too? obj_out = torch.empty( *entry.shape, dtype=string_to_dtype(entry.dtype), device=torch.device("cpu") From c38a161eaa04d0c170539aad22c9655d6fab84bc Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 17 Oct 2022 15:39:46 +0100 Subject: [PATCH 4/5] fix: creating tensors with empty shape --- torchsnapshot/io_preparer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchsnapshot/io_preparer.py b/torchsnapshot/io_preparer.py index 1d7ba06..d7e31d2 100644 --- a/torchsnapshot/io_preparer.py +++ b/torchsnapshot/io_preparer.py @@ -905,17 +905,20 @@ def prepare_write( def _make_obj_from_entry(entry: Entry): if isinstance(entry, PrimitiveEntry): obj_out = entry.get_value() + elif isinstance(entry, (ObjectEntry,)): + obj_out = None elif isinstance(entry, (TensorEntry, )): # we could perhaps code a get_value() for those too? obj_out = torch.empty( - *entry.shape, dtype=string_to_dtype(entry.dtype), device=torch.device("cpu") + torch.Size(entry.shape), dtype=string_to_dtype(entry.dtype), device=torch.device("cpu") ) if isinstance(entry, ShardedTensorEntry): # Do we need this? obj_out.share_memory_() else: raise NotImplementedError( - f"populating non-instantiated {type(entry)} is not implemented" + f"populating non-instantiated {type(entry)} is not implemented. " + f"Got object: {entry}." ) return obj_out From 5501d83f8ede08a21487c4adbff5a85140ff965b Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 17 Oct 2022 15:40:46 +0100 Subject: [PATCH 5/5] lint --- torchsnapshot/io_preparer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchsnapshot/io_preparer.py b/torchsnapshot/io_preparer.py index d7e31d2..7145fb6 100644 --- a/torchsnapshot/io_preparer.py +++ b/torchsnapshot/io_preparer.py @@ -907,10 +907,12 @@ def _make_obj_from_entry(entry: Entry): obj_out = entry.get_value() elif isinstance(entry, (ObjectEntry,)): obj_out = None - elif isinstance(entry, (TensorEntry, )): + elif isinstance(entry, (TensorEntry,)): # we could perhaps code a get_value() for those too? obj_out = torch.empty( - torch.Size(entry.shape), dtype=string_to_dtype(entry.dtype), device=torch.device("cpu") + torch.Size(entry.shape), + dtype=string_to_dtype(entry.dtype), + device=torch.device("cpu"), ) if isinstance(entry, ShardedTensorEntry): # Do we need this?