Skip to content

Commit

Permalink
Fix optimizer support for Python <= 3.9 (#1379)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas authored Sep 30, 2024
1 parent 776140a commit 51294d9
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def load_state_dict(self, state_dict, move_to_device=True):
raise ValueError("loaded state dict has a different number of parameter groups")
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens, strict=True)):
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError(
"loaded state dict contains a parameter group that doesn't match the size of optimizer's group",
)
Expand All @@ -184,7 +184,6 @@ def load_state_dict(self, state_dict, move_to_device=True):
for old_id, p in zip(
chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable(g["params"] for g in groups),
strict=True,
)
}

Expand Down Expand Up @@ -226,7 +225,7 @@ def update_group(group, new_group):
new_group["params"] = group["params"]
return new_group

param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups, strict=True)]
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({"state": state, "param_groups": param_groups})

def to_gpu(self):
Expand Down

0 comments on commit 51294d9

Please sign in to comment.