Skip to content

Commit

Permalink
fix #200
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Jul 11, 2024
1 parent 66e6aae commit 5e63a5c
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class BaseModelConfig:
Use `dtype="auto"` to derive the type from the model's weights.
device (Union[int, str]): device to use for model training.
quantization_config (Optional[BitsAndBytesConfig]): quantization
configuration for the model. Needed for 4-bit and 8-bit precision.
configuration for the model, manually provided to load a normally floating point
model at a quantized precision. Needed for 4-bit and 8-bit precision.
trust_remote_code (bool): Whether to trust remote code during model
loading.
Expand Down Expand Up @@ -144,13 +145,28 @@ def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedCon
cache_dir=env_config.cache_dir,
token=env_config.token,
)
if getattr(auto_config, "quantization_config", False) and self.quantization_config is None:

# Gathering the model's automatic quantization config, if available
try:
model_auto_quantization_config = auto_config.quantization_config
hlog("An automatic quantization config was found in the model's config. Using it to load the model")
except (AttributeError, KeyError):
model_auto_quantization_config = None

# We don't load models quantized by default with a different user provided conf
if model_auto_quantization_config is not None and self.quantization_config is not None:
raise ValueError("You manually requested quantization on a model already quantized!")

# We add the quantization to the model params we store
if model_auto_quantization_config["quant_method"] == "gptq":
if not is_autogptq_available():
raise ImportError(NO_AUTOGPTQ_ERROR_MSG)
hlog(
"`quantization_config` is None but was found in the model's config, using the one found in config.json"
)
auto_config.quantization_config["use_exllama"] = None
self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True)
elif model_auto_quantization_config["quant_method"] == "bitsandbytes":
if not is_bnb_available():
raise ImportError(NO_BNB_ERROR_MSG)
self.quantization_config = BitsAndBytesConfig(**auto_config.quantization_config)

return auto_config

Expand Down

0 comments on commit 5e63a5c

Please sign in to comment.