Skip to content

Commit

Permalink
Trigger weights_only=True by default for all compatible objects (#3036
Browse files Browse the repository at this point in the history
)

* rebase

* Update torch v

* Rename

* Prop to docs

* Actually reverse states

* Rebase fully

* Restore old state

* Keep as load()

* No need for explicit anymore

* Check numpy version, dtypes was added in 1.25

* Clean up diff

* Fix hang
  • Loading branch information
muellerzr authored Oct 10, 2024
1 parent 1d2ca74 commit 6f79b63
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 10 deletions.
2 changes: 2 additions & 0 deletions docs/source/package_reference/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ These include general utilities that should be used when working in parallel.

[[autodoc]] utils.save

[[autodoc]] utils.load

[[autodoc]] utils.wait_for_everyone


Expand Down
22 changes: 13 additions & 9 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_mlu_available,
is_torch_xla_available,
is_xpu_available,
load,
save,
)

Expand Down Expand Up @@ -217,23 +218,24 @@ def load_accelerator_state(
else:
# Load with torch
input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
state_dict = torch.load(input_model_file, map_location=map_location)
state_dict = load(input_model_file, map_location=map_location)
model.load_state_dict(state_dict, **load_model_func_kwargs)
logger.info("All model weights loaded successfully")

# Optimizer states
for i, opt in enumerate(optimizers):
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
input_optimizer_file = input_dir.joinpath(optimizer_name)
optimizer_state = torch.load(input_optimizer_file, map_location=map_location)
optimizer_state = load(input_optimizer_file, map_location=map_location)
optimizers[i].load_state_dict(optimizer_state)
logger.info("All optimizer states loaded successfully")

# Scheduler states
for i, scheduler in enumerate(schedulers):
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
input_scheduler_file = input_dir.joinpath(scheduler_name)
scheduler.load_state_dict(torch.load(input_scheduler_file))
scheduler_state = load(input_scheduler_file)
scheduler.load_state_dict(scheduler_state)
logger.info("All scheduler states loaded successfully")

for i, dataloader in enumerate(dataloaders):
Expand All @@ -245,24 +247,25 @@ def load_accelerator_state(
if isinstance(dataloader.dataset, IterableDatasetShard):
sampler = dataloader.get_sampler()
if isinstance(sampler, SeedableRandomSampler):
sampler = dataloader.set_sampler(torch.load(input_sampler_file))
sampler = dataloader.set_sampler(load(input_sampler_file))
if getattr(dataloader, "use_stateful_dataloader", False):
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
if input_dataloader_state_dict_file.exists():
state_dict = torch.load(input_dataloader_state_dict_file)
state_dict = load(input_dataloader_state_dict_file)
dataloader.load_state_dict(state_dict)
logger.info("All dataloader sampler states loaded successfully")

# GradScaler state
if scaler is not None:
input_scaler_file = input_dir.joinpath(SCALER_NAME)
scaler.load_state_dict(torch.load(input_scaler_file))
scaler_state = load(input_scaler_file)
scaler.load_state_dict(scaler_state)
logger.info("GradScaler state loaded successfully")

# Random states
try:
states = torch.load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
states = load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
if "step" in states:
override_attributes["step"] = states["step"]
random.setstate(states["random_state"])
Expand Down Expand Up @@ -295,8 +298,9 @@ def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False

def load_custom_state(obj, path, index: int = 0):
"""
Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl`
Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl`. Will always set `weights_only=False` when
loading the state.
"""
load_location = f"{path}/custom_checkpoint_{index}.pkl"
logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}")
obj.load_state_dict(torch.load(load_location, map_location="cpu"))
obj.load_state_dict(load(load_location, map_location="cpu", weights_only=False))
2 changes: 2 additions & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
is_transformers_available,
is_triton_available,
is_wandb_available,
is_weights_only_available,
is_xpu_available,
)
from .modeling import (
Expand Down Expand Up @@ -248,6 +249,7 @@
extract_model_from_parallel,
get_pretty_name,
is_port_in_use,
load,
merge_dicts,
patch_environment,
recursive_getattr,
Expand Down
11 changes: 11 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,14 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


def is_weights_only_available():
# Weights only with allowlist was added in 2.4.0
# ref: https://github.com/pytorch/pytorch/pull/124331
return is_torch_version(">=", "2.4.0")


def is_numpy_available(min_version="1.25.0"):
numpy_version = parse(importlib.metadata.version("numpy"))
return compare_versions(numpy_version, ">=", min_version)
57 changes: 56 additions & 1 deletion src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import platform
import re
import socket
from codecs import encode
from contextlib import contextmanager
from functools import partial, reduce
from types import MethodType
from typing import OrderedDict

import numpy as np
import torch
from packaging.version import Version
from safetensors.torch import save_file as safe_save_file
Expand All @@ -31,7 +33,13 @@
from ..state import PartialState
from .constants import FSDP_PYTORCH_VERSION
from .dataclasses import DistributedType
from .imports import is_deepspeed_available, is_torch_distributed_available, is_torch_xla_available
from .imports import (
is_deepspeed_available,
is_numpy_available,
is_torch_distributed_available,
is_torch_xla_available,
is_weights_only_available,
)
from .modeling import id_tensor_storage
from .transformer_engine import convert_model
from .versions import is_torch_version
Expand Down Expand Up @@ -207,6 +215,53 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Fal
save_func(obj, f)


# The following are considered "safe" globals to reconstruct various types of objects when using `weights_only=True`
# These should be added and then removed after loading in the file
TORCH_SAFE_GLOBALS = [
# numpy arrays are just numbers, not objects, so we can reconstruct them safely
np.core.multiarray._reconstruct,
np.ndarray,
# The following are needed for the RNG states
encode,
np.dtype,
]

if is_numpy_available("1.25.0"):
TORCH_SAFE_GLOBALS.append(np.dtypes.UInt32DType)


def load(f, map_location=None, **kwargs):
"""
Compatible drop-in replacement of `torch.load()` which allows for `weights_only` to be used if `torch` version is
2.4.0 or higher. Otherwise will ignore the kwarg.
Will also add (and then remove) an exception for numpy arrays
Args:
f:
The file (or file-like object) to use to load the data
map_location:
a function, `torch.device`, string or a dict specifying how to remap storage locations
**kwargs:
Additional keyword arguments to pass to `torch.load()`.
"""
try:
if is_weights_only_available():
old_safe_globals = torch.serialization.get_safe_globals()
if "weights_only" not in kwargs:
kwargs["weights_only"] = True
torch.serialization.add_safe_globals(TORCH_SAFE_GLOBALS)
else:
kwargs.pop("weights_only", None)
loaded_obj = torch.load(f, map_location=map_location, **kwargs)
finally:
if is_weights_only_available():
torch.serialization.clear_safe_globals()
if old_safe_globals:
torch.serialization.add_safe_globals(old_safe_globals)
return loaded_obj


@contextmanager
def clear_environment():
"""
Expand Down

0 comments on commit 6f79b63

Please sign in to comment.