Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optim state using data_ptr as index may be a problem #11

Open
ethansmith2000 opened this issue Nov 21, 2024 · 3 comments
Open

Optim state using data_ptr as index may be a problem #11

ethansmith2000 opened this issue Nov 21, 2024 · 3 comments

Comments

@ethansmith2000
Copy link
Contributor

ethansmith2000 commented Nov 21, 2024

i'm seeing that the optimizer.state has a number of keys that is much fewer than the number of parameters, it seems data_ptr() is not returning us unique values, and because shapes aren't unique either, multiple parameter states are being written to the same spot

(Pdb) optim.state.keys()
dict_keys([(0, (1, 2, 1024)), (0, (1, 64, 1024)), (0, (1024,)), (0, (1024, 48)), (0, (1024, 1024)), (0, (4096, 1024)), (0, (4096,)), (0, (1024, 4096)), (0, (1, 1, 1024)), (0, (1000, 1024)), (0, (1000,))])
(Pdb) optim.state[(0, (1, 2, 1024))]["exp_avg"].data_ptr()
0
(Pdb) optim.state[(0, (1, 2, 1024))]["exp_avg"]._cdata
93862421754000
(Pdb)

I suggest we use param._cdata which appears to be a unique value to each tensor.

additionally, the distributed checkpointing system expects state values to be indexed in torch's original format, just a long list of integers for each parameter [0, N]
I suggest we use our (param._cdata, shape) to map to the original indexing system, and then to the state, so we can retain original behavior
https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict.py#L776

@ethansmith2000
Copy link
Contributor Author

pytorch uses the entire parameter as a key which feels maybe a bit overkill
https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py#L119

@ethansmith2000
Copy link
Contributor Author

changing your state function

# def state_(self, arg: torch.Tensor):
#     return self.state[self.key(arg)]
def state_(self, arg: torch.Tensor):
    return self.state[arg]

actually restores the original behavior and FSDP saving works now + typical state indexing as is

@ethansmith2000
Copy link
Contributor Author

ethansmith2000 commented Nov 22, 2024

Another sharp bit i found with distributed saving/loading of checkpoints. On the initialization of state_dicts, the step function will be run, but at LR=0.
However if this is our first step, we've already created the fake_param_groups, which all have lr=0, we need to somehow sync the fake_groups attributes to the original group
https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict.py#L586

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant