From 51294d909e914f80d2f2c85f9e9229dd4386ef85 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 30 Sep 2024 11:58:52 -0400 Subject: [PATCH] Fix optimizer support for Python <= 3.9 (#1379) --- bitsandbytes/optim/optimizer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index ab8f0d4ea..03e0e01d7 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -173,7 +173,7 @@ def load_state_dict(self, state_dict, move_to_device=True): raise ValueError("loaded state dict has a different number of parameter groups") param_lens = (len(g["params"]) for g in groups) saved_lens = (len(g["params"]) for g in saved_groups) - if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens, strict=True)): + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): raise ValueError( "loaded state dict contains a parameter group that doesn't match the size of optimizer's group", ) @@ -184,7 +184,6 @@ def load_state_dict(self, state_dict, move_to_device=True): for old_id, p in zip( chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups), - strict=True, ) } @@ -226,7 +225,7 @@ def update_group(group, new_group): new_group["params"] = group["params"] return new_group - param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups, strict=True)] + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] self.__setstate__({"state": state, "param_groups": param_groups}) def to_gpu(self):