From 06c3eb841ef1a81e02109a17925ed46df346d9e0 Mon Sep 17 00:00:00 2001 From: statelesshz Date: Mon, 17 Jun 2024 10:50:08 +0800 Subject: [PATCH] add bnb support for Ascend NPU --- src/transformers/quantizers/quantizer_bnb_4bit.py | 6 ++++++ src/transformers/quantizers/quantizer_bnb_8bit.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 98d57e22524902..dc9f5ab39029ac 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -17,6 +17,7 @@ from packaging import version +from ..utils import is_torch_npu_available from .base import HfQuantizer from .quantizers_utils import get_module_from_name @@ -171,6 +172,9 @@ def create_quantized_param( old_value = getattr(module, tensor_name) + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if isinstance(target_device, int) and is_torch_npu_available(): + target_device = f"npu:{target_device}" if tensor_name == "bias": if param_value is None: new_value = old_value.to(target_device) @@ -264,6 +268,8 @@ def update_device_map(self, device_map): if device_map is None: if torch.cuda.is_available(): device_map = {"": torch.cuda.current_device()} + elif is_torch_npu_available(): + device_map = {"": f"npu:{torch.npu.current_device()}"} elif is_torch_xpu_available(): device_map = {"": f"xpu:{torch.xpu.current_device()}"} else: diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index 093d612b914cef..bc967e821253b1 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -16,6 +16,7 @@ from packaging import version +from ..utils import is_torch_npu_available from .base import HfQuantizer @@ -135,6 +136,8 @@ def update_device_map(self, device_map): if device_map is None: if torch.cuda.is_available(): device_map = {"": torch.cuda.current_device()} + elif is_torch_npu_available(): + device_map = {"": f"npu:{torch.npu.current_device()}"} elif is_torch_xpu_available(): device_map = {"": f"xpu:{torch.xpu.current_device()}"} else: