Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AutoGPTQ quantization script #545

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
84d4476
WIP Add training callback to send predictions to WandB table
Glavin001 Sep 3, 2023
766875f
Merge branch 'main' of github.com:OpenAccess-AI-Collective/axolotl in…
Glavin001 Sep 5, 2023
0c743e3
WIP improve wandb table reporting callback
Glavin001 Sep 5, 2023
5a7f301
WIP improve wandb table reporting callback (cont)
Glavin001 Sep 5, 2023
8c7b7c5
Add VSCode launching for debugging
Glavin001 Sep 5, 2023
88c31f1
Add tiny llama example
Glavin001 Sep 5, 2023
06a44de
WIP attempt to improve post-eval prediction generation for table
Glavin001 Sep 7, 2023
ab3cffa
WIP attempt to improve post-eval prediction generation for table - pa…
Glavin001 Sep 8, 2023
b22d1c6
WIP batch generation
Glavin001 Sep 8, 2023
6f3216e
WIP attempt to handle sample_packing using position_ids for wandb pre…
Glavin001 Sep 8, 2023
e9eae77
WIP add code for debugging
Glavin001 Sep 8, 2023
83e6b29
Fix sample_packing support for wandb prediction table
Glavin001 Sep 9, 2023
aaf4d1e
Clean up code for PR review
Glavin001 Sep 9, 2023
e4c1a2e
WIP Add AutoGPTQ quantization script
Glavin001 Sep 9, 2023
19a30cf
WIP Integrate quantization into finetune script
Glavin001 Sep 10, 2023
894a4be
Add --quantize option to finetune script, fix auto_gptq logging
Glavin001 Sep 11, 2023
24c0483
Disable quantizing directly after fine tuning
Glavin001 Sep 11, 2023
14d26e1
Add eval_table_size, eval_table_max_new_tokens configs & clean up code
Glavin001 Sep 12, 2023
c6c54ee
Clean up PR, delete VSCode config, add tiny-llama example
Glavin001 Sep 12, 2023
dee3d54
Add eval_table_size, eval_table_max_new_tokens configs & clean up code
Glavin001 Sep 12, 2023
09b16d8
Clean up PR, delete VSCode config, add tiny-llama example
Glavin001 Sep 12, 2023
578d8b6
Merge branch 'feat/wandb-pred-table' of github.com:Glavin001/axolotl …
Glavin001 Sep 12, 2023
cf23998
WIP quantize model & push model
Glavin001 Sep 12, 2023
8a26ab3
WIP
Glavin001 Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ Pass the appropriate flag to the train command:
Add below flag to train command above

```bash
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
--merge_lora --lora_model_dir="./completed-model"
```

If you run out of CUDA memory, you can try to merge in system RAM with
Expand Down
70 changes: 70 additions & 0 deletions examples/llama-2/lora-short.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
# val_set_size: 0.01
val_set_size: 0.001
output_dir: ./lora-out

sequence_len: 4096
sample_packing: true

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
# num_epochs: 3
# num_epochs: 1
num_epochs: 0.1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
91 changes: 86 additions & 5 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""

import gc

import importlib
import logging
import os
Expand Down Expand Up @@ -27,6 +29,7 @@
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.wandb import setup_wandb_env_vars
from axolotl.utils.quantize import get_examples_for_quantization, load_merged_model, quantize_and_save

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
Expand Down Expand Up @@ -57,27 +60,44 @@ def get_multi_line_input() -> Optional[str]:
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction

def get_merged_out_dir(cfg: DictDefault):
return Path(cfg.output_dir) / "merged"

def do_merge_lora(
def do_merge_lora_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
model,
tokenizer,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True

LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)

merged_out_dir = str(get_merged_out_dir(cfg))

if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
merged_out_dir,
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
tokenizer.save_pretrained(merged_out_dir)

def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
new_cfg = DictDefault({
**cfg,
'lora_model_dir': cfg.get('lora_model_dir', cfg['output_dir']),
'load_in_8bit': False,
'load_in_4bit': False,
})
model, tokenizer = load_model_and_tokenizer(cfg=new_cfg, cli_args=cli_args)
do_merge_lora_model_and_tokenizer(cfg=new_cfg, model=model, tokenizer=tokenizer)

