Skip to content

Commit

Permalink
add bnb support for Ascend NPU
Browse files Browse the repository at this point in the history
  • Loading branch information
statelesshz committed Dec 7, 2024
1 parent 3ee24e2 commit 0d852b8
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
is_accelerate_available,
is_bitsandbytes_available,
is_torch_available,
is_torch_npu_available,
is_torch_xpu_available,
logging,
)
Expand Down Expand Up @@ -171,6 +172,9 @@ def create_quantized_param(

old_value = getattr(module, tensor_name)

# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if isinstance(target_device, int) and is_torch_npu_available():
target_device = f"npu:{target_device}"
if tensor_name == "bias":
if param_value is None:
new_value = old_value.to(target_device)
Expand Down Expand Up @@ -264,6 +268,8 @@ def update_device_map(self, device_map):
if device_map is None:
if torch.cuda.is_available():
device_map = {"": torch.cuda.current_device()}
elif is_torch_npu_available():
device_map = {"": f"npu:{torch.npu.current_device()}"}
elif is_torch_xpu_available():
device_map = {"": f"xpu:{torch.xpu.current_device()}"}
else:
Expand Down

0 comments on commit 0d852b8

Please sign in to comment.