Skip to content

Commit

Permalink
Fix dtype bug when offload_state_dict=True and dtype is specified (
Browse files Browse the repository at this point in the history
…#2116)

* fix bug when using offload_state_dict

* fix wrong docstring & type hint

* fix & add test

* style

* fix device_map

* Update tests/test_modeling_utils.py

* fix style
  • Loading branch information
fxmarty authored Dec 5, 2023
1 parent 8f871f4 commit 9569150
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ def set_module_tensor_to_device(
if value is None:
new_value = old_value.to(device)
if dtype is not None and device in ["meta", torch.device("meta")]:
new_value = new_value.to(dtype)
if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
new_value = new_value.to(dtype)

if not is_buffer:
module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
elif isinstance(value, torch.Tensor):
Expand Down
23 changes: 23 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.linear(input, self.weight, self.bias)


class ModelSeveralDtypes(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("int_param", torch.randint(high=10, size=(15, 30)))
self.register_parameter("float_param", torch.nn.Parameter(torch.rand(10, 5)))

def forward(self, x):
return x + 2


def sequential_model(num_layers):
layers = OrderedDict([(f"linear{i}", nn.Linear(1000, 1000)) for i in range(1, num_layers + 1)])
return nn.Sequential(layers)
Expand Down Expand Up @@ -425,6 +435,19 @@ def test_load_checkpoint_in_model_two_gpu(self):
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
self.assertEqual(model.linear2.weight.device, torch.device(1))

def test_load_checkpoint_in_model_dtype(self):
with tempfile.NamedTemporaryFile(suffix=".pt") as tmpfile:
model = ModelSeveralDtypes()
torch.save(model.state_dict(), tmpfile.name)

new_model = ModelSeveralDtypes()
load_checkpoint_in_model(
new_model, tmpfile.name, offload_state_dict=True, dtype=torch.float16, device_map={"": "cpu"}
)

self.assertEqual(new_model.int_param.dtype, torch.int64)
self.assertEqual(new_model.float_param.dtype, torch.float16)

def test_clean_device_map(self):
# Regroup everything if all is on the same device
self.assertDictEqual(clean_device_map({"a": 0, "b": 0, "c": 0}), {"": 0})
Expand Down

0 comments on commit 9569150

Please sign in to comment.