Skip to content

Commit

Permalink
Refactor to use consts
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Nov 7, 2023
1 parent 28f263f commit 4880e63
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from torch.cuda.amp import GradScaler

from .utils import (
MODEL_NAME,
OPTIMIZER_NAME,
RNG_STATE_NAME,
SAFE_MODEL_NAME,
SAFE_WEIGHTS_NAME,
SAMPLER_NAME,
SCALER_NAME,
Expand Down Expand Up @@ -192,15 +194,13 @@ def load_accelerator_state(
input_dir = Path(input_dir)
# Model states
for i, model in enumerate(models):
weights_name = SAFE_WEIGHTS_NAME
if i > 0:
weights_name = weights_name.replace(".", f"_{i}.")
input_model_file = input_dir.joinpath(weights_name)
ending = f"_{i}" if i > 0 else ""
input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
if input_model_file.exists():
state_dict = load_file(input_model_file, device=str(map_location))
else:
# Load with torch
input_model_file = input_model_file.with_suffix(".bin")
input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
state_dict = torch.load(input_model_file, map_location=map_location)
models[i].load_state_dict(state_dict, **load_model_func_kwargs)
logger.info("All model weights loaded successfully")
Expand Down

0 comments on commit 4880e63

Please sign in to comment.