diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 40fd064508f..e2945d02aba 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2580,9 +2580,8 @@ def save_model( os.makedirs(save_directory, exist_ok=True) - for param in model.parameters(): - if param.device == torch.device("meta"): - raise RuntimeError("You can't save the model since some parameters are on the meta device.") + if any(param.device == torch.device("meta") for param in model.parameters()): + raise RuntimeError("You can't save the model since some parameters are on the meta device.") # get the state_dict of the model state_dict = self.get_state_dict(model)