diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 932a94571107..751117f8f247 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -176,6 +176,8 @@ def load_model_dict_into_meta( hf_quantizer=None, keep_in_fp32_modules=None, ) -> List[str]: + if device is not None and not isinstance(device, (str, torch.device)): + raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.") if hf_quantizer is None: device = device or torch.device("cpu") dtype = dtype or torch.float32 diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 76f6c5f6309d..7b2022798d41 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -836,7 +836,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P param_device = "cpu" # TODO (sayakpaul, SunMarc): remove this after model loading refactor elif is_quant_method_bnb: - param_device = torch.cuda.current_device() + param_device = torch.device(torch.cuda.current_device()) state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict)