diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 0000000000..3f60f05d36
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,37 @@
+{
+ // Use IntelliSense to learn about possible attributes.
+ // Hover to view descriptions of existing attributes.
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "name": "Python: Remote Attach",
+ "type": "python",
+ "request": "attach",
+ "connect": {
+ "host": "0.0.0.0",
+ "port": 5678
+ },
+ "pathMappings": [
+ {
+ "localRoot": "${workspaceFolder}",
+ "remoteRoot": "/workspace/axolotl/"
+ }
+ ],
+ "justMyCode": false
+ },
+ {
+ "name": "train",
+ "type": "python",
+ "request": "launch",
+ "module": "accelerate.commands.launch",
+ "args": [
+ "${workspaceFolder}/scripts/finetune.py",
+ // "${file}",
+ "${workspaceFolder}/examples/llama-2/tiny-random.yml",
+ ], // other args comes after train.py
+ "console": "integratedTerminal",
+ // "env": {"CUDA_LAUNCH_BLOCKING": "1"}
+ },
+ ]
+}
diff --git a/README.md b/README.md
index 775592efe6..4189165577 100644
--- a/README.md
+++ b/README.md
@@ -703,7 +703,7 @@ Pass the appropriate flag to the train command:
Add below flag to train command above
```bash
---merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
+--merge_lora --lora_model_dir="./completed-model"
```
If you run out of CUDA memory, you can try to merge in system RAM with
diff --git a/docker-compose.yaml b/docker-compose.yaml
index a16be726cf..6708dbf6a0 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -1,9 +1,10 @@
# version: '3.8'
services:
axolotl:
- build:
- context: .
- dockerfile: ./docker/Dockerfile
+ # build:
+ # context: .
+ # dockerfile: ./docker/Dockerfile
+ image: winglian/axolotl:main-py3.10-cu118-2.0.1
volumes:
- .:/workspace/axolotl
- ~/.cache/huggingface/:/root/.cache/huggingface/
@@ -15,6 +16,8 @@ services:
- GIT_COMMITTER_NAME=${GIT_COMMITTER_NAME}
- GIT_COMMITTER_EMAIL=${GIT_COMMITTER_EMAIL}
- WANDB_API_KEY=${WANDB_API_KEY}
+ ports:
+ - "5678:5678"
deploy:
resources:
reservations:
diff --git a/examples/llama-2/llama-68.yml b/examples/llama-2/llama-68.yml
new file mode 100644
index 0000000000..85616aba90
--- /dev/null
+++ b/examples/llama-2/llama-68.yml
@@ -0,0 +1,70 @@
+base_model: JackFram/llama-68m
+base_model_config: JackFram/llama-68m
+
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+is_llama_derived_model: true
+
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: mhenrichsen/alpaca_2k_test
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.01
+output_dir: ./lora-out
+
+# sequence_len: 4096
+sequence_len: 2048
+sample_packing: true
+
+adapter: lora
+lora_model_dir:
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 3
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16: false
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 10
+eval_steps: 20
+eval_table_size: 5
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/examples/llama-2/lora-short.yml b/examples/llama-2/lora-short.yml
new file mode 100644
index 0000000000..bd2b51b962
--- /dev/null
+++ b/examples/llama-2/lora-short.yml
@@ -0,0 +1,70 @@
+base_model: meta-llama/Llama-2-7b-hf
+base_model_config: meta-llama/Llama-2-7b-hf
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+is_llama_derived_model: true
+
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: mhenrichsen/alpaca_2k_test
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+# val_set_size: 0.01
+val_set_size: 0.001
+output_dir: ./lora-out
+
+sequence_len: 4096
+sample_packing: true
+
+adapter: lora
+lora_model_dir:
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+# num_epochs: 3
+# num_epochs: 1
+num_epochs: 0.1
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16: false
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 10
+eval_steps: 20
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml
index a54799b408..2438b0d884 100644
--- a/examples/llama-2/lora.yml
+++ b/examples/llama-2/lora.yml
@@ -27,7 +27,7 @@ lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
-wandb_project:
+wandb_project: test-issue-490-7b-2
wandb_entity:
wandb_watch:
wandb_run_id:
@@ -56,6 +56,8 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
+eval_table_size: 5
+eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml
index dd029859ed..ef20d9fbe3 100644
--- a/examples/llama-2/qlora.yml
+++ b/examples/llama-2/qlora.yml
@@ -58,6 +58,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
+eval_table_size: 5
save_steps:
debug:
deepspeed:
diff --git a/examples/llama-2/tiny-llama.yml b/examples/llama-2/tiny-llama.yml
new file mode 100644
index 0000000000..a53c9c831b
--- /dev/null
+++ b/examples/llama-2/tiny-llama.yml
@@ -0,0 +1,69 @@
+base_model: PY007/TinyLlama-1.1B-step-50K-105b
+base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
+
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+is_llama_derived_model: true
+
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: mhenrichsen/alpaca_2k_test
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.01
+output_dir: ./lora-out
+
+sequence_len: 4096
+sample_packing: true
+
+adapter: lora
+lora_model_dir:
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 3
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16: false
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 10
+eval_steps: 20
+eval_table_size: 5
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/examples/llama-2/tiny-puffed-llama.yml b/examples/llama-2/tiny-puffed-llama.yml
new file mode 100644
index 0000000000..ac02b7b27b
--- /dev/null
+++ b/examples/llama-2/tiny-puffed-llama.yml
@@ -0,0 +1,71 @@
+base_model: PY007/TinyLlama-1.1B-step-50K-105b
+base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
+
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+is_llama_derived_model: true
+
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+
+datasets:
+ # - path: mhenrichsen/alpaca_2k_test
+ # type: alpaca
+ - path: LDJnr/Puffin
+ type: sharegpt:chat
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.01
+output_dir: ./lora-tiny-puffed-out
+
+sequence_len: 2048
+sample_packing: true
+
+adapter: lora
+lora_model_dir:
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 4
+num_epochs: 2
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16: false
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 10
+eval_steps: 20
+eval_table_size: 10
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/scripts/finetune.py b/scripts/finetune.py
index b998edc798..bcf828202c 100644
--- a/scripts/finetune.py
+++ b/scripts/finetune.py
@@ -1,5 +1,7 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
+import gc
+
import importlib
import logging
import os
@@ -27,6 +29,7 @@
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.wandb import setup_wandb_env_vars
+from axolotl.utils.quantize import get_examples_for_quantization, load_merged_model, push_model, quantize_and_save
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
@@ -57,27 +60,44 @@ def get_multi_line_input() -> Optional[str]:
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
+def get_merged_out_dir(cfg: DictDefault):
+ return Path(cfg.output_dir) / "merged"
-def do_merge_lora(
+def do_merge_lora_model_and_tokenizer(
*,
cfg: DictDefault,
- cli_args: TrainerCliArgs,
+ model,
+ tokenizer,
):
- model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
+ merged_out_dir = str(get_merged_out_dir(cfg))
+
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
- str(Path(cfg.output_dir) / "merged"),
+ merged_out_dir,
safe_serialization=safe_serialization,
)
- tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
+ tokenizer.save_pretrained(merged_out_dir)
+def do_merge_lora(
+ *,
+ cfg: DictDefault,
+ cli_args: TrainerCliArgs,
+):
+ new_cfg = DictDefault({
+ **cfg,
+ 'lora_model_dir': cfg.get('lora_model_dir', cfg['output_dir']),
+ 'load_in_8bit': False,
+ 'load_in_4bit': False,
+ })
+ model, tokenizer = load_model_and_tokenizer(cfg=new_cfg, cli_args=cli_args)
+ do_merge_lora_model_and_tokenizer(cfg=new_cfg, model=model, tokenizer=tokenizer)
def shard(
*,
@@ -135,6 +155,9 @@ def do_inference(
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
+ print(prompt)
+ print("=" * 20)
+
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
@@ -267,11 +290,74 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
+ elif parsed_cli_args.quantize:
+ dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
+
+ tokenizer = load_tokenizer(parsed_cfg)
+ # Load merged model with AutoGPTQ
+ merged_model = load_merged_model(parsed_cfg)
+
+ # Quantize & save
+ n_samples = 128
+ examples = get_examples_for_quantization(dataset_meta.train_dataset, n_samples)
+ quantize_and_save(parsed_cfg, merged_model, tokenizer, examples)
+ elif parsed_cli_args.push:
+ model, tokenizer = load_model_and_tokenizer(cfg=parsed_cfg, cli_args=parsed_cli_args)
+ push_model(cfg=parsed_cfg, model=model, tokenizer=tokenizer)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
+ # model, tokenizer = train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
+ # tokenizer = None
+ should_quantize = False
+
+ if should_quantize:
+ # Merge model
+ # do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
+ # do_merge_lora_model_and_tokenizer(cfg=parsed_cfg, model=model, tokenizer=tokenizer)
+ # new_cfg = parsed_cfg.copy()
+ # new_cfg['lora_model_dir'] = new_cfg['output_dir']
+ # new_cfg['load_in_8bit'] = False
+ # new_cfg['load_in_4bit'] = False
+
+ # new_cfg = DictDefault({
+ # **parsed_cfg,
+ # 'lora_model_dir': parsed_cfg['output_dir'],
+ # 'load_in_8bit': False,
+ # 'load_in_4bit': False,
+ # })
+ # lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
+ # do_merge_lora(cfg=new_cfg, cli_args=parsed_cli_args)
+
+ def log_gpu_memory():
+ print("GPU Memory:", torch.cuda.memory_allocated())
+
+ log_gpu_memory()
+ print(len(gc.get_referrers(model)))
+ print(sys.getrefcount(model))
+
+ # TODO: release old model from GPU memory
+ print(gc.collect())
+ del model
+ # del tokenizer
+ print(gc.collect())
+ torch.cuda.empty_cache()
+ print(gc.collect())
+
+ log_gpu_memory()
+
+ do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
+
+ # Load merged model with AutoGPTQ
+ merged_model = load_merged_model(parsed_cfg)
+
+ # Quantize & save
+ n_samples = 128
+ examples = get_examples_for_quantization(dataset_meta.train_dataset, n_samples)
+ quantize_and_save(parsed_cfg, merged_model, tokenizer, examples)
+
if __name__ == "__main__":
diff --git a/scripts/quantize.py b/scripts/quantize.py
new file mode 100644
index 0000000000..10bddf6a7c
--- /dev/null
+++ b/scripts/quantize.py
@@ -0,0 +1,147 @@
+# pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
+
+# import debugpy
+# debugpy.listen(('0.0.0.0', 5678))
+# debugpy.wait_for_client()
+# debugpy.breakpoint()
+
+import json
+import random
+import time
+from pathlib import Path
+import logging
+import re
+
+import torch
+from datasets import load_dataset, Dataset
+from transformers import AutoTokenizer, LlamaTokenizer, TextGenerationPipeline
+from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
+from axolotl.prompters import AlpacaPrompter
+from axolotl.utils.models import load_model, load_tokenizer
+from axolotl.common.cli import TrainerCliArgs
+from axolotl.logging_config import configure_logging
+from axolotl.utils.dict import DictDefault
+# from scripts.finetune import load_cfg
+from finetune import load_cfg, get_merged_out_dir, do_merge_lora_model_and_tokenizer, load_datasets
+
+from axolotl.utils.quantize import load_merged_model, get_quantized_model, quantize_and_save, push_model, get_quantized_model_id, get_quantized_model_dir, get_examples_for_quantization
+
+configure_logging()
+LOG = logging.getLogger("axolotl.quantize")
+
+import debugpy
+debugpy.listen(('0.0.0.0', 5678))
+debugpy.wait_for_client()
+debugpy.breakpoint()
+
+class ProgressExtractingHandler(logging.Handler):
+ def emit(self, record):
+ log_entry = self.format(record)
+ progress_info = self.extract_progress(log_entry)
+ if progress_info:
+ print(f"Progress: {progress_info}")
+
+ @staticmethod
+ def extract_progress(log_entry):
+# [2023-09-11 07:20:37,502] [INFO] [auto_gptq.modeling._base.quantize:364] [PID:3962] [RANK:0] Quantizing self_attn.k_proj in layer 4/32...
+ match = re.search(r'layer (\d+/\d+)', log_entry)
+ return match.group(1) if match else None
+ # [2023-09-11 07:27:52,208] [INFO] [auto_gptq.modeling._utils.pack_model:129] [PID:3962] [RANK:0] model.layers.15.self_attn.o_proj
+
+handler = ProgressExtractingHandler()
+# logging.getLogger('auto_gptq.modeling._base.quantize').addHandler(handler)
+logger = logging.getLogger('auto_gptq.modeling._base.quantize')
+logger.setLevel(logging.DEBUG)
+logger.addHandler(handler)
+
+# logging.basicConfig(
+# format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.DEBUG, datefmt="%Y-%m-%d %H:%M:%S"
+# )
+
+# LOG.setLevel(logging.DEBUG)
+# handler = logging.StreamHandler()
+# formatter = logging.Formatter('%(asctime)s %(levelname)s [%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
+# handler.setFormatter(formatter)
+# LOG.addHandler(handler)
+
+print("Done importing...")
+
+## CHANGE BELOW ##
+# config_path: Path = Path("./examples/llama-2/lora.yml")
+config_path: Path = Path("./examples/llama-2/lora-short.yml")
+
+# pretrained_model_dir = "facebook/opt-125m"
+# quantized_model_dir = "opt-125m-4bit"
+dataset_name = "teknium/GPT4-LLM-Cleaned"
+# huggingface_username = "CHANGE_ME"
+## CHANGE ABOVE
+
+def main():
+ print("Starting...")
+ # return
+ # prompt = "<|prompt|>How can entrepreneurs start building their own communities even before launching their product?<|answer|>"
+
+ should_quantize = True
+ # tokenizer = get_tokenizer()
+
+ cfg = load_cfg(config_path)
+
+ cfg['lora_model_dir'] = cfg['output_dir']
+
+ LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
+ tokenizer = load_tokenizer(cfg)
+
+ if should_quantize:
+ print("Quantizing...")
+
+ print("Loading dataset...")
+ datasets = load_datasets(cfg=cfg, cli_args=TrainerCliArgs())
+ train_dataset = datasets.train_dataset
+ n_samples = 128
+ # # n_samples = 2
+ # examples = train_dataset.shuffle(seed=42).select(
+ # [
+ # random.randrange(0, len(train_dataset) - 1) # nosec
+ # for _ in range(n_samples)
+ # ]
+ # )
+
+ LOG.info("loading model and (optionally) peft_config...")
+ # model, peft_config = load_model(cfg, tokenizer, inference=True)
+ model = load_merged_model(cfg)
+ # model = get_model()
+
+ # examples = load_data(dataset_name, tokenizer, n_samples)
+
+ # print(examples)
+ # examples_for_quant = [
+ # {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]}
+ # for example in examples
+ # ]
+ # print(examples_for_quant)
+ examples_for_quant = get_examples_for_quantization(train_dataset, n_samples)
+
+ modelq = quantize_and_save(cfg, model, tokenizer, examples_for_quant)
+ else:
+ print("Loading quantized model...")
+ modelq = get_quantized_model(cfg)
+
+ push_model(cfg, modelq, tokenizer)
+
+if __name__ == "__main__":
+ main()
+
+
+# Load configure
+# Load dataset
+# Load tokenizer
+# Prepare database
+# Load previous model, final checkpoint
+
+
+# --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
+# accelerate launch ./scripts/finetune.py ./examples/llama-2/lora.yml --merge_lora --lora_model_dir="./lora-out" --load_in_8bit=False --load_in_4bit=False
+# CUDA_VISIBLE_DEVICES="1" accelerate launch ./scripts/finetune.py ./examples/llama-2/lora.yml --merge_lora --lora_model_dir="./lora-out" --load_in_8bit=False --load_in_4bit=False
+
+# HUB_MODEL_ID="Glavin001/llama-2-7b-alpaca_2k_test" accelerate launch ./scripts/quantize.py
+
diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py
index 62f2b1061a..ef439d3cf0 100644
--- a/src/axolotl/common/cli.py
+++ b/src/axolotl/common/cli.py
@@ -9,6 +9,7 @@
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
+from axolotl.utils.quantize import get_quantized_model
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
@@ -25,6 +26,8 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=5)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
+ quantize: bool = field(default=False)
+ push: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
@@ -38,6 +41,8 @@ def load_model_and_tokenizer(
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model and (optionally) peft_config...")
- model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
+ # model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
+ # TEMP
+ model = get_quantized_model(cfg)
return model, tokenizer
diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py
index 8f473aa240..41a8e41103 100644
--- a/src/axolotl/logging_config.py
+++ b/src/axolotl/logging_config.py
@@ -61,6 +61,11 @@ def format(self, record):
"level": "DEBUG",
"propagate": False,
},
+ "auto_gptq": {
+ "handlers": ["color_console"],
+ "level": "DEBUG",
+ "propagate": False,
+ },
},
}
diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
index ef048082c1..d172d302d9 100644
--- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
+++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
@@ -193,7 +193,7 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
- if cu_seqlens is not None and max_seqlen is not None:
+ if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
@@ -261,6 +261,8 @@ def flashattn_forward(
if attention_mask is not None
else None,
)
+ if q_unpad.dtype != kv_unpad.dtype:
+ kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py
index 3f776537a5..3925c20d8c 100644
--- a/src/axolotl/utils/callbacks.py
+++ b/src/axolotl/utils/callbacks.py
@@ -15,6 +15,8 @@
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
+ GenerationConfig,
+ Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
@@ -22,6 +24,7 @@
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
+import wandb
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
barrier,
@@ -323,3 +326,177 @@ def on_evaluate(
metrics[key] = val
return BenchEvalCallback
+
+
+def log_prediction_callback_factory(trainer: Trainer, tokenizer):
+ class LogPredictionCallback(TrainerCallback):
+ """Callback to log prediction values during each evaluation"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.logged = False
+
+ def on_evaluate(
+ self,
+ args: AxolotlTrainingArguments,
+ state: TrainerState,
+ control: TrainerControl,
+ train_dataloader,
+ eval_dataloader,
+ **kwargs,
+ ):
+ eval_table_size = self.cfg.eval_table_size
+
+ if eval_table_size <= 0:
+ return control
+
+ trainer.model.eval()
+ device = torch.device(self.cfg.device)
+
+ generation_config = GenerationConfig(
+ max_new_tokens=self.cfg.eval_table_max_new_tokens,
+ bos_token_id=tokenizer.bos_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id,
+ do_sample=False,
+ use_cache=True,
+ return_dict_in_generate=True,
+ output_attentions=False,
+ output_hidden_states=False,
+ output_scores=False,
+ )
+
+ def logits_to_tokens(logits) -> str:
+ probabilities = torch.softmax(logits, dim=-1)
+ # Get the predicted token ids (the ones with the highest probability)
+ predicted_token_ids = torch.argmax(probabilities, dim=-1)
+ return predicted_token_ids
+
+ def find_ranges(lst):
+ ranges = []
+ start = 0
+ for i in range(1, len(lst)):
+ if lst[i] == 0:
+ ranges.append((start, i - 1))
+ start = i
+ end = len(lst) - 1
+ ranges.append((start, end))
+ return ranges
+
+ def log_table_from_dataloader(name: str, table_dataloader):
+ table = wandb.Table(
+ columns=[
+ "id",
+ "Prompt",
+ "Correct Completion",
+ "Predicted Completion (model.generate)",
+ "Predicted Completion (trainer.prediction_step)",
+ ]
+ )
+ row_index = 0
+
+ for batch in tqdm(table_dataloader):
+
+ if row_index > eval_table_size:
+ break
+
+ batch_labels = batch["labels"].to(device)
+ batch_input_ids = batch["input_ids"].to(device)
+
+ if "position_ids" in batch:
+ batch_pos_ids = batch["position_ids"].tolist()
+ else:
+ batch_pos_ids = [None] * len(batch["input_ids"])
+
+ (_, batch_logits, _) = trainer.prediction_step(
+ trainer.model,
+ batch,
+ prediction_loss_only=False,
+ )
+
+ prompt_token_ids_list = []
+ pred_step_token_ids_list = []
+ completion_token_ids_list = []
+
+ for input_ids_all, labels_all, pos_ids, logits in zip(
+ batch_input_ids, batch_labels, batch_pos_ids, batch_logits,
+ ):
+ if pos_ids is None:
+ pos_ranges = [(0, len(input_ids_all) - 1)]
+ else:
+ pos_ranges = find_ranges(pos_ids)
+
+ for pos_range in pos_ranges:
+ start, end = pos_range
+ if start == end:
+ continue
+
+ input_ids = input_ids_all[start : end + 1]
+ labels = labels_all[start : end + 1]
+
+ tokens_without_loss = labels == IGNORE_INDEX
+ tokens_with_loss = labels != IGNORE_INDEX
+ tokens_exclude_padding = input_ids != tokenizer.pad_token_id
+ prompt_token_includes = (
+ tokens_without_loss & tokens_exclude_padding
+ )
+
+ prompt_token_ids = input_ids[prompt_token_includes]
+ prompt_token_ids_list.append(prompt_token_ids)
+
+ completion_token_ids = input_ids[tokens_with_loss]
+ completion_token_ids_list.append(completion_token_ids)
+
+ pred_step_token_ids = logits_to_tokens(logits[start : end + 1])[tokens_with_loss]
+ pred_step_token_ids_list.append(pred_step_token_ids)
+
+ prompt_texts = tokenizer.batch_decode(
+ prompt_token_ids_list, skip_special_tokens=True
+ )
+ completion_texts = tokenizer.batch_decode(
+ completion_token_ids_list, skip_special_tokens=True
+ )
+ pred_step_texts = tokenizer.batch_decode(
+ pred_step_token_ids_list, skip_special_tokens=True
+ )
+
+ with torch.no_grad():
+ prompt_encoding = tokenizer(
+ prompt_texts, padding=True, return_tensors="pt"
+ ).to(self.cfg.device)
+ predictions = trainer.model.generate(
+ **prompt_encoding, generation_config=generation_config
+ )
+
+ prediction_all_tokens = predictions["sequences"].cpu().tolist()
+ prediction_without_prompt_tokens_list = []
+ for prompt_token_ids, prediction_tokens in zip(
+ prompt_token_ids_list, prediction_all_tokens
+ ):
+ prediction_without_prompt_tokens = prediction_tokens[
+ len(prompt_token_ids) :
+ ]
+ prediction_without_prompt_tokens_list.append(
+ prediction_without_prompt_tokens
+ )
+
+ predicted_texts = tokenizer.batch_decode(
+ prediction_without_prompt_tokens_list, skip_special_tokens=True
+ )
+
+ for prompt_text, completion_text, prediction_text, pred_step_text in zip(
+ prompt_texts, completion_texts, predicted_texts, pred_step_texts
+ ):
+ table.add_data(
+ row_index, prompt_text, completion_text, prediction_text, pred_step_text
+ )
+ row_index += 1
+
+ wandb.run.log({f"{name} - Predictions vs Ground Truth": table})
+
+ if is_main_process():
+ log_table_from_dataloader("Eval", eval_dataloader)
+
+ return control
+
+ return LogPredictionCallback
diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py
index 6de807eab9..a4e68869d8 100644
--- a/src/axolotl/utils/config.py
+++ b/src/axolotl/utils/config.py
@@ -48,6 +48,8 @@ def normalize_config(cfg):
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
+ cfg.eval_table_size = cfg.eval_table_size or 0
+ cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
@@ -82,6 +84,11 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device)
+ if os.environ.get("WANDB_PROJECT") and len(os.environ.get("WANDB_PROJECT", "")) > 0:
+ cfg.wandb_project = os.environ.get("WANDB_PROJECT")
+
+ if os.environ.get("HUB_MODEL_ID") and len(os.environ.get("HUB_MODEL_ID", "")) > 0:
+ cfg.hub_model_id = os.environ.get("HUB_MODEL_ID")
def validate_config(cfg):
if cfg.max_packed_sequence_len and cfg.sample_packing:
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index 2000b1aee8..eae08f7f70 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -296,10 +296,10 @@ def load_model(
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
- and cfg.sequence_len >= model.config.max_position_embeddings
+ and cfg.sequence_len > model.config.max_position_embeddings
):
LOG.warning(
- f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
+ f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len
diff --git a/src/axolotl/utils/quantize.py b/src/axolotl/utils/quantize.py
new file mode 100644
index 0000000000..39318836a9
--- /dev/null
+++ b/src/axolotl/utils/quantize.py
@@ -0,0 +1,139 @@
+# pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
+
+# import debugpy
+# debugpy.listen(('0.0.0.0', 5678))
+# debugpy.wait_for_client()
+# debugpy.breakpoint()
+
+import json
+import random
+import time
+from pathlib import Path
+import logging
+
+# import torch
+# from datasets import load_dataset, Dataset
+# from transformers import AutoTokenizer, LlamaTokenizer, TextGenerationPipeline
+from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
+# from axolotl.prompters import AlpacaPrompter
+# from axolotl.utils.models import load_model, load_tokenizer
+# from axolotl.common.cli import TrainerCliArgs
+# from axolotl.logging_config import configure_logging
+from axolotl.utils.dict import DictDefault
+# from finetune import load_cfg, get_merged_out_dir, do_merge_lora_model_and_tokenizer
+
+# configure_logging()
+# LOG = logging.getLogger("axolotl")
+
+quantize_config = BaseQuantizeConfig(
+ bits=4, # quantize model to 4-bit
+ group_size=128, # it is recommended to set the value to 128
+ desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
+)
+
+def get_merged_out_dir(cfg: DictDefault):
+ return Path(cfg.output_dir) / "merged"
+
+def load_merged_model(cfg: DictDefault):
+ print("Loading merged model...")
+
+ merged_out_dir = get_merged_out_dir(cfg)
+
+ # Check if the merged model exists
+ if not merged_out_dir.exists():
+ # If not, merge the model
+ raise FileNotFoundError("Merged model not found. Please ensure the model has been merged.")
+ # do_merge_lora_model_and_tokenizer(cfg=cfg, model=model, tokenizer=tokenizer)
+ # raise NotImplementedError("Merging model is not implemented yet.")
+
+ # load un-quantized model, by default, the model will always be loaded into CPU memory
+ model = AutoGPTQForCausalLM.from_pretrained(merged_out_dir, quantize_config)
+ # model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
+ print("Model loaded.")
+ return model
+
+def get_quantized_model(cfg: DictDefault):
+ print("Loading quantized model...")
+ quantized_model_dir = get_quantized_model_dir(cfg)
+ model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,
+ device="cuda:0",
+ use_safetensors=True,
+ inject_fused_attention=False, # WORKAROUND for https://github.com/PanQiWei/AutoGPTQ/issues/210
+ )
+ print("Model loaded.")
+ return model
+
+def quantize_and_save(cfg: DictDefault, model, tokenizer, examples_for_quant):
+ print("Quantize...")
+ start = time.time()
+ # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
+ model.quantize(
+ examples_for_quant,
+ batch_size=1,
+ # batch_size=args.quant_batch_size,
+ # use_triton=args.use_triton,
+ # autotune_warmup_after_quantized=args.use_triton
+ )
+ end = time.time()
+ print(f"quantization took: {end - start: .4f}s")
+
+ # save quantized model
+ print("Saving quantized model...")
+ # model.save_quantized(quantized_model_dir)
+ quantized_model_dir = get_quantized_model_dir(cfg)
+ model.save_quantized(quantized_model_dir, use_safetensors=True)
+ print("Saving tokenizer...")
+ tokenizer.save_pretrained(quantized_model_dir)
+ print("Saved.")
+
+ # FIXME: Add fix to config.json
+ # "error": "handler: 'pad_token_id' \ntraceback: Traceback (most recent call last):\n File \"/usr/local/lib/python3.10/dist-packages/runpod/serverless/modules/job.py\", line 141, in run_job_generator\n for output_partial in job_output:\n File \"/data/handler.py\", line 107, in inference\n generator, default_settings = load_model()\n File \"/data/handler.py\", line 45, in load_model\n config = ExLlamaConfig(model_config_path) # create config from config.json\n File \"/data/exllama/model.py\", line 52, in __init__\n self.pad_token_id = read_config[\"pad_token_id\"]\nKeyError: 'pad_token_id'\n"
+
+ return model
+
+def push_model(cfg: DictDefault, model, tokenizer):
+# def push_model(model):
+ # push quantized model to Hugging Face Hub.
+ # to use use_auth_token=True, Login first via huggingface-cli login.
+ # or pass explcit token with: use_auth_token="hf_xxxxxxx"
+ # (uncomment the following three lines to enable this feature)
+ # repo_id = f"YourUserName/{quantized_model_dir}"
+ print("Pushing to Huggingface hub...")
+ # repo_id = f"{huggingface_username}/{quantized_model_dir}"
+ repo_id = get_quantized_model_id(cfg)
+ pretrained_model_dir = cfg['base_model']
+ commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
+ # model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, use_safetensors=True, safe_serialization=True)
+ # model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, safe_serialization=True)
+ model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, use_safetensors=True)
+ tokenizer.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True)
+ print("Pushed.")
+
+def get_quantized_model_id(cfg: DictDefault):
+# def get_quantized_model_id(cfg: DictDefault, quantize_config):
+ # return f"{cfg.hub_model_id}-{quantize_config.bits}bits-gr{quantize_config.group_size}-desc_act{quantize_config.desc_act}"
+ if not cfg.hub_model_id:
+ raise ValueError("Missing hub_model_id in the configuration.")
+ return f"{cfg.hub_model_id}-GPTQ"
+
+def get_quantized_model_dir(cfg: DictDefault):
+# def get_quantized_model_dir(cfg: DictDefault, quantize_config):
+ if not cfg.output_dir:
+ raise ValueError("Missing output_dir in the configuration.")
+ p = Path(cfg.output_dir) / "quantized"
+ return str(p).lstrip('./')
+
+def get_examples_for_quantization(dataset, n_samples):
+ print("Loading dataset...")
+ examples = dataset.shuffle(seed=42).select(
+ [
+ random.randrange(0, len(dataset) - 1) # nosec
+ for _ in range(n_samples)
+ ]
+ )
+
+ examples_for_quant = [
+ {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]}
+ for example in examples
+ ]
+ return examples_for_quant
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index ece1bd9b69..d959c896f1 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -30,6 +30,7 @@
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
bench_eval_callback_factory,
+ log_prediction_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
@@ -703,6 +704,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
**trainer_kwargs,
)
+ if cfg.use_wandb and cfg.eval_table_size > 0:
+ LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
+ trainer.add_callback(LogPredictionCallback(cfg))
+
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))