Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading tensors in lists/dict that have not yet been instantiated #101

Open
vmoens opened this issue Oct 16, 2022 · 1 comment · May be fixed by #103
Open

Loading tensors in lists/dict that have not yet been instantiated #101

vmoens opened this issue Oct 16, 2022 · 1 comment · May be fixed by #103

Comments

@vmoens
Copy link

vmoens commented Oct 16, 2022

🚀 The feature

We'd like to be able to load tensors that are saved on disk but do not yet populate the destination module.

Motivation, pitch

Say we have a module that stores a list of tensors. During training, we increment that list.

If I'm using regular torch.save(state_dict). We will end up with a dictionary with a list of tensors, and we can just load it back where it belongs (as loading is not done in place).

With torchsnapshot, what I understand is that snapshot will look for my current state_dict, and repopulate it in-place. Hence, if my list of tensors is empty (which I expect to be when I load a checkpoint) all the tensors in the list will be discarded.

Example:

from torchsnapshot import StateDict, Snapshot
import torch
import os

def list_files(startpath):
    for root, dirs, files in os.walk(startpath):
        level = root.replace(startpath, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print('{}{}/'.format(indent, os.path.basename(root)))
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            print('{}{}'.format(subindent, f))

class ClassWithSD:
    def __init__(self):
        self.obj = []
    def state_dict(self):
        return {"obj": self.obj}
    def load_state_dict(self, sd):
        self.obj = sd["obj"]


x = ClassWithSD()

# let's put 2 tensors in out list. We'd like to get them back when loading
x.obj.append(torch.tensor([1.0]))
x.obj.append(torch.tensor([2.0]))

app_state = {"x": x}
Snapshot.take(app_state=app_state, path="./")


snapshot = Snapshot(path="./")
y = ClassWithSD()
app_state = {"x": y}
snapshot.restore(app_state=app_state)

print(list_files("./0"))
print("content before take:", x.obj)
print("content after restore:", y.obj)

# with torch.save

torch.save(x.state_dict(), "torch_saved.pt")
y = ClassWithSD()
y.load_state_dict(torch.load("torch_saved.pt"))
print("torch.save:", y.obj)

Alternatives

No response

Additional context

Looking at this:

for logical_path, obj in flattened.items():
if logical_path not in available_entries:
raise RuntimeError(
f"""
When restoring from the snapshot, stateful object "{stateful_key}" requested
path "{logical_path}" which was not available to rank {rank}.
- If the entry does not exist in the snapshot, it means that the state dict
entry was introduced after the snapshot was taken. To partially restore from
the snapshot, please explicitly ignore the state dict entries missing from
the snapshot.
- If the entry exists in the snapshot, it could mean that the world size has
changed and the entry was not marked as replicated when the snapshot was
taken. To resolve the issue, try any of:
- Re-taking the snapshot with the new world size
- Re-taking the snapshot with the original world size, ensuring all
non-sharded values are marked as replicated
- Coerce the missing entry into replicated on restore"""
)
entry = available_entries[logical_path]
if isinstance(entry, PrimitiveEntry):
# for primitive types, directly materialize from PrimitiveEntry
flattened[logical_path] = entry.get_value()
continue
rrs = prepare_read(
entry=entry,
obj_out=obj,
)
for rr in rrs:
buffer_consumer = rr.buffer_consumer
if isinstance(buffer_consumer, ObjectBufferConsumer):
# ObjectBufferConsumer deals with objects that can not be
# in-place restored. We need to replace the original object
# in the flattened dictionary with the object materialized
# by the buffer consumer.
buffer_consumer.set_consume_callback(
functools.partial(dict.__setitem__, flattened, logical_path)
)
read_reqs += rrs
if get_is_batching_enabled():
read_reqs = batch_read_requests(read_reqs=read_reqs)
memory_budget_bytes = get_process_memory_budget_bytes(pg=pg)
sync_execute_read_reqs(
read_reqs=read_reqs,
storage=storage,
memory_budget_bytes=memory_budget_bytes,
rank=pg.get_rank(),
event_loop=event_loop,
)
state_dict = inflate(mnfst, flattened, prefix=stateful_key)
stateful.load_state_dict(state_dict)

I guess that what I would like is that if not all available_entries are loaded, the remaining logical_paths are still loaded in the state_dict that will be given to the stateful.load_state_dict(...) at line 736.

@vmoens vmoens linked a pull request Oct 17, 2022 that will close this issue
@ananthsub
Copy link
Contributor

@yifuwang was this fixed by #104 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants