Skip to content

Commit

Permalink
add bnb support for Ascend NPU
Browse files Browse the repository at this point in the history
  • Loading branch information
statelesshz committed Nov 18, 2024
1 parent 3ee24e2 commit cf1894c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from packaging import version

from ..utils import is_torch_npu_available
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name

Expand Down Expand Up @@ -172,10 +173,13 @@ def create_quantized_param(
old_value = getattr(module, tensor_name)

if tensor_name == "bias":
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if isinstance(target_device, int) and is_torch_npu_available():
device = f"npu:{target_device}"
if param_value is None:
new_value = old_value.to(target_device)
new_value = old_value.to(device)
else:
new_value = param_value.to(target_device)
new_value = param_value.to(device)

new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value
Expand Down Expand Up @@ -264,6 +268,8 @@ def update_device_map(self, device_map):
if device_map is None:
if torch.cuda.is_available():
device_map = {"": torch.cuda.current_device()}
elif is_torch_npu_available():
device_map = {"": f"npu:{torch.npu.current_device()}"}
elif is_torch_xpu_available():
device_map = {"": f"xpu:{torch.xpu.current_device()}"}
else:
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from packaging import version

from ..utils import is_torch_npu_available
from .base import HfQuantizer


Expand Down Expand Up @@ -135,6 +136,8 @@ def update_device_map(self, device_map):
if device_map is None:
if torch.cuda.is_available():
device_map = {"": torch.cuda.current_device()}
elif is_torch_npu_available():
device_map = {"": f"npu:{torch.npu.current_device()}"}
elif is_torch_xpu_available():
device_map = {"": f"xpu:{torch.xpu.current_device()}"}
else:
Expand Down

0 comments on commit cf1894c

Please sign in to comment.