From 4880e63eb4adff032805914287e1f57c63177088 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Tue, 7 Nov 2023 18:39:12 -0500 Subject: [PATCH] Refactor to use consts --- src/accelerate/checkpointing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index 5e2577b2778..11d30d9fef1 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -22,8 +22,10 @@ from torch.cuda.amp import GradScaler from .utils import ( + MODEL_NAME, OPTIMIZER_NAME, RNG_STATE_NAME, + SAFE_MODEL_NAME, SAFE_WEIGHTS_NAME, SAMPLER_NAME, SCALER_NAME, @@ -192,15 +194,13 @@ def load_accelerator_state( input_dir = Path(input_dir) # Model states for i, model in enumerate(models): - weights_name = SAFE_WEIGHTS_NAME - if i > 0: - weights_name = weights_name.replace(".", f"_{i}.") - input_model_file = input_dir.joinpath(weights_name) + ending = f"_{i}" if i > 0 else "" + input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors") if input_model_file.exists(): state_dict = load_file(input_model_file, device=str(map_location)) else: # Load with torch - input_model_file = input_model_file.with_suffix(".bin") + input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin") state_dict = torch.load(input_model_file, map_location=map_location) models[i].load_state_dict(state_dict, **load_model_func_kwargs) logger.info("All model weights loaded successfully")