From 5e63a5c9d0657f9d47534c305ead43703105e92d Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Thu, 11 Jul 2024 14:25:54 +0000 Subject: [PATCH 1/3] fix #200 --- src/lighteval/models/model_config.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index b6f4bb5d9..419bd6149 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -95,7 +95,8 @@ class BaseModelConfig: Use `dtype="auto"` to derive the type from the model's weights. device (Union[int, str]): device to use for model training. quantization_config (Optional[BitsAndBytesConfig]): quantization - configuration for the model. Needed for 4-bit and 8-bit precision. + configuration for the model, manually provided to load a normally floating point + model at a quantized precision. Needed for 4-bit and 8-bit precision. trust_remote_code (bool): Whether to trust remote code during model loading. @@ -144,13 +145,28 @@ def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedCon cache_dir=env_config.cache_dir, token=env_config.token, ) - if getattr(auto_config, "quantization_config", False) and self.quantization_config is None: + + # Gathering the model's automatic quantization config, if available + try: + model_auto_quantization_config = auto_config.quantization_config + hlog("An automatic quantization config was found in the model's config. Using it to load the model") + except (AttributeError, KeyError): + model_auto_quantization_config = None + + # We don't load models quantized by default with a different user provided conf + if model_auto_quantization_config is not None and self.quantization_config is not None: + raise ValueError("You manually requested quantization on a model already quantized!") + + # We add the quantization to the model params we store + if model_auto_quantization_config["quant_method"] == "gptq": if not is_autogptq_available(): raise ImportError(NO_AUTOGPTQ_ERROR_MSG) - hlog( - "`quantization_config` is None but was found in the model's config, using the one found in config.json" - ) + auto_config.quantization_config["use_exllama"] = None self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True) + elif model_auto_quantization_config["quant_method"] == "bitsandbytes": + if not is_bnb_available(): + raise ImportError(NO_BNB_ERROR_MSG) + self.quantization_config = BitsAndBytesConfig(**auto_config.quantization_config) return auto_config From e0e99fc537cba31ae322f58a5efc21bdf9198ee6 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Thu, 11 Jul 2024 15:03:21 +0000 Subject: [PATCH 2/3] fix #176 --- src/lighteval/models/base_model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index 3e483d448..fbd0f9337 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -29,7 +29,7 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset from lighteval.logging.hierarchical_logger import hlog, hlog_err, hlog_warn @@ -88,9 +88,7 @@ def __init__( self.multichoice_continuations_start_space = config.multichoice_continuations_start_space # We are in DP (and launch the script with `accelerate launch`) - if not config.model_parallel and config.quantization_config is None: - # might need to use accelerate instead - # self.model = config.accelerator.prepare(self.model) + if not config.model_parallel and not isinstance(config.quantization_config, BitsAndBytesConfig): hlog(f"Using Data Parallelism, putting model on device {self._device}") self.model = self.model.to(self._device) From 7858e823f4d144598774abf760e86e4d06325079 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine?= Date: Thu, 11 Jul 2024 17:11:57 +0200 Subject: [PATCH 3/3] fix --- src/lighteval/models/model_config.py | 29 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index 419bd6149..75a29d02c 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -153,20 +153,21 @@ def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedCon except (AttributeError, KeyError): model_auto_quantization_config = None - # We don't load models quantized by default with a different user provided conf - if model_auto_quantization_config is not None and self.quantization_config is not None: - raise ValueError("You manually requested quantization on a model already quantized!") - - # We add the quantization to the model params we store - if model_auto_quantization_config["quant_method"] == "gptq": - if not is_autogptq_available(): - raise ImportError(NO_AUTOGPTQ_ERROR_MSG) - auto_config.quantization_config["use_exllama"] = None - self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True) - elif model_auto_quantization_config["quant_method"] == "bitsandbytes": - if not is_bnb_available(): - raise ImportError(NO_BNB_ERROR_MSG) - self.quantization_config = BitsAndBytesConfig(**auto_config.quantization_config) + if model_auto_quantization_config is not None: + if self.quantization_config is not None: + # We don't load models quantized by default with a different user provided conf + raise ValueError("You manually requested quantization on a model already quantized!") + + # We add the quantization to the model params we store + if model_auto_quantization_config["quant_method"] == "gptq": + if not is_autogptq_available(): + raise ImportError(NO_AUTOGPTQ_ERROR_MSG) + auto_config.quantization_config["use_exllama"] = None + self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True) + elif model_auto_quantization_config["quant_method"] == "bitsandbytes": + if not is_bnb_available(): + raise ImportError(NO_BNB_ERROR_MSG) + self.quantization_config = BitsAndBytesConfig(**auto_config.quantization_config) return auto_config