Skip to content

Commit

Permalink
Raise error when saving with param on meta device (#2132)
Browse files Browse the repository at this point in the history
* add error

* style

* Update src/accelerate/accelerator.py

Co-authored-by: Zach Mueller <[email protected]>

* style

* move before creating the directory

---------

Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
SunMarc and muellerzr authored Nov 8, 2023
1 parent e638b1e commit 0b0d921
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2574,6 +2574,9 @@ def save_model(
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return

if any(param.device == torch.device("meta") for param in model.parameters()):
raise RuntimeError("You can't save the model since some parameters are on the meta device.")

os.makedirs(save_directory, exist_ok=True)

# get the state_dict of the model
Expand Down
24 changes: 24 additions & 0 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ def create_components():
return model, optimizer, scheduler, train_dl, valid_dl


class ModelForTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(3, 4)
self.batchnorm = torch.nn.BatchNorm1d(4)
self.linear2 = torch.nn.Linear(4, 5)

def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))


def get_signature(model):
return (model.weight.abs().sum() + model.bias.abs().sum()).item()

Expand Down Expand Up @@ -136,6 +147,19 @@ def test_save_model(self, use_safetensors):
load_checkpoint_in_model(model, tmpdirname)
self.assertTrue(abs(model_signature - get_signature(model)) < 1e-3)

@parameterized.expand([True, False], name_func=parameterized_custom_name_func)
def test_save_model_offload(self, use_safetensors):
accelerator = Accelerator()

device_map = {"linear1": "cpu", "batchnorm": "disk", "linear2": "cpu"}

model = ModelForTest()
with tempfile.TemporaryDirectory() as tmp_dir:
accelerator.save_model(model, tmp_dir, safe_serialization=use_safetensors)
load_checkpoint_in_model(model, tmp_dir, device_map=device_map, offload_folder=tmp_dir)
with self.assertRaises(RuntimeError):
accelerator.save_model(model, tmp_dir, safe_serialization=use_safetensors)

@parameterized.expand([True, False], name_func=parameterized_custom_name_func)
def test_save_load_model_with_hooks(self, use_safetensors):
accelerator = Accelerator()
Expand Down

0 comments on commit 0b0d921

Please sign in to comment.