You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The CheckpointServer currently uses torch.save/torch.load which requires allocating the entire buffer into memory. We want to instead use streaming transfers so we minimize the amount of CPU memory required.
It would also be nice to add checksums to these transfers to avoid any data corruption from the network.
copy over the write_state_dict and read_state_dict implementations into checkpointing.py
replace existing torch.save/torch.load with those
add unit tests for write_state_dict/read_state_dict for all the different possible types of torch tensors (different data types, strided, offsets, scalars, nested structures, etc)
optionally add in checksum to read/write_state_dict that uses zlib.crc32
The text was updated successfully, but these errors were encountered:
The CheckpointServer currently uses torch.save/torch.load which requires allocating the entire buffer into memory. We want to instead use streaming transfers so we minimize the amount of CPU memory required.
It would also be nice to add checksums to these transfers to avoid any data corruption from the network.
Relevant existing code: https://github.com/pytorch-labs/torchft/blob/main/torchft/checkpointing.py#L72
The algorithm is described at: https://gist.github.com/d4l3k/b68094d649a076384967788c9b0a5f08
Existing tests: https://github.com/pytorch-labs/torchft/blob/main/torchft/checkpointing_test.py#L15
Overview of work:
The text was updated successfully, but these errors were encountered: