Skip to content

Commit

Permalink
add bnb support for Ascend NPU
Browse files Browse the repository at this point in the history
  • Loading branch information
statelesshz committed Nov 18, 2024
1 parent 3ee24e2 commit 06c3eb8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from packaging import version

from ..utils import is_torch_npu_available
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name

Expand Down Expand Up @@ -171,6 +172,9 @@ def create_quantized_param(

old_value = getattr(module, tensor_name)

# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if isinstance(target_device, int) and is_torch_npu_available():
target_device = f"npu:{target_device}"
if tensor_name == "bias":
if param_value is None:
new_value = old_value.to(target_device)
Expand Down Expand Up @@ -264,6 +268,8 @@ def update_device_map(self, device_map):
if device_map is None:
if torch.cuda.is_available():
device_map = {"": torch.cuda.current_device()}
elif is_torch_npu_available():
device_map = {"": f"npu:{torch.npu.current_device()}"}
elif is_torch_xpu_available():
device_map = {"": f"xpu:{torch.xpu.current_device()}"}
else:
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from packaging import version

from ..utils import is_torch_npu_available
from .base import HfQuantizer


Expand Down Expand Up @@ -135,6 +136,8 @@ def update_device_map(self, device_map):
if device_map is None:
if torch.cuda.is_available():
device_map = {"": torch.cuda.current_device()}
elif is_torch_npu_available():
device_map = {"": f"npu:{torch.npu.current_device()}"}
elif is_torch_xpu_available():
device_map = {"": f"xpu:{torch.xpu.current_device()}"}
else:
Expand Down

0 comments on commit 06c3eb8

Please sign in to comment.