def shard(
*,
Expand Down Expand Up @@ -267,11 +287,72 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.quantize:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)

tokenizer = load_tokenizer(parsed_cfg)
# Load merged model with AutoGPTQ
merged_model = load_merged_model(parsed_cfg)

# Quantize & save
n_samples = 128
examples = get_examples_for_quantization(dataset_meta.train_dataset, n_samples)
quantize_and_save(parsed_cfg, merged_model, tokenizer, examples)

else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
# model, tokenizer = train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
# tokenizer = None
should_quantize = False

if should_quantize:
# Merge model
# do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
# do_merge_lora_model_and_tokenizer(cfg=parsed_cfg, model=model, tokenizer=tokenizer)
# new_cfg = parsed_cfg.copy()
# new_cfg['lora_model_dir'] = new_cfg['output_dir']
# new_cfg['load_in_8bit'] = False
# new_cfg['load_in_4bit'] = False

# new_cfg = DictDefault({
# **parsed_cfg,
# 'lora_model_dir': parsed_cfg['output_dir'],
# 'load_in_8bit': False,
# 'load_in_4bit': False,
# })
# lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
# do_merge_lora(cfg=new_cfg, cli_args=parsed_cli_args)

def log_gpu_memory():
print("GPU Memory:", torch.cuda.memory_allocated())

log_gpu_memory()
print(len(gc.get_referrers(model)))
print(sys.getrefcount(model))

# TODO: release old model from GPU memory
print(gc.collect())
del model
# del tokenizer
print(gc.collect())
torch.cuda.empty_cache()
print(gc.collect())

log_gpu_memory()

do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
Copy link
Contributor Author

@Glavin001 Glavin001 Sep 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Help Wanted

I kept getting:

Expected a cuda device, but got: cpu

when calling do_merge_lora

Running nvidia-smi always showed lots of GPU memory still taken up / unreleased.


# Load merged model with AutoGPTQ
merged_model = load_merged_model(parsed_cfg)

# Quantize & save
n_samples = 128
examples = get_examples_for_quantization(dataset_meta.train_dataset, n_samples)
quantize_and_save(parsed_cfg, merged_model, tokenizer, examples)



if __name__ == "__main__":
Expand Down
147 changes: 147 additions & 0 deletions scripts/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/

# import debugpy
# debugpy.listen(('0.0.0.0', 5678))
# debugpy.wait_for_client()
# debugpy.breakpoint()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Glavin001 clean up old code


import json
import random
import time
from pathlib import Path
import logging
import re

import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, LlamaTokenizer, TextGenerationPipeline
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from axolotl.prompters import AlpacaPrompter
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
# from scripts.finetune import load_cfg
from finetune import load_cfg, get_merged_out_dir, do_merge_lora_model_and_tokenizer, load_datasets

from axolotl.utils.quantize import load_merged_model, get_quantized_model, quantize_and_save, push_model, get_quantized_model_id, get_quantized_model_dir, get_examples_for_quantization

configure_logging()
LOG = logging.getLogger("axolotl.quantize")

import debugpy
debugpy.listen(('0.0.0.0', 5678))
debugpy.wait_for_client()
debugpy.breakpoint()

class ProgressExtractingHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
progress_info = self.extract_progress(log_entry)
if progress_info:
print(f"Progress: {progress_info}")

@staticmethod
def extract_progress(log_entry):
# [2023-09-11 07:20:37,502] [INFO] [auto_gptq.modeling._base.quantize:364] [PID:3962] [RANK:0] Quantizing self_attn.k_proj in layer 4/32...
match = re.search(r'layer (\d+/\d+)', log_entry)
return match.group(1) if match else None
# [2023-09-11 07:27:52,208] [INFO] [auto_gptq.modeling._utils.pack_model:129] [PID:3962] [RANK:0] model.layers.15.self_attn.o_proj

handler = ProgressExtractingHandler()
# logging.getLogger('auto_gptq.modeling._base.quantize').addHandler(handler)
logger = logging.getLogger('auto_gptq.modeling._base.quantize')
logger.setLevel(logging.DEBUG)
logger.addHandler(handler)

# logging.basicConfig(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Help Wanted

I couldn't get any logging to work from AutoGPTQ. Would be nice to fix logging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my old code which works when not calling Axolotl's configure_logging()

# format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.DEBUG, datefmt="%Y-%m-%d %H:%M:%S"
# )

# LOG.setLevel(logging.DEBUG)
# handler = logging.StreamHandler()
# formatter = logging.Formatter('%(asctime)s %(levelname)s [%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
# handler.setFormatter(formatter)
# LOG.addHandler(handler)

print("Done importing...")

## CHANGE BELOW ##
# config_path: Path = Path("./examples/llama-2/lora.yml")
config_path: Path = Path("./examples/llama-2/lora-short.yml")

# pretrained_model_dir = "facebook/opt-125m"
# quantized_model_dir = "opt-125m-4bit"
dataset_name = "teknium/GPT4-LLM-Cleaned"
# huggingface_username = "CHANGE_ME"
## CHANGE ABOVE

def main():
print("Starting...")
# return
# prompt = "<|prompt|>How can entrepreneurs start building their own communities even before launching their product?</s><|answer|>"

should_quantize = True
# tokenizer = get_tokenizer()

cfg = load_cfg(config_path)

cfg['lora_model_dir'] = cfg['output_dir']

LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)

if should_quantize:
print("Quantizing...")

print("Loading dataset...")
datasets = load_datasets(cfg=cfg, cli_args=TrainerCliArgs())
train_dataset = datasets.train_dataset
n_samples = 128
# # n_samples = 2
# examples = train_dataset.shuffle(seed=42).select(
# [
# random.randrange(0, len(train_dataset) - 1) # nosec
# for _ in range(n_samples)
# ]
# )

LOG.info("loading model and (optionally) peft_config...")
# model, peft_config = load_model(cfg, tokenizer, inference=True)
model = load_merged_model(cfg)
# model = get_model()

# examples = load_data(dataset_name, tokenizer, n_samples)

# print(examples)
# examples_for_quant = [
# {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]}
# for example in examples
# ]
# print(examples_for_quant)
examples_for_quant = get_examples_for_quantization(train_dataset, n_samples)

modelq = quantize_and_save(cfg, model, tokenizer, examples_for_quant)
else:
print("Loading quantized model...")
modelq = get_quantized_model(cfg)

push_model(cfg, modelq, tokenizer)

if __name__ == "__main__":
main()


# Load configure
# Load dataset
# Load tokenizer
# Prepare database
# Load previous model, final checkpoint


# --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
# accelerate launch ./scripts/finetune.py ./examples/llama-2/lora.yml --merge_lora --lora_model_dir="./lora-out" --load_in_8bit=False --load_in_4bit=False
# CUDA_VISIBLE_DEVICES="1" accelerate launch ./scripts/finetune.py ./examples/llama-2/lora.yml --merge_lora --lora_model_dir="./lora-out" --load_in_8bit=False --load_in_4bit=False

# HUB_MODEL_ID="Glavin001/llama-2-7b-alpaca_2k_test" accelerate launch ./scripts/quantize.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Glavin001 delete test notes


1 change: 1 addition & 0 deletions src/axolotl/common/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=5)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
quantize: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def format(self, record):
"level": "DEBUG",
"propagate": False,
},
"auto_gptq": {
"handlers": ["color_console"],
"level": "DEBUG",
"propagate": False,
},
},
}

Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def normalize_config(cfg):

log_gpu_memory_usage(LOG, "baseline", cfg.device)

if os.environ.get("WANDB_PROJECT") and len(os.environ.get("WANDB_PROJECT", "")) > 0:
cfg.wandb_project = os.environ.get("WANDB_PROJECT")

if os.environ.get("HUB_MODEL_ID") and len(os.environ.get("HUB_MODEL_ID", "")) > 0:
cfg.hub_model_id = os.environ.get("HUB_MODEL_ID")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: this is used for upcoming work of starting scripts/finetune.py and having it run without any custom / run specific / user specific info in the Axolotl config.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be better off in the setup_wandb_env_vars function


def validate_config(cfg):
if cfg.max_packed_sequence_len and cfg.sample_packing:
Expand Down
Loading
Loading