Skip to content

Commit

Permalink
Assemble state dictionary for offloaded models (#2156)
Browse files Browse the repository at this point in the history
* changed meta alignment device to cpu

* reverted alignment device and init weight map

* trace on values

* trace on values

* trace on values

* added offload model state dict save and test

* removed hook traces

* removed n

* Update src/accelerate/accelerator.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/accelerate/utils/modeling.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/accelerate/utils/modeling.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/accelerate/utils/modeling.py

Co-authored-by: Marc Sun <[email protected]>

* suggestions and make style

* fixed circular import and make style

* debugged test

* Update src/accelerate/utils/modeling.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/accelerate/utils/modeling.py

Co-authored-by: Marc Sun <[email protected]>

* function level import and make style

* Update src/accelerate/utils/modeling.py

Co-authored-by: Zach Mueller <[email protected]>

* Update tests/test_accelerator.py

Co-authored-by: Marc Sun <[email protected]>

* Update tests/test_accelerator.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/accelerate/utils/modeling.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/accelerate/utils/modeling.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/accelerate/utils/modeling.py

Co-authored-by: Marc Sun <[email protected]>

* make style

---------

Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
3 people authored Nov 30, 2023
1 parent 68d63ee commit 3499cf2
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 8 deletions.
18 changes: 14 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
from .hooks import AlignDevicesHook
from .logging import get_logger
from .optimizer import AcceleratedOptimizer
from .scheduler import AcceleratedScheduler
Expand Down Expand Up @@ -96,6 +97,7 @@
wait_for_everyone,
)
from .utils.constants import FSDP_PYTORCH_VERSION
from .utils.modeling import get_state_dict_offloaded_model
from .utils.other import is_compiled_module


Expand Down Expand Up @@ -2490,13 +2492,21 @@ def save_model(
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return

if any(param.device == torch.device("meta") for param in model.parameters()):
raise RuntimeError("You can't save the model since some parameters are on the meta device.")

os.makedirs(save_directory, exist_ok=True)

# get the state_dict of the model
state_dict = self.get_state_dict(model)
if any(
[
module._hf_hook.offload
for module in model.modules()
if hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook)
]
):
state_dict = get_state_dict_offloaded_model(model)
else:
if any(param.device == torch.device("meta") for param in model.parameters()):
raise RuntimeError("You can't save the model since some parameters are on the meta device.")
state_dict = self.get_state_dict(model)

if safe_serialization:
state_dict = clean_state_dict_for_safetensors(state_dict)
Expand Down
48 changes: 48 additions & 0 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,54 @@ def load_state_dict(checkpoint_file, device_map=None):
return torch.load(checkpoint_file, map_location=torch.device("cpu"))


def get_state_dict_offloaded_model(model: nn.Module):
"""
Returns the state dictionary for an offloaded model via iterative onloading
Args:
model (`torch.nn.Module`):
The offloaded model we want to save
"""
from ..hooks import AlignDevicesHook

state_dict = {}
placeholders = set()
for name, module in model.named_modules():
if name == "":
continue
if hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload:
original_device = module._hf_hook.execution_device
# assign hook execution device to cpu
module._hf_hook.execution_device = "cpu"
# onload meta tensors to execution device
try:
module._hf_hook.pre_forward(module)
except MemoryError:
raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None
module_state_dict = module.state_dict()
# offload meta tensors from cpu
module._hf_hook.post_forward(module, torch.tensor([]))
# re-assign hook to original execution device
module._hf_hook.execution_device = original_device
else:
module_state_dict = module.state_dict()

for key in module_state_dict:
# ignore placeholder parameters that are still on the meta device
if module_state_dict[key].device == torch.device("meta"):
placeholders.add(name + f".{key}")
continue
params = module_state_dict[key]
state_dict[name + f".{key}"] = params
for key in placeholders.copy():
if key in state_dict:
placeholders.remove(key)
if placeholders:
logger.warning(f"The following tensors were not saved because they were still on meta device: {placeholders}")

return state_dict


def load_checkpoint_in_model(
model: nn.Module,
checkpoint: Union[str, os.PathLike],
Expand Down
15 changes: 11 additions & 4 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from parameterized import parameterized
from torch.utils.data import DataLoader, TensorDataset

from accelerate import DistributedType, infer_auto_device_map, init_empty_weights
from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
from accelerate.accelerator import Accelerator
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import require_bnb, require_multi_gpu, slow
Expand Down Expand Up @@ -153,12 +153,19 @@ def test_save_model_offload(self, use_safetensors):

device_map = {"linear1": "cpu", "batchnorm": "disk", "linear2": "cpu"}

inputs = torch.randn(3, 3)
model = ModelForTest()
expected = model(inputs)
with tempfile.TemporaryDirectory() as tmp_dir:
accelerator.save_model(model, tmp_dir, safe_serialization=use_safetensors)
load_checkpoint_in_model(model, tmp_dir, device_map=device_map, offload_folder=tmp_dir)
with self.assertRaises(RuntimeError):
accelerator.save_model(model, tmp_dir, safe_serialization=use_safetensors)
# load and save offloaded model
load_checkpoint_and_dispatch(model, tmp_dir, device_map=device_map, offload_folder=tmp_dir)
accelerator.save_model(model, tmp_dir, safe_serialization=use_safetensors)

# load weights that were saved from the offloaded model
load_checkpoint_and_dispatch(model, tmp_dir)
output = model(inputs)
self.assertTrue(torch.allclose(expected, output, atol=1e-5))

@parameterized.expand([True, False], name_func=parameterized_custom_name_func)
def test_save_load_model_with_hooks(self, use_safetensors):
Expand Down

0 comments on commit 3499cf2

Please sign in to comment.