diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 41314a4f0..f4660588e 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -228,7 +228,7 @@ def step(self, closure=None): self.prefetch_state(p) if "rank" in group: - self.update_step(group, p, gindex, pindex, return_updates=lor_update) + self.update_step(group, lor_update, gindex, pindex, return_updates=lor_update) # GaLore Projection Back p.data.add_(state["projector"].project_back(lor_update))