Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low Bit Optim Instability. #1218

Open
nighting0le01 opened this issue Nov 4, 2024 · 4 comments
Open

Low Bit Optim Instability. #1218

nighting0le01 opened this issue Nov 4, 2024 · 4 comments

Comments

@nighting0le01
Copy link

Hi will the torchao low bit optim allow for per layer selection of switching to 32-bit adam for stability? also stablemebedding layers 1. Reference:https://huggingface.co/docs/bitsandbytes/main/en/optimizers#optimize-unstable-parameters
2. Stable Embeddings https://huggingface.co/docs/bitsandbytes/main/en/reference/nn/embeddings#bitsandbytes.nn.StableEmbedding

i see some divergence with torchao optimizer and don't with bitsandbytes

@gau-nernst
Copy link
Collaborator

Regarding "per layer selection", I planned for this feature before but forgot about it. Should be easy to add. Do you have an idea how you want this API to look like? I'm thinking like this

optim = AdamW8bit(model.parameters(), exclude_low_bit_optim_params=[model.output.weight])

Regarding StableEmbedding, it looks like an ordinary nn.Module with some custom default. You should be able to use it directly from bnb. I don't think we need to re-implement it in ao.

Would you be interested in contributing a PR for the 1st feature?

@nighting0le01
Copy link
Author

hi @gau-nernst yes i believe this seems like a decent API. sure i can take it up if you can give some pointers

@gau-nernst
Copy link
Collaborator

We currently have some checks to only apply low-bit optim for certain params

# follow bitsandbytes, only quantize tensors >= 4096 values
# also wrap subclass in DTensor when needed
def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = DTensor.from_local(
local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
)
else:
out = self._subclass_zeros(p, signed, self.block_size)
else:
out = torch.zeros_like(p)
return out

You can simply add an extra check if the param is not in exclude_low_bit_optim_params list (or set), and add this extra argument to each Adam variation.

@nighting0le01
Copy link
Author

sounds good! i'll create a PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants