-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optim state using data_ptr as index may be a problem #11
Comments
pytorch uses the entire parameter as a key which feels maybe a bit overkill |
changing your state function
actually restores the original behavior and FSDP saving works now + typical state indexing as is |
Another sharp bit i found with distributed saving/loading of checkpoints. On the initialization of state_dicts, the step function will be run, but at LR=0. |
i'm seeing that the optimizer.state has a number of keys that is much fewer than the number of parameters, it seems data_ptr() is not returning us unique values, and because shapes aren't unique either, multiple parameter states are being written to the same spot
(Pdb) optim.state.keys()
dict_keys([(0, (1, 2, 1024)), (0, (1, 64, 1024)), (0, (1024,)), (0, (1024, 48)), (0, (1024, 1024)), (0, (4096, 1024)), (0, (4096,)), (0, (1024, 4096)), (0, (1, 1, 1024)), (0, (1000, 1024)), (0, (1000,))])
(Pdb) optim.state[(0, (1, 2, 1024))]["exp_avg"].data_ptr()
0
(Pdb) optim.state[(0, (1, 2, 1024))]["exp_avg"]._cdata
93862421754000
(Pdb)
I suggest we use param._cdata which appears to be a unique value to each tensor.
additionally, the distributed checkpointing system expects state values to be indexed in torch's original format, just a long list of integers for each parameter [0, N]
I suggest we use our (param._cdata, shape) to map to the original indexing system, and then to the state, so we can retain original behavior
https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict.py#L776
The text was updated successfully, but these errors were encountered: