From eceed128d1988a4b06de8591f620b80220399af1 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 18 Mar 2024 20:24:09 -0400 Subject: [PATCH] One more time --- bitsandbytes/optim/adamw.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index f4660588e..c81e8ca61 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -220,7 +220,7 @@ def step(self, closure=None): lor_update = torch.zeros_like( grad, dtype=p.data.dtype, device=p.data.device, requires_grad=grad.requires_grad ) - lor_update.grad = grad + p.grad = grad if "state1" not in state: self.init_state(group, p, gindex, pindex) @@ -228,7 +228,7 @@ def step(self, closure=None): self.prefetch_state(p) if "rank" in group: - self.update_step(group, lor_update, gindex, pindex, return_updates=lor_update) + self.update_step(group, p, gindex, pindex, return_updates=lor_update) # GaLore Projection Back p.data.add_(state["projector"].project_back(lor_update))