diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py index 7b81c93edf1..4dd818f6465 100644 --- a/src/transformers/quantizers/quantizer_awq.py +++ b/src/transformers/quantizers/quantizer_awq.py @@ -87,6 +87,7 @@ def validate_environment(self, device_map, **kwargs): def update_torch_dtype(self, torch_dtype): if torch_dtype is None: torch_dtype = torch.float16 + logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.") elif torch_dtype != torch.float16: logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.") return torch_dtype diff --git a/src/transformers/quantizers/quantizer_gptq.py b/src/transformers/quantizers/quantizer_gptq.py index 233a5279d3f..bf5079435d6 100644 --- a/src/transformers/quantizers/quantizer_gptq.py +++ b/src/transformers/quantizers/quantizer_gptq.py @@ -64,6 +64,7 @@ def validate_environment(self, *args, **kwargs): def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: torch_dtype = torch.float16 + logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.") elif torch_dtype != torch.float16: logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.") return torch_dtype diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index 10d2b184ef1..bcc9c57dfa0 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -114,6 +114,9 @@ def update_torch_dtype(self, torch_dtype): torch_dtype = torch.bfloat16 if self.quantization_config.quant_type == "int8_dynamic_activation_int8_weight": if torch_dtype is None: + logger.info( + "Setting torch_dtype to torch.float32 for int8_dynamic_activation_int8_weight quantization as no torch_dtype was specified in from_pretrained" + ) # we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op torch_dtype = torch.float32 return torch_dtype