Skip to content

Commit

Permalink
Use torch.device instead of current device index for BnB quantizer (#…
Browse files Browse the repository at this point in the history
…10069)

* update

* apply review suggestion

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
a-r-r-o-w and sayakpaul authored Dec 5, 2024
1 parent 0d11ab2 commit 98d0cd5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def load_model_dict_into_meta(
hf_quantizer=None,
keep_in_fp32_modules=None,
) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
if hf_quantizer is None:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
param_device = "cpu"
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
elif is_quant_method_bnb:
param_device = torch.cuda.current_device()
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)

Expand Down

0 comments on commit 98d0cd5

Please sign in to comment.