diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index e792140ac8c..eb3fc5634a9 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -310,7 +310,9 @@ def set_module_tensor_to_device( if value is None: new_value = old_value.to(device) if dtype is not None and device in ["meta", torch.device("meta")]: - new_value = new_value.to(dtype) + if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + new_value = new_value.to(dtype) + if not is_buffer: module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad) elif isinstance(value, torch.Tensor): diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index d258938fe44..19b4944acf9 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -67,6 +67,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.nn.functional.linear(input, self.weight, self.bias) +class ModelSeveralDtypes(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("int_param", torch.randint(high=10, size=(15, 30))) + self.register_parameter("float_param", torch.nn.Parameter(torch.rand(10, 5))) + + def forward(self, x): + return x + 2 + + def sequential_model(num_layers): layers = OrderedDict([(f"linear{i}", nn.Linear(1000, 1000)) for i in range(1, num_layers + 1)]) return nn.Sequential(layers) @@ -425,6 +435,19 @@ def test_load_checkpoint_in_model_two_gpu(self): self.assertEqual(model.batchnorm.weight.device, torch.device("cpu")) self.assertEqual(model.linear2.weight.device, torch.device(1)) + def test_load_checkpoint_in_model_dtype(self): + with tempfile.NamedTemporaryFile(suffix=".pt") as tmpfile: + model = ModelSeveralDtypes() + torch.save(model.state_dict(), tmpfile.name) + + new_model = ModelSeveralDtypes() + load_checkpoint_in_model( + new_model, tmpfile.name, offload_state_dict=True, dtype=torch.float16, device_map={"": "cpu"} + ) + + self.assertEqual(new_model.int_param.dtype, torch.int64) + self.assertEqual(new_model.float_param.dtype, torch.float16) + def test_clean_device_map(self): # Regroup everything if all is on the same device self.assertDictEqual(clean_device_map({"a": 0, "b": 0, "c": 0}), {"": 0})