Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Load partially instantiated state-dict #103

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,37 @@ def test_state_dict(self) -> None:
},
)
bar = torchsnapshot.StateDict(
{
"a": torch.rand(40, 40),
"b": torch.rand(40, 40),
"c": 41,
"d/e": 42,
"[@x]->&y^%": {"(z)": 43},
},
)
self.assertFalse(check_state_dict_eq(foo.state_dict(), bar.state_dict()))
self.assertTrue(type(foo.state_dict()) == dict)

with tempfile.TemporaryDirectory() as path:
snapshot = torchsnapshot.Snapshot.take(path, {"foo": foo})
snapshot.restore({"foo": bar})
assert_state_dict_eq(self, foo.state_dict(), bar.state_dict())

def test_incomplete_state_dict(self) -> None:
foo = torchsnapshot.StateDict(
{
"a": torch.rand(40, 40),
"b": torch.rand(40, 40),
"c": 42,
"d/e": 43,
"f": {"g": torch.rand(40, 40)},
"[@x]->&y^%": {"(z)": 44},
},
)
bar = torchsnapshot.StateDict(
{
"a": torch.rand(40, 40),
"c": 42,
"[@x]->&y^%": {"(z)": 44},
},
)
Expand Down
19 changes: 19 additions & 0 deletions torchsnapshot/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ def flatten(obj: Any, prefix: str = "") -> Tuple[Manifest, Dict[str, Any]]:
return manifest, flattened


def _populate_mnfst_entry(mnfst, logical_path):
path = os.path.dirname(logical_path)
basename = os.path.basename(logical_path)
if path not in mnfst:
_populate_mnfst_entry(mnfst, path)
mnfst_entry = mnfst.setdefault(path, DictEntry(keys=[]))
if isinstance(mnfst_entry, ListEntry):
pass
# idx_str = os.path.basename(logical_path)
elif isinstance(mnfst_entry, DictEntry):
basename = _filename_to_key(basename)
mnfst_entry.keys.append(basename)
elif isinstance(mnfst_entry, OrderedDictEntry):
basename = _filename_to_key(basename)
mnfst_entry.keys.append(basename)
else:
raise NotImplementedError(f"Unknown entry type {type(mnfst_entry)}")


# pyre-ignore[3]: Return annotation cannot be `Any`
def inflate(
manifest: Manifest, flattened: Dict[str, Any], prefix: str = ""
Expand Down
23 changes: 23 additions & 0 deletions torchsnapshot/io_preparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,29 @@ def prepare_write(
return entry, obj_write_req


def _make_obj_from_entry(entry: Entry):
if isinstance(entry, PrimitiveEntry):
obj_out = entry.get_value()
elif isinstance(entry, (ObjectEntry,)):
obj_out = None
elif isinstance(entry, (TensorEntry,)):
# we could perhaps code a get_value() for those too?
obj_out = torch.empty(
torch.Size(entry.shape),
dtype=string_to_dtype(entry.dtype),
device=torch.device("cpu"),
)
if isinstance(entry, ShardedTensorEntry):
# Do we need this?
obj_out.share_memory_()
else:
raise NotImplementedError(
f"populating non-instantiated {type(entry)} is not implemented. "
f"Got object: {entry}."
)
return obj_out


def prepare_read(
entry: Entry,
obj_out: Optional[Any] = None,
Expand Down
74 changes: 48 additions & 26 deletions torchsnapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@

from .dist_store import get_or_create_store, LinearBarrier

from .flatten import flatten, inflate
from .flatten import _populate_mnfst_entry, flatten, inflate
from .io_preparer import (
_identity_tensor_prepare_func,
_make_obj_from_entry,
ObjectBufferConsumer,
prepare_read,
prepare_write,
Expand Down Expand Up @@ -678,32 +679,53 @@ def _load_stateful(
del state_dict

read_reqs: List[ReadReq] = []
for logical_path, obj in flattened.items():
if logical_path not in available_entries:
raise RuntimeError(
f"""
When restoring from the snapshot, stateful object "{stateful_key}" requested
path "{logical_path}" which was not available to rank {rank}.

- If the entry does not exist in the snapshot, it means that the state dict
entry was introduced after the snapshot was taken. To partially restore from
the snapshot, please explicitly ignore the state dict entries missing from
the snapshot.

- If the entry exists in the snapshot, it could mean that the world size has
changed and the entry was not marked as replicated when the snapshot was
taken. To resolve the issue, try any of:
- Re-taking the snapshot with the new world size
- Re-taking the snapshot with the original world size, ensuring all
non-sharded values are marked as replicated
- Coerce the missing entry into replicated on restore"""
)
available_entries = {
k: item
for k, item in available_entries.items()
if k.startswith(stateful_key)
}
flatten_items = list(flattened.items())
while True:
if len(flatten_items):
logical_path, obj = flatten_items[0]
flatten_items = flatten_items[1:]
if logical_path not in available_entries:
raise RuntimeError(
f"""
When restoring from the snapshot, stateful object "{stateful_key}" requested
path "{logical_path}" which was not available to rank {rank}.

- If the entry does not exist in the snapshot, it means that the state dict
entry was introduced after the snapshot was taken. To partially restore from
the snapshot, please explicitly ignore the state dict entries missing from
the snapshot.

- If the entry exists in the snapshot, it could mean that the world size has
changed and the entry was not marked as replicated when the snapshot was
taken. To resolve the issue, try any of:
- Re-taking the snapshot with the new world size
- Re-taking the snapshot with the original world size, ensuring all
non-sharded values are marked as replicated
- Coerce the missing entry into replicated on restore"""
)
entry = available_entries.pop(logical_path)
if isinstance(entry, PrimitiveEntry):
# for primitive types, directly materialize from PrimitiveEntry
flattened[logical_path] = entry.get_value()
continue
elif len(available_entries):
logical_path, entry = available_entries.popitem()
obj = _make_obj_from_entry(entry)
flattened[logical_path] = obj
# populate manifest
_populate_mnfst_entry(mnfst, logical_path)
if isinstance(entry, PrimitiveEntry):
# for primitive types, directly materialize from PrimitiveEntry
flattened[logical_path] = obj
continue
else:
break

entry = available_entries[logical_path]
if isinstance(entry, PrimitiveEntry):
# for primitive types, directly materialize from PrimitiveEntry
flattened[logical_path] = entry.get_value()
continue
rrs = prepare_read(
entry=entry,
obj_out=obj,
Expand Down