diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 98d57e22524902..6fb54adb32d60a 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -29,6 +29,7 @@ is_accelerate_available, is_bitsandbytes_available, is_torch_available, + is_torch_npu_available, is_torch_xpu_available, logging, ) @@ -171,6 +172,9 @@ def create_quantized_param( old_value = getattr(module, tensor_name) + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if isinstance(target_device, int) and is_torch_npu_available(): + target_device = f"npu:{target_device}" if tensor_name == "bias": if param_value is None: new_value = old_value.to(target_device) @@ -264,6 +268,8 @@ def update_device_map(self, device_map): if device_map is None: if torch.cuda.is_available(): device_map = {"": torch.cuda.current_device()} + elif is_torch_npu_available(): + device_map = {"": f"npu:{torch.npu.current_device()}"} elif is_torch_xpu_available(): device_map = {"": f"xpu:{torch.xpu.current_device()}"} else: