From c9f2047f7a9da9b0dfd54391cac9945c158fa17b Mon Sep 17 00:00:00 2001 From: statelesshz Date: Mon, 17 Jun 2024 10:50:08 +0800 Subject: [PATCH 1/2] add bnb support for Ascend NPU --- src/transformers/quantizers/quantizer_bnb_4bit.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 98d57e22524902..6fb54adb32d60a 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -29,6 +29,7 @@ is_accelerate_available, is_bitsandbytes_available, is_torch_available, + is_torch_npu_available, is_torch_xpu_available, logging, ) @@ -171,6 +172,9 @@ def create_quantized_param( old_value = getattr(module, tensor_name) + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if isinstance(target_device, int) and is_torch_npu_available(): + target_device = f"npu:{target_device}" if tensor_name == "bias": if param_value is None: new_value = old_value.to(target_device) @@ -264,6 +268,8 @@ def update_device_map(self, device_map): if device_map is None: if torch.cuda.is_available(): device_map = {"": torch.cuda.current_device()} + elif is_torch_npu_available(): + device_map = {"": f"npu:{torch.npu.current_device()}"} elif is_torch_xpu_available(): device_map = {"": f"xpu:{torch.xpu.current_device()}"} else: From a411f18e4223bf3ed542922eea563c07c74e1d60 Mon Sep 17 00:00:00 2001 From: Huazhong Ji Date: Mon, 9 Dec 2024 22:52:22 +0800 Subject: [PATCH 2/2] delete comment --- src/transformers/quantizers/quantizer_bnb_4bit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 6fb54adb32d60a..8657bda166254d 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -263,7 +263,6 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": torch_dtype = torch.float16 return torch_dtype - # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_device_map def update_device_map(self, device_map): if device_map is None: if torch.cuda.is_available():