-
-
Notifications
You must be signed in to change notification settings - Fork 894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AutoGPTQ quantization script #545
base: main
Are you sure you want to change the base?
Changes from 3 commits
84d4476
766875f
0c743e3
5a7f301
8c7b7c5
88c31f1
06a44de
ab3cffa
b22d1c6
6f3216e
e9eae77
83e6b29
aaf4d1e
e4c1a2e
19a30cf
894a4be
24c0483
14d26e1
c6c54ee
dee3d54
09b16d8
578d8b6
cf23998
8a26ab3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
base_model: meta-llama/Llama-2-7b-hf | ||
base_model_config: meta-llama/Llama-2-7b-hf | ||
model_type: LlamaForCausalLM | ||
tokenizer_type: LlamaTokenizer | ||
is_llama_derived_model: true | ||
|
||
load_in_8bit: true | ||
load_in_4bit: false | ||
strict: false | ||
|
||
datasets: | ||
- path: mhenrichsen/alpaca_2k_test | ||
type: alpaca | ||
dataset_prepared_path: last_run_prepared | ||
# val_set_size: 0.01 | ||
val_set_size: 0.001 | ||
output_dir: ./lora-out | ||
|
||
sequence_len: 4096 | ||
sample_packing: true | ||
|
||
adapter: lora | ||
lora_model_dir: | ||
lora_r: 32 | ||
lora_alpha: 16 | ||
lora_dropout: 0.05 | ||
lora_target_linear: true | ||
lora_fan_in_fan_out: | ||
|
||
wandb_project: | ||
wandb_entity: | ||
wandb_watch: | ||
wandb_run_id: | ||
wandb_log_model: | ||
|
||
gradient_accumulation_steps: 4 | ||
micro_batch_size: 2 | ||
# num_epochs: 3 | ||
# num_epochs: 1 | ||
num_epochs: 0.1 | ||
optimizer: adamw_bnb_8bit | ||
lr_scheduler: cosine | ||
learning_rate: 0.0002 | ||
|
||
train_on_inputs: false | ||
group_by_length: false | ||
bf16: true | ||
fp16: false | ||
tf32: false | ||
|
||
gradient_checkpointing: true | ||
early_stopping_patience: | ||
resume_from_checkpoint: | ||
local_rank: | ||
logging_steps: 1 | ||
xformers_attention: | ||
flash_attention: true | ||
|
||
warmup_steps: 10 | ||
eval_steps: 20 | ||
save_steps: | ||
debug: | ||
deepspeed: | ||
weight_decay: 0.0 | ||
fsdp: | ||
fsdp_config: | ||
special_tokens: | ||
bos_token: "<s>" | ||
eos_token: "</s>" | ||
unk_token: "<unk>" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" | ||
|
||
import gc | ||
|
||
import importlib | ||
import logging | ||
import os | ||
|
@@ -27,6 +29,7 @@ | |
from axolotl.utils.models import load_tokenizer | ||
from axolotl.utils.tokenization import check_dataset_labels | ||
from axolotl.utils.wandb import setup_wandb_env_vars | ||
from axolotl.utils.quantize import get_examples_for_quantization, load_merged_model, quantize_and_save | ||
|
||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | ||
src_dir = os.path.join(project_root, "src") | ||
|
@@ -57,27 +60,44 @@ def get_multi_line_input() -> Optional[str]: | |
# instruction = pathlib.Path("/proc/self/fd/0").read_text() | ||
return instruction | ||
|
||
def get_merged_out_dir(cfg: DictDefault): | ||
return Path(cfg.output_dir) / "merged" | ||
|
||
def do_merge_lora( | ||
def do_merge_lora_model_and_tokenizer( | ||
*, | ||
cfg: DictDefault, | ||
cli_args: TrainerCliArgs, | ||
model, | ||
tokenizer, | ||
): | ||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) | ||
safe_serialization = cfg.save_safetensors is True | ||
|
||
LOG.info("running merge of LoRA with base model") | ||
model = model.merge_and_unload() | ||
model.to(dtype=torch.float16) | ||
|
||
merged_out_dir = str(get_merged_out_dir(cfg)) | ||
|
||
if cfg.local_rank == 0: | ||
LOG.info("saving merged model") | ||
model.save_pretrained( | ||
str(Path(cfg.output_dir) / "merged"), | ||
merged_out_dir, | ||
safe_serialization=safe_serialization, | ||
) | ||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) | ||
tokenizer.save_pretrained(merged_out_dir) | ||
|
||
def do_merge_lora( | ||
*, | ||
cfg: DictDefault, | ||
cli_args: TrainerCliArgs, | ||
): | ||
new_cfg = DictDefault({ | ||
**cfg, | ||
'lora_model_dir': cfg.get('lora_model_dir', cfg['output_dir']), | ||
'load_in_8bit': False, | ||
'load_in_4bit': False, | ||
}) | ||
model, tokenizer = load_model_and_tokenizer(cfg=new_cfg, cli_args=cli_args) | ||
do_merge_lora_model_and_tokenizer(cfg=new_cfg, model=model, tokenizer=tokenizer) | ||
|
||
def shard( | ||
*, | ||
|
@@ -267,11 +287,72 @@ def do_cli(config: Path = Path("examples/"), **kwargs): | |
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) | ||
elif parsed_cli_args.shard: | ||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args) | ||
elif parsed_cli_args.quantize: | ||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) | ||
|
||
tokenizer = load_tokenizer(parsed_cfg) | ||
# Load merged model with AutoGPTQ | ||
merged_model = load_merged_model(parsed_cfg) | ||
|
||
# Quantize & save | ||
n_samples = 128 | ||
examples = get_examples_for_quantization(dataset_meta.train_dataset, n_samples) | ||
quantize_and_save(parsed_cfg, merged_model, tokenizer, examples) | ||
|
||
else: | ||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) | ||
if parsed_cli_args.prepare_ds_only: | ||
return | ||
# model, tokenizer = train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) | ||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) | ||
# tokenizer = None | ||
should_quantize = True | ||
|
||
if should_quantize: | ||
# Merge model | ||
# do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) | ||
# do_merge_lora_model_and_tokenizer(cfg=parsed_cfg, model=model, tokenizer=tokenizer) | ||
# new_cfg = parsed_cfg.copy() | ||
# new_cfg['lora_model_dir'] = new_cfg['output_dir'] | ||
# new_cfg['load_in_8bit'] = False | ||
# new_cfg['load_in_4bit'] = False | ||
|
||
# new_cfg = DictDefault({ | ||
# **parsed_cfg, | ||
# 'lora_model_dir': parsed_cfg['output_dir'], | ||
# 'load_in_8bit': False, | ||
# 'load_in_4bit': False, | ||
# }) | ||
# lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False | ||
# do_merge_lora(cfg=new_cfg, cli_args=parsed_cli_args) | ||
|
||
def log_gpu_memory(): | ||
print("GPU Memory:", torch.cuda.memory_allocated()) | ||
|
||
log_gpu_memory() | ||
print(len(gc.get_referrers(model))) | ||
print(sys.getrefcount(model)) | ||
|
||
# TODO: release old model from GPU memory | ||
print(gc.collect()) | ||
del model | ||
# del tokenizer | ||
print(gc.collect()) | ||
torch.cuda.empty_cache() | ||
print(gc.collect()) | ||
|
||
log_gpu_memory() | ||
|
||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Help Wanted I kept getting:
when calling Running |
||
|
||
# Load merged model with AutoGPTQ | ||
merged_model = load_merged_model(parsed_cfg) | ||
|
||
# Quantize & save | ||
n_samples = 128 | ||
examples = get_examples_for_quantization(dataset_meta.train_dataset, n_samples) | ||
quantize_and_save(parsed_cfg, merged_model, tokenizer, examples) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ | ||
|
||
# import debugpy | ||
# debugpy.listen(('0.0.0.0', 5678)) | ||
# debugpy.wait_for_client() | ||
# debugpy.breakpoint() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Glavin001 clean up old code |
||
|
||
import json | ||
import random | ||
import time | ||
from pathlib import Path | ||
import logging | ||
import re | ||
|
||
import torch | ||
from datasets import load_dataset, Dataset | ||
from transformers import AutoTokenizer, LlamaTokenizer, TextGenerationPipeline | ||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig | ||
from axolotl.prompters import AlpacaPrompter | ||
from axolotl.utils.models import load_model, load_tokenizer | ||
from axolotl.common.cli import TrainerCliArgs | ||
from axolotl.logging_config import configure_logging | ||
from axolotl.utils.dict import DictDefault | ||
# from scripts.finetune import load_cfg | ||
from finetune import load_cfg, get_merged_out_dir, do_merge_lora_model_and_tokenizer, load_datasets | ||
|
||
from axolotl.utils.quantize import load_merged_model, get_quantized_model, quantize_and_save, push_model, get_quantized_model_id, get_quantized_model_dir, get_examples_for_quantization | ||
|
||
configure_logging() | ||
LOG = logging.getLogger("axolotl.quantize") | ||
|
||
import debugpy | ||
debugpy.listen(('0.0.0.0', 5678)) | ||
debugpy.wait_for_client() | ||
debugpy.breakpoint() | ||
|
||
class ProgressExtractingHandler(logging.Handler): | ||
def emit(self, record): | ||
log_entry = self.format(record) | ||
progress_info = self.extract_progress(log_entry) | ||
if progress_info: | ||
print(f"Progress: {progress_info}") | ||
|
||
@staticmethod | ||
def extract_progress(log_entry): | ||
# [2023-09-11 07:20:37,502] [INFO] [auto_gptq.modeling._base.quantize:364] [PID:3962] [RANK:0] Quantizing self_attn.k_proj in layer 4/32... | ||
match = re.search(r'layer (\d+/\d+)', log_entry) | ||
return match.group(1) if match else None | ||
# [2023-09-11 07:27:52,208] [INFO] [auto_gptq.modeling._utils.pack_model:129] [PID:3962] [RANK:0] model.layers.15.self_attn.o_proj | ||
|
||
handler = ProgressExtractingHandler() | ||
# logging.getLogger('auto_gptq.modeling._base.quantize').addHandler(handler) | ||
logger = logging.getLogger('auto_gptq.modeling._base.quantize') | ||
logger.setLevel(logging.DEBUG) | ||
logger.addHandler(handler) | ||
|
||
# logging.basicConfig( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Help Wanted I couldn't get any logging to work from AutoGPTQ. Would be nice to fix logging. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is my old code which works when not calling Axolotl's |
||
# format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.DEBUG, datefmt="%Y-%m-%d %H:%M:%S" | ||
# ) | ||
|
||
# LOG.setLevel(logging.DEBUG) | ||
# handler = logging.StreamHandler() | ||
# formatter = logging.Formatter('%(asctime)s %(levelname)s [%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') | ||
# handler.setFormatter(formatter) | ||
# LOG.addHandler(handler) | ||
|
||
print("Done importing...") | ||
|
||
## CHANGE BELOW ## | ||
# config_path: Path = Path("./examples/llama-2/lora.yml") | ||
config_path: Path = Path("./examples/llama-2/lora-short.yml") | ||
|
||
# pretrained_model_dir = "facebook/opt-125m" | ||
# quantized_model_dir = "opt-125m-4bit" | ||
dataset_name = "teknium/GPT4-LLM-Cleaned" | ||
# huggingface_username = "CHANGE_ME" | ||
## CHANGE ABOVE | ||
|
||
def main(): | ||
print("Starting...") | ||
# return | ||
# prompt = "<|prompt|>How can entrepreneurs start building their own communities even before launching their product?</s><|answer|>" | ||
|
||
should_quantize = True | ||
# tokenizer = get_tokenizer() | ||
|
||
cfg = load_cfg(config_path) | ||
|
||
cfg['lora_model_dir'] = cfg['output_dir'] | ||
|
||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") | ||
tokenizer = load_tokenizer(cfg) | ||
|
||
if should_quantize: | ||
print("Quantizing...") | ||
|
||
print("Loading dataset...") | ||
datasets = load_datasets(cfg=cfg, cli_args=TrainerCliArgs()) | ||
train_dataset = datasets.train_dataset | ||
n_samples = 128 | ||
# # n_samples = 2 | ||
# examples = train_dataset.shuffle(seed=42).select( | ||
# [ | ||
# random.randrange(0, len(train_dataset) - 1) # nosec | ||
# for _ in range(n_samples) | ||
# ] | ||
# ) | ||
|
||
LOG.info("loading model and (optionally) peft_config...") | ||
# model, peft_config = load_model(cfg, tokenizer, inference=True) | ||
model = load_merged_model(cfg) | ||
# model = get_model() | ||
|
||
# examples = load_data(dataset_name, tokenizer, n_samples) | ||
|
||
# print(examples) | ||
# examples_for_quant = [ | ||
# {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]} | ||
# for example in examples | ||
# ] | ||
# print(examples_for_quant) | ||
examples_for_quant = get_examples_for_quantization(train_dataset, n_samples) | ||
|
||
modelq = quantize_and_save(cfg, model, tokenizer, examples_for_quant) | ||
else: | ||
print("Loading quantized model...") | ||
modelq = get_quantized_model(cfg) | ||
|
||
push_model(cfg, modelq, tokenizer) | ||
|
||
if __name__ == "__main__": | ||
main() | ||
|
||
|
||
# Load configure | ||
# Load dataset | ||
# Load tokenizer | ||
# Prepare database | ||
# Load previous model, final checkpoint | ||
|
||
|
||
# --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False | ||
# accelerate launch ./scripts/finetune.py ./examples/llama-2/lora.yml --merge_lora --lora_model_dir="./lora-out" --load_in_8bit=False --load_in_4bit=False | ||
# CUDA_VISIBLE_DEVICES="1" accelerate launch ./scripts/finetune.py ./examples/llama-2/lora.yml --merge_lora --lora_model_dir="./lora-out" --load_in_8bit=False --load_in_4bit=False | ||
|
||
# HUB_MODEL_ID="Glavin001/llama-2-7b-alpaca_2k_test" accelerate launch ./scripts/quantize.py | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Glavin001 delete test notes |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,6 +82,11 @@ def normalize_config(cfg): | |
|
||
log_gpu_memory_usage(LOG, "baseline", cfg.device) | ||
|
||
if os.environ.get("WANDB_PROJECT") and len(os.environ.get("WANDB_PROJECT", "")) > 0: | ||
cfg.wandb_project = os.environ.get("WANDB_PROJECT") | ||
|
||
if os.environ.get("HUB_MODEL_ID") and len(os.environ.get("HUB_MODEL_ID", "")) > 0: | ||
cfg.hub_model_id = os.environ.get("HUB_MODEL_ID") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI: this is used for upcoming work of starting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might be better off in the |
||
|
||
def validate_config(cfg): | ||
if cfg.max_packed_sequence_len and cfg.sample_packing: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: @Glavin001 should make this based off the config