You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We'd like to be able to load tensors that are saved on disk but do not yet populate the destination module.
Motivation, pitch
Say we have a module that stores a list of tensors. During training, we increment that list.
If I'm using regular torch.save(state_dict). We will end up with a dictionary with a list of tensors, and we can just load it back where it belongs (as loading is not done in place).
With torchsnapshot, what I understand is that snapshot will look for my current state_dict, and repopulate it in-place. Hence, if my list of tensors is empty (which I expect to be when I load a checkpoint) all the tensors in the list will be discarded.
Example:
fromtorchsnapshotimportStateDict, Snapshotimporttorchimportosdeflist_files(startpath):
forroot, dirs, filesinos.walk(startpath):
level=root.replace(startpath, '').count(os.sep)
indent=' '*4* (level)
print('{}{}/'.format(indent, os.path.basename(root)))
subindent=' '*4* (level+1)
forfinfiles:
print('{}{}'.format(subindent, f))
classClassWithSD:
def__init__(self):
self.obj= []
defstate_dict(self):
return {"obj": self.obj}
defload_state_dict(self, sd):
self.obj=sd["obj"]
x=ClassWithSD()
# let's put 2 tensors in out list. We'd like to get them back when loadingx.obj.append(torch.tensor([1.0]))
x.obj.append(torch.tensor([2.0]))
app_state= {"x": x}
Snapshot.take(app_state=app_state, path="./")
snapshot=Snapshot(path="./")
y=ClassWithSD()
app_state= {"x": y}
snapshot.restore(app_state=app_state)
print(list_files("./0"))
print("content before take:", x.obj)
print("content after restore:", y.obj)
# with torch.savetorch.save(x.state_dict(), "torch_saved.pt")
y=ClassWithSD()
y.load_state_dict(torch.load("torch_saved.pt"))
print("torch.save:", y.obj)
I guess that what I would like is that if not all available_entries are loaded, the remaining logical_paths are still loaded in the state_dict that will be given to the stateful.load_state_dict(...) at line 736.
The text was updated successfully, but these errors were encountered:
🚀 The feature
We'd like to be able to load tensors that are saved on disk but do not yet populate the destination module.
Motivation, pitch
Say we have a module that stores a list of tensors. During training, we increment that list.
If I'm using regular
torch.save(state_dict)
. We will end up with a dictionary with a list of tensors, and we can just load it back where it belongs (as loading is not done in place).With torchsnapshot, what I understand is that snapshot will look for my current state_dict, and repopulate it in-place. Hence, if my list of tensors is empty (which I expect to be when I load a checkpoint) all the tensors in the list will be discarded.
Example:
Alternatives
No response
Additional context
Looking at this:
torchsnapshot/torchsnapshot/snapshot.py
Lines 681 to 736 in 4596fc6
I guess that what I would like is that if not all
available_entries
are loaded, the remaininglogical_path
s are still loaded in the state_dict that will be given to thestateful.load_state_dict(...)
at line 736.The text was updated successfully, but these errors were encountered: