Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantization related issues #224

Merged
merged 7 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
from lighteval.logging.hierarchical_logger import hlog, hlog_err, hlog_warn
Expand Down Expand Up @@ -88,9 +88,7 @@ def __init__(
self.multichoice_continuations_start_space = config.multichoice_continuations_start_space

# We are in DP (and launch the script with `accelerate launch`)
if not config.model_parallel and config.quantization_config is None:
# might need to use accelerate instead
# self.model = config.accelerator.prepare(self.model)
if not config.model_parallel and not isinstance(config.quantization_config, BitsAndBytesConfig):
Copy link
Member Author

@clefourrier clefourrier Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to enter this if the config provided is a GPTQ config, but not bits and bytes

hlog(f"Using Data Parallelism, putting model on device {self._device}")
self.model = self.model.to(self._device)

Expand Down
33 changes: 25 additions & 8 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class BaseModelConfig:
Use `dtype="auto"` to derive the type from the model's weights.
device (Union[int, str]): device to use for model training.
quantization_config (Optional[BitsAndBytesConfig]): quantization
configuration for the model. Needed for 4-bit and 8-bit precision.
configuration for the model, manually provided to load a normally floating point
model at a quantized precision. Needed for 4-bit and 8-bit precision.
trust_remote_code (bool): Whether to trust remote code during model
loading.

Expand Down Expand Up @@ -144,13 +145,29 @@ def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedCon
cache_dir=env_config.cache_dir,
token=env_config.token,
)
if getattr(auto_config, "quantization_config", False) and self.quantization_config is None:
if not is_autogptq_available():
raise ImportError(NO_AUTOGPTQ_ERROR_MSG)
hlog(
"`quantization_config` is None but was found in the model's config, using the one found in config.json"
)
self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True)

# Gathering the model's automatic quantization config, if available
try:
model_auto_quantization_config = auto_config.quantization_config
hlog("An automatic quantization config was found in the model's config. Using it to load the model")
except (AttributeError, KeyError):
model_auto_quantization_config = None

if model_auto_quantization_config is not None:
if self.quantization_config is not None:
# We don't load models quantized by default with a different user provided conf
raise ValueError("You manually requested quantization on a model already quantized!")

# We add the quantization to the model params we store
if model_auto_quantization_config["quant_method"] == "gptq":
if not is_autogptq_available():
raise ImportError(NO_AUTOGPTQ_ERROR_MSG)
auto_config.quantization_config["use_exllama"] = None
self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True)
elif model_auto_quantization_config["quant_method"] == "bitsandbytes":
if not is_bnb_available():
raise ImportError(NO_BNB_ERROR_MSG)
self.quantization_config = BitsAndBytesConfig(**auto_config.quantization_config)

return auto_config

Expand Down
Loading