Skip to content

Commit

Permalink
Add --quantize option to finetune script, fix auto_gptq logging
Browse files Browse the repository at this point in the history
  • Loading branch information
Glavin001 committed Sep 11, 2023
1 parent 19a30cf commit 894a4be
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ Pass the appropriate flag to the train command:
Add below flag to train command above

```bash
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
--merge_lora --lora_model_dir="./completed-model"
```

If you run out of CUDA memory, you can try to merge in system RAM with
Expand Down
35 changes: 33 additions & 2 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""

import gc

import importlib
import logging
import os
Expand Down Expand Up @@ -90,7 +92,7 @@ def do_merge_lora(
):
new_cfg = DictDefault({
**cfg,
'lora_model_dir': cfg['output_dir'],
'lora_model_dir': cfg.get('lora_model_dir', cfg['output_dir']),
'load_in_8bit': False,
'load_in_4bit': False,
})
Expand Down Expand Up @@ -285,11 +287,24 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.quantize:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)

tokenizer = load_tokenizer(parsed_cfg)
# Load merged model with AutoGPTQ
merged_model = load_merged_model(parsed_cfg)

# Quantize & save
n_samples = 128
examples = get_examples_for_quantization(dataset_meta.train_dataset, n_samples)
quantize_and_save(parsed_cfg, merged_model, tokenizer, examples)

else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
model, tokenizer = train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
# model, tokenizer = train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
# tokenizer = None
should_quantize = True

Expand All @@ -310,8 +325,24 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
# })
# lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
# do_merge_lora(cfg=new_cfg, cli_args=parsed_cli_args)

def log_gpu_memory():
print("GPU Memory:", torch.cuda.memory_allocated())

log_gpu_memory()
print(len(gc.get_referrers(model)))
print(sys.getrefcount(model))

# TODO: release old model from GPU memory
print(gc.collect())
del model
# del tokenizer
print(gc.collect())
torch.cuda.empty_cache()
print(gc.collect())

log_gpu_memory()

do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)

# Load merged model with AutoGPTQ
Expand Down
28 changes: 27 additions & 1 deletion scripts/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
from pathlib import Path
import logging
import re

import torch
from datasets import load_dataset, Dataset
Expand All @@ -26,7 +27,32 @@
from axolotl.utils.quantize import load_merged_model, get_quantized_model, quantize_and_save, push_model, get_quantized_model_id, get_quantized_model_dir, get_examples_for_quantization

configure_logging()
LOG = logging.getLogger("axolotl")
LOG = logging.getLogger("axolotl.quantize")

import debugpy
debugpy.listen(('0.0.0.0', 5678))
debugpy.wait_for_client()
debugpy.breakpoint()

class ProgressExtractingHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
progress_info = self.extract_progress(log_entry)
if progress_info:
print(f"Progress: {progress_info}")

@staticmethod
def extract_progress(log_entry):
# [2023-09-11 07:20:37,502] [INFO] [auto_gptq.modeling._base.quantize:364] [PID:3962] [RANK:0] Quantizing self_attn.k_proj in layer 4/32...
match = re.search(r'layer (\d+/\d+)', log_entry)
return match.group(1) if match else None
# [2023-09-11 07:27:52,208] [INFO] [auto_gptq.modeling._utils.pack_model:129] [PID:3962] [RANK:0] model.layers.15.self_attn.o_proj

handler = ProgressExtractingHandler()
# logging.getLogger('auto_gptq.modeling._base.quantize').addHandler(handler)
logger = logging.getLogger('auto_gptq.modeling._base.quantize')
logger.setLevel(logging.DEBUG)
logger.addHandler(handler)

# logging.basicConfig(
# format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.DEBUG, datefmt="%Y-%m-%d %H:%M:%S"
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/common/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=5)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
quantize: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def format(self, record):
"level": "DEBUG",
"propagate": False,
},
"auto_gptq": {
"handlers": ["color_console"],
"level": "DEBUG",
"propagate": False,
},
},
}

Expand Down
8 changes: 4 additions & 4 deletions src/axolotl/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ def load_merged_model(cfg: DictDefault):
# Check if the merged model exists
if not merged_out_dir.exists():
# If not, merge the model
print("Merged model not found. Merging...")
# model, tokenizer = load_model(cfg, inference=True)
raise FileNotFoundError("Merged model not found. Please ensure the model has been merged.")
# do_merge_lora_model_and_tokenizer(cfg=cfg, model=model, tokenizer=tokenizer)
raise NotImplementedError("Merging model is not implemented yet.")
# raise NotImplementedError("Merging model is not implemented yet.")

# load un-quantized model, by default, the model will always be loaded into CPU memory
model = AutoGPTQForCausalLM.from_pretrained(merged_out_dir, quantize_config)
Expand Down Expand Up @@ -114,7 +113,8 @@ def get_quantized_model_dir(cfg: DictDefault):
# def get_quantized_model_dir(cfg: DictDefault, quantize_config):
if not cfg.output_dir:
raise ValueError("Missing output_dir in the configuration.")
return f"{cfg.output_dir.lstrip('./')}-GPTQ"
p = Path(cfg.output_dir) / "quantized"
return str(p).lstrip('./')

def get_examples_for_quantization(dataset, n_samples):
print("Loading dataset...")
Expand Down

0 comments on commit 894a4be

Please sign in to comment.