diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 98d57e22524902..dcba696ddf83b4 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -17,6 +17,7 @@ from packaging import version +from ..utils import is_torch_npu_available from .base import HfQuantizer from .quantizers_utils import get_module_from_name @@ -172,10 +173,13 @@ def create_quantized_param( old_value = getattr(module, tensor_name) if tensor_name == "bias": + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if isinstance(target_device, int) and is_torch_npu_available(): + device = f"npu:{target_device}" if param_value is None: - new_value = old_value.to(target_device) + new_value = old_value.to(device) else: - new_value = param_value.to(target_device) + new_value = param_value.to(device) new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad) module._parameters[tensor_name] = new_value @@ -264,6 +268,8 @@ def update_device_map(self, device_map): if device_map is None: if torch.cuda.is_available(): device_map = {"": torch.cuda.current_device()} + elif is_torch_npu_available(): + device_map = {"": f"npu:{torch.npu.current_device()}"} elif is_torch_xpu_available(): device_map = {"": f"xpu:{torch.xpu.current_device()}"} else: diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index 093d612b914cef..bc967e821253b1 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -16,6 +16,7 @@ from packaging import version +from ..utils import is_torch_npu_available from .base import HfQuantizer @@ -135,6 +136,8 @@ def update_device_map(self, device_map): if device_map is None: if torch.cuda.is_available(): device_map = {"": torch.cuda.current_device()} + elif is_torch_npu_available(): + device_map = {"": f"npu:{torch.npu.current_device()}"} elif is_torch_xpu_available(): device_map = {"": f"xpu:{torch.xpu.current_device()}"} else: