diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000..3f60f05d36 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,37 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Remote Attach", + "type": "python", + "request": "attach", + "connect": { + "host": "0.0.0.0", + "port": 5678 + }, + "pathMappings": [ + { + "localRoot": "${workspaceFolder}", + "remoteRoot": "/workspace/axolotl/" + } + ], + "justMyCode": false + }, + { + "name": "train", + "type": "python", + "request": "launch", + "module": "accelerate.commands.launch", + "args": [ + "${workspaceFolder}/scripts/finetune.py", + // "${file}", + "${workspaceFolder}/examples/llama-2/tiny-random.yml", + ], // other args comes after train.py + "console": "integratedTerminal", + // "env": {"CUDA_LAUNCH_BLOCKING": "1"} + }, + ] +} diff --git a/README.md b/README.md index 775592efe6..4189165577 100644 --- a/README.md +++ b/README.md @@ -703,7 +703,7 @@ Pass the appropriate flag to the train command: Add below flag to train command above ```bash ---merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False +--merge_lora --lora_model_dir="./completed-model" ``` If you run out of CUDA memory, you can try to merge in system RAM with diff --git a/docker-compose.yaml b/docker-compose.yaml index a16be726cf..6708dbf6a0 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,9 +1,10 @@ # version: '3.8' services: axolotl: - build: - context: . - dockerfile: ./docker/Dockerfile + # build: + # context: . + # dockerfile: ./docker/Dockerfile + image: winglian/axolotl:main-py3.10-cu118-2.0.1 volumes: - .:/workspace/axolotl - ~/.cache/huggingface/:/root/.cache/huggingface/ @@ -15,6 +16,8 @@ services: - GIT_COMMITTER_NAME=${GIT_COMMITTER_NAME} - GIT_COMMITTER_EMAIL=${GIT_COMMITTER_EMAIL} - WANDB_API_KEY=${WANDB_API_KEY} + ports: + - "5678:5678" deploy: resources: reservations: diff --git a/examples/llama-2/llama-68.yml b/examples/llama-2/llama-68.yml new file mode 100644 index 0000000000..85616aba90 --- /dev/null +++ b/examples/llama-2/llama-68.yml @@ -0,0 +1,70 @@ +base_model: JackFram/llama-68m +base_model_config: JackFram/llama-68m + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./lora-out + +# sequence_len: 4096 +sequence_len: 2048 +sample_packing: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 3 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 20 +eval_table_size: 5 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/examples/llama-2/lora-short.yml b/examples/llama-2/lora-short.yml new file mode 100644 index 0000000000..bd2b51b962 --- /dev/null +++ b/examples/llama-2/lora-short.yml @@ -0,0 +1,70 @@ +base_model: meta-llama/Llama-2-7b-hf +base_model_config: meta-llama/Llama-2-7b-hf +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +# val_set_size: 0.01 +val_set_size: 0.001 +output_dir: ./lora-out + +sequence_len: 4096 +sample_packing: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +# num_epochs: 3 +# num_epochs: 1 +num_epochs: 0.1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 20 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index a54799b408..2438b0d884 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -27,7 +27,7 @@ lora_dropout: 0.05 lora_target_linear: true lora_fan_in_fan_out: -wandb_project: +wandb_project: test-issue-490-7b-2 wandb_entity: wandb_watch: wandb_run_id: @@ -56,6 +56,8 @@ flash_attention: true warmup_steps: 10 eval_steps: 20 +eval_table_size: 5 +eval_table_max_new_tokens: 128 save_steps: debug: deepspeed: diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index dd029859ed..ef20d9fbe3 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -58,6 +58,7 @@ flash_attention: true warmup_steps: 10 eval_steps: 20 +eval_table_size: 5 save_steps: debug: deepspeed: diff --git a/examples/llama-2/tiny-llama.yml b/examples/llama-2/tiny-llama.yml new file mode 100644 index 0000000000..a53c9c831b --- /dev/null +++ b/examples/llama-2/tiny-llama.yml @@ -0,0 +1,69 @@ +base_model: PY007/TinyLlama-1.1B-step-50K-105b +base_model_config: PY007/TinyLlama-1.1B-step-50K-105b + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./lora-out + +sequence_len: 4096 +sample_packing: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 3 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 20 +eval_table_size: 5 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/examples/llama-2/tiny-puffed-llama.yml b/examples/llama-2/tiny-puffed-llama.yml new file mode 100644 index 0000000000..ac02b7b27b --- /dev/null +++ b/examples/llama-2/tiny-puffed-llama.yml @@ -0,0 +1,71 @@ +base_model: PY007/TinyLlama-1.1B-step-50K-105b +base_model_config: PY007/TinyLlama-1.1B-step-50K-105b + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + # - path: mhenrichsen/alpaca_2k_test + # type: alpaca + - path: LDJnr/Puffin + type: sharegpt:chat +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./lora-tiny-puffed-out + +sequence_len: 2048 +sample_packing: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 4 +num_epochs: 2 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 20 +eval_table_size: 10 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/scripts/finetune.py b/scripts/finetune.py index b998edc798..bcf828202c 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -1,5 +1,7 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +import gc + import importlib import logging import os @@ -27,6 +29,7 @@ from axolotl.utils.models import load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.wandb import setup_wandb_env_vars +from axolotl.utils.quantize import get_examples_for_quantization, load_merged_model, push_model, quantize_and_save project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") @@ -57,27 +60,44 @@ def get_multi_line_input() -> Optional[str]: # instruction = pathlib.Path("/proc/self/fd/0").read_text() return instruction +def get_merged_out_dir(cfg: DictDefault): + return Path(cfg.output_dir) / "merged" -def do_merge_lora( +def do_merge_lora_model_and_tokenizer( *, cfg: DictDefault, - cli_args: TrainerCliArgs, + model, + tokenizer, ): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) safe_serialization = cfg.save_safetensors is True LOG.info("running merge of LoRA with base model") model = model.merge_and_unload() model.to(dtype=torch.float16) + merged_out_dir = str(get_merged_out_dir(cfg)) + if cfg.local_rank == 0: LOG.info("saving merged model") model.save_pretrained( - str(Path(cfg.output_dir) / "merged"), + merged_out_dir, safe_serialization=safe_serialization, ) - tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + tokenizer.save_pretrained(merged_out_dir) +def do_merge_lora( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + new_cfg = DictDefault({ + **cfg, + 'lora_model_dir': cfg.get('lora_model_dir', cfg['output_dir']), + 'load_in_8bit': False, + 'load_in_4bit': False, + }) + model, tokenizer = load_model_and_tokenizer(cfg=new_cfg, cli_args=cli_args) + do_merge_lora_model_and_tokenizer(cfg=new_cfg, model=model, tokenizer=tokenizer) def shard( *, @@ -135,6 +155,9 @@ def do_inference( batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) print("=" * 40) + print(prompt) + print("=" * 20) + model.eval() with torch.no_grad(): generation_config = GenerationConfig( @@ -267,11 +290,74 @@ def do_cli(config: Path = Path("examples/"), **kwargs): do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) elif parsed_cli_args.shard: shard(cfg=parsed_cfg, cli_args=parsed_cli_args) + elif parsed_cli_args.quantize: + dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + + tokenizer = load_tokenizer(parsed_cfg) + # Load merged model with AutoGPTQ + merged_model = load_merged_model(parsed_cfg) + + # Quantize & save + n_samples = 128 + examples = get_examples_for_quantization(dataset_meta.train_dataset, n_samples) + quantize_and_save(parsed_cfg, merged_model, tokenizer, examples) + elif parsed_cli_args.push: + model, tokenizer = load_model_and_tokenizer(cfg=parsed_cfg, cli_args=parsed_cli_args) + push_model(cfg=parsed_cfg, model=model, tokenizer=tokenizer) else: dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) if parsed_cli_args.prepare_ds_only: return + # model, tokenizer = train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) + # tokenizer = None + should_quantize = False + + if should_quantize: + # Merge model + # do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) + # do_merge_lora_model_and_tokenizer(cfg=parsed_cfg, model=model, tokenizer=tokenizer) + # new_cfg = parsed_cfg.copy() + # new_cfg['lora_model_dir'] = new_cfg['output_dir'] + # new_cfg['load_in_8bit'] = False + # new_cfg['load_in_4bit'] = False + + # new_cfg = DictDefault({ + # **parsed_cfg, + # 'lora_model_dir': parsed_cfg['output_dir'], + # 'load_in_8bit': False, + # 'load_in_4bit': False, + # }) + # lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False + # do_merge_lora(cfg=new_cfg, cli_args=parsed_cli_args) + + def log_gpu_memory(): + print("GPU Memory:", torch.cuda.memory_allocated()) + + log_gpu_memory() + print(len(gc.get_referrers(model))) + print(sys.getrefcount(model)) + + # TODO: release old model from GPU memory + print(gc.collect()) + del model + # del tokenizer + print(gc.collect()) + torch.cuda.empty_cache() + print(gc.collect()) + + log_gpu_memory() + + do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) + + # Load merged model with AutoGPTQ + merged_model = load_merged_model(parsed_cfg) + + # Quantize & save + n_samples = 128 + examples = get_examples_for_quantization(dataset_meta.train_dataset, n_samples) + quantize_and_save(parsed_cfg, merged_model, tokenizer, examples) + if __name__ == "__main__": diff --git a/scripts/quantize.py b/scripts/quantize.py new file mode 100644 index 0000000000..10bddf6a7c --- /dev/null +++ b/scripts/quantize.py @@ -0,0 +1,147 @@ +# pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ + +# import debugpy +# debugpy.listen(('0.0.0.0', 5678)) +# debugpy.wait_for_client() +# debugpy.breakpoint() + +import json +import random +import time +from pathlib import Path +import logging +import re + +import torch +from datasets import load_dataset, Dataset +from transformers import AutoTokenizer, LlamaTokenizer, TextGenerationPipeline +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +from axolotl.prompters import AlpacaPrompter +from axolotl.utils.models import load_model, load_tokenizer +from axolotl.common.cli import TrainerCliArgs +from axolotl.logging_config import configure_logging +from axolotl.utils.dict import DictDefault +# from scripts.finetune import load_cfg +from finetune import load_cfg, get_merged_out_dir, do_merge_lora_model_and_tokenizer, load_datasets + +from axolotl.utils.quantize import load_merged_model, get_quantized_model, quantize_and_save, push_model, get_quantized_model_id, get_quantized_model_dir, get_examples_for_quantization + +configure_logging() +LOG = logging.getLogger("axolotl.quantize") + +import debugpy +debugpy.listen(('0.0.0.0', 5678)) +debugpy.wait_for_client() +debugpy.breakpoint() + +class ProgressExtractingHandler(logging.Handler): + def emit(self, record): + log_entry = self.format(record) + progress_info = self.extract_progress(log_entry) + if progress_info: + print(f"Progress: {progress_info}") + + @staticmethod + def extract_progress(log_entry): +# [2023-09-11 07:20:37,502] [INFO] [auto_gptq.modeling._base.quantize:364] [PID:3962] [RANK:0] Quantizing self_attn.k_proj in layer 4/32... + match = re.search(r'layer (\d+/\d+)', log_entry) + return match.group(1) if match else None + # [2023-09-11 07:27:52,208] [INFO] [auto_gptq.modeling._utils.pack_model:129] [PID:3962] [RANK:0] model.layers.15.self_attn.o_proj + +handler = ProgressExtractingHandler() +# logging.getLogger('auto_gptq.modeling._base.quantize').addHandler(handler) +logger = logging.getLogger('auto_gptq.modeling._base.quantize') +logger.setLevel(logging.DEBUG) +logger.addHandler(handler) + +# logging.basicConfig( +# format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.DEBUG, datefmt="%Y-%m-%d %H:%M:%S" +# ) + +# LOG.setLevel(logging.DEBUG) +# handler = logging.StreamHandler() +# formatter = logging.Formatter('%(asctime)s %(levelname)s [%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') +# handler.setFormatter(formatter) +# LOG.addHandler(handler) + +print("Done importing...") + +## CHANGE BELOW ## +# config_path: Path = Path("./examples/llama-2/lora.yml") +config_path: Path = Path("./examples/llama-2/lora-short.yml") + +# pretrained_model_dir = "facebook/opt-125m" +# quantized_model_dir = "opt-125m-4bit" +dataset_name = "teknium/GPT4-LLM-Cleaned" +# huggingface_username = "CHANGE_ME" +## CHANGE ABOVE + +def main(): + print("Starting...") + # return + # prompt = "<|prompt|>How can entrepreneurs start building their own communities even before launching their product?<|answer|>" + + should_quantize = True + # tokenizer = get_tokenizer() + + cfg = load_cfg(config_path) + + cfg['lora_model_dir'] = cfg['output_dir'] + + LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + tokenizer = load_tokenizer(cfg) + + if should_quantize: + print("Quantizing...") + + print("Loading dataset...") + datasets = load_datasets(cfg=cfg, cli_args=TrainerCliArgs()) + train_dataset = datasets.train_dataset + n_samples = 128 + # # n_samples = 2 + # examples = train_dataset.shuffle(seed=42).select( + # [ + # random.randrange(0, len(train_dataset) - 1) # nosec + # for _ in range(n_samples) + # ] + # ) + + LOG.info("loading model and (optionally) peft_config...") + # model, peft_config = load_model(cfg, tokenizer, inference=True) + model = load_merged_model(cfg) + # model = get_model() + + # examples = load_data(dataset_name, tokenizer, n_samples) + + # print(examples) + # examples_for_quant = [ + # {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]} + # for example in examples + # ] + # print(examples_for_quant) + examples_for_quant = get_examples_for_quantization(train_dataset, n_samples) + + modelq = quantize_and_save(cfg, model, tokenizer, examples_for_quant) + else: + print("Loading quantized model...") + modelq = get_quantized_model(cfg) + + push_model(cfg, modelq, tokenizer) + +if __name__ == "__main__": + main() + + +# Load configure +# Load dataset +# Load tokenizer +# Prepare database +# Load previous model, final checkpoint + + +# --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False +# accelerate launch ./scripts/finetune.py ./examples/llama-2/lora.yml --merge_lora --lora_model_dir="./lora-out" --load_in_8bit=False --load_in_4bit=False +# CUDA_VISIBLE_DEVICES="1" accelerate launch ./scripts/finetune.py ./examples/llama-2/lora.yml --merge_lora --lora_model_dir="./lora-out" --load_in_8bit=False --load_in_4bit=False + +# HUB_MODEL_ID="Glavin001/llama-2-7b-alpaca_2k_test" accelerate launch ./scripts/quantize.py + diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 62f2b1061a..ef439d3cf0 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -9,6 +9,7 @@ from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer +from axolotl.utils.quantize import get_quantized_model configure_logging() LOG = logging.getLogger("axolotl.common.cli") @@ -25,6 +26,8 @@ class TrainerCliArgs: debug_num_examples: int = field(default=5) inference: bool = field(default=False) merge_lora: bool = field(default=False) + quantize: bool = field(default=False) + push: bool = field(default=False) prepare_ds_only: bool = field(default=False) prompter: Optional[str] = field(default=None) shard: bool = field(default=False) @@ -38,6 +41,8 @@ def load_model_and_tokenizer( LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) LOG.info("loading model and (optionally) peft_config...") - model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + # model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + # TEMP + model = get_quantized_model(cfg) return model, tokenizer diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py index 8f473aa240..41a8e41103 100644 --- a/src/axolotl/logging_config.py +++ b/src/axolotl/logging_config.py @@ -61,6 +61,11 @@ def format(self, record): "level": "DEBUG", "propagate": False, }, + "auto_gptq": { + "handlers": ["color_console"], + "level": "DEBUG", + "propagate": False, + }, }, } diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index ef048082c1..d172d302d9 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -193,7 +193,7 @@ def flashattn_forward( # only on first autoregressive step q,k,v have same seqlen is_causal = key_states.shape == query_states.shape - if cu_seqlens is not None and max_seqlen is not None: + if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: # special handling using sample packing qkv = torch.stack( [query_states, key_states, value_states], dim=2 @@ -261,6 +261,8 @@ def flashattn_forward( if attention_mask is not None else None, ) + if q_unpad.dtype != kv_unpad.dtype: + kv_unpad = kv_unpad.to(q_unpad.dtype) output_unpad = flash_attn_varlen_kvpacked_func( q_unpad, kv_unpad, diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 3f776537a5..3925c20d8c 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -15,6 +15,8 @@ from optimum.bettertransformer import BetterTransformer from tqdm import tqdm from transformers import ( + GenerationConfig, + Trainer, TrainerCallback, TrainerControl, TrainerState, @@ -22,6 +24,7 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy +import wandb from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.distributed import ( barrier, @@ -323,3 +326,177 @@ def on_evaluate( metrics[key] = val return BenchEvalCallback + + +def log_prediction_callback_factory(trainer: Trainer, tokenizer): + class LogPredictionCallback(TrainerCallback): + """Callback to log prediction values during each evaluation""" + + def __init__(self, cfg): + self.cfg = cfg + self.logged = False + + def on_evaluate( + self, + args: AxolotlTrainingArguments, + state: TrainerState, + control: TrainerControl, + train_dataloader, + eval_dataloader, + **kwargs, + ): + eval_table_size = self.cfg.eval_table_size + + if eval_table_size <= 0: + return control + + trainer.model.eval() + device = torch.device(self.cfg.device) + + generation_config = GenerationConfig( + max_new_tokens=self.cfg.eval_table_max_new_tokens, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + + def logits_to_tokens(logits) -> str: + probabilities = torch.softmax(logits, dim=-1) + # Get the predicted token ids (the ones with the highest probability) + predicted_token_ids = torch.argmax(probabilities, dim=-1) + return predicted_token_ids + + def find_ranges(lst): + ranges = [] + start = 0 + for i in range(1, len(lst)): + if lst[i] == 0: + ranges.append((start, i - 1)) + start = i + end = len(lst) - 1 + ranges.append((start, end)) + return ranges + + def log_table_from_dataloader(name: str, table_dataloader): + table = wandb.Table( + columns=[ + "id", + "Prompt", + "Correct Completion", + "Predicted Completion (model.generate)", + "Predicted Completion (trainer.prediction_step)", + ] + ) + row_index = 0 + + for batch in tqdm(table_dataloader): + + if row_index > eval_table_size: + break + + batch_labels = batch["labels"].to(device) + batch_input_ids = batch["input_ids"].to(device) + + if "position_ids" in batch: + batch_pos_ids = batch["position_ids"].tolist() + else: + batch_pos_ids = [None] * len(batch["input_ids"]) + + (_, batch_logits, _) = trainer.prediction_step( + trainer.model, + batch, + prediction_loss_only=False, + ) + + prompt_token_ids_list = [] + pred_step_token_ids_list = [] + completion_token_ids_list = [] + + for input_ids_all, labels_all, pos_ids, logits in zip( + batch_input_ids, batch_labels, batch_pos_ids, batch_logits, + ): + if pos_ids is None: + pos_ranges = [(0, len(input_ids_all) - 1)] + else: + pos_ranges = find_ranges(pos_ids) + + for pos_range in pos_ranges: + start, end = pos_range + if start == end: + continue + + input_ids = input_ids_all[start : end + 1] + labels = labels_all[start : end + 1] + + tokens_without_loss = labels == IGNORE_INDEX + tokens_with_loss = labels != IGNORE_INDEX + tokens_exclude_padding = input_ids != tokenizer.pad_token_id + prompt_token_includes = ( + tokens_without_loss & tokens_exclude_padding + ) + + prompt_token_ids = input_ids[prompt_token_includes] + prompt_token_ids_list.append(prompt_token_ids) + + completion_token_ids = input_ids[tokens_with_loss] + completion_token_ids_list.append(completion_token_ids) + + pred_step_token_ids = logits_to_tokens(logits[start : end + 1])[tokens_with_loss] + pred_step_token_ids_list.append(pred_step_token_ids) + + prompt_texts = tokenizer.batch_decode( + prompt_token_ids_list, skip_special_tokens=True + ) + completion_texts = tokenizer.batch_decode( + completion_token_ids_list, skip_special_tokens=True + ) + pred_step_texts = tokenizer.batch_decode( + pred_step_token_ids_list, skip_special_tokens=True + ) + + with torch.no_grad(): + prompt_encoding = tokenizer( + prompt_texts, padding=True, return_tensors="pt" + ).to(self.cfg.device) + predictions = trainer.model.generate( + **prompt_encoding, generation_config=generation_config + ) + + prediction_all_tokens = predictions["sequences"].cpu().tolist() + prediction_without_prompt_tokens_list = [] + for prompt_token_ids, prediction_tokens in zip( + prompt_token_ids_list, prediction_all_tokens + ): + prediction_without_prompt_tokens = prediction_tokens[ + len(prompt_token_ids) : + ] + prediction_without_prompt_tokens_list.append( + prediction_without_prompt_tokens + ) + + predicted_texts = tokenizer.batch_decode( + prediction_without_prompt_tokens_list, skip_special_tokens=True + ) + + for prompt_text, completion_text, prediction_text, pred_step_text in zip( + prompt_texts, completion_texts, predicted_texts, pred_step_texts + ): + table.add_data( + row_index, prompt_text, completion_text, prediction_text, pred_step_text + ) + row_index += 1 + + wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) + + if is_main_process(): + log_table_from_dataloader("Eval", eval_dataloader) + + return control + + return LogPredictionCallback diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 6de807eab9..a4e68869d8 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -48,6 +48,8 @@ def normalize_config(cfg): ) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + cfg.eval_table_size = cfg.eval_table_size or 0 + cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128 choose_device(cfg) cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 if cfg.ddp: @@ -82,6 +84,11 @@ def normalize_config(cfg): log_gpu_memory_usage(LOG, "baseline", cfg.device) + if os.environ.get("WANDB_PROJECT") and len(os.environ.get("WANDB_PROJECT", "")) > 0: + cfg.wandb_project = os.environ.get("WANDB_PROJECT") + + if os.environ.get("HUB_MODEL_ID") and len(os.environ.get("HUB_MODEL_ID", "")) > 0: + cfg.hub_model_id = os.environ.get("HUB_MODEL_ID") def validate_config(cfg): if cfg.max_packed_sequence_len and cfg.sample_packing: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2000b1aee8..eae08f7f70 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -296,10 +296,10 @@ def load_model( if ( hasattr(model.config, "max_position_embeddings") and model.config.max_position_embeddings - and cfg.sequence_len >= model.config.max_position_embeddings + and cfg.sequence_len > model.config.max_position_embeddings ): LOG.warning( - f"increasing model.config.max_position_embeddings to {cfg.sequence_len}" + f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}" ) model.config.max_position_embeddings = cfg.sequence_len diff --git a/src/axolotl/utils/quantize.py b/src/axolotl/utils/quantize.py new file mode 100644 index 0000000000..39318836a9 --- /dev/null +++ b/src/axolotl/utils/quantize.py @@ -0,0 +1,139 @@ +# pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ + +# import debugpy +# debugpy.listen(('0.0.0.0', 5678)) +# debugpy.wait_for_client() +# debugpy.breakpoint() + +import json +import random +import time +from pathlib import Path +import logging + +# import torch +# from datasets import load_dataset, Dataset +# from transformers import AutoTokenizer, LlamaTokenizer, TextGenerationPipeline +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +# from axolotl.prompters import AlpacaPrompter +# from axolotl.utils.models import load_model, load_tokenizer +# from axolotl.common.cli import TrainerCliArgs +# from axolotl.logging_config import configure_logging +from axolotl.utils.dict import DictDefault +# from finetune import load_cfg, get_merged_out_dir, do_merge_lora_model_and_tokenizer + +# configure_logging() +# LOG = logging.getLogger("axolotl") + +quantize_config = BaseQuantizeConfig( + bits=4, # quantize model to 4-bit + group_size=128, # it is recommended to set the value to 128 + desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad +) + +def get_merged_out_dir(cfg: DictDefault): + return Path(cfg.output_dir) / "merged" + +def load_merged_model(cfg: DictDefault): + print("Loading merged model...") + + merged_out_dir = get_merged_out_dir(cfg) + + # Check if the merged model exists + if not merged_out_dir.exists(): + # If not, merge the model + raise FileNotFoundError("Merged model not found. Please ensure the model has been merged.") + # do_merge_lora_model_and_tokenizer(cfg=cfg, model=model, tokenizer=tokenizer) + # raise NotImplementedError("Merging model is not implemented yet.") + + # load un-quantized model, by default, the model will always be loaded into CPU memory + model = AutoGPTQForCausalLM.from_pretrained(merged_out_dir, quantize_config) + # model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) + print("Model loaded.") + return model + +def get_quantized_model(cfg: DictDefault): + print("Loading quantized model...") + quantized_model_dir = get_quantized_model_dir(cfg) + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, + device="cuda:0", + use_safetensors=True, + inject_fused_attention=False, # WORKAROUND for https://github.com/PanQiWei/AutoGPTQ/issues/210 + ) + print("Model loaded.") + return model + +def quantize_and_save(cfg: DictDefault, model, tokenizer, examples_for_quant): + print("Quantize...") + start = time.time() + # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask" + model.quantize( + examples_for_quant, + batch_size=1, + # batch_size=args.quant_batch_size, + # use_triton=args.use_triton, + # autotune_warmup_after_quantized=args.use_triton + ) + end = time.time() + print(f"quantization took: {end - start: .4f}s") + + # save quantized model + print("Saving quantized model...") + # model.save_quantized(quantized_model_dir) + quantized_model_dir = get_quantized_model_dir(cfg) + model.save_quantized(quantized_model_dir, use_safetensors=True) + print("Saving tokenizer...") + tokenizer.save_pretrained(quantized_model_dir) + print("Saved.") + + # FIXME: Add fix to config.json + # "error": "handler: 'pad_token_id' \ntraceback: Traceback (most recent call last):\n File \"/usr/local/lib/python3.10/dist-packages/runpod/serverless/modules/job.py\", line 141, in run_job_generator\n for output_partial in job_output:\n File \"/data/handler.py\", line 107, in inference\n generator, default_settings = load_model()\n File \"/data/handler.py\", line 45, in load_model\n config = ExLlamaConfig(model_config_path) # create config from config.json\n File \"/data/exllama/model.py\", line 52, in __init__\n self.pad_token_id = read_config[\"pad_token_id\"]\nKeyError: 'pad_token_id'\n" + + return model + +def push_model(cfg: DictDefault, model, tokenizer): +# def push_model(model): + # push quantized model to Hugging Face Hub. + # to use use_auth_token=True, Login first via huggingface-cli login. + # or pass explcit token with: use_auth_token="hf_xxxxxxx" + # (uncomment the following three lines to enable this feature) + # repo_id = f"YourUserName/{quantized_model_dir}" + print("Pushing to Huggingface hub...") + # repo_id = f"{huggingface_username}/{quantized_model_dir}" + repo_id = get_quantized_model_id(cfg) + pretrained_model_dir = cfg['base_model'] + commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}" + # model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, use_safetensors=True, safe_serialization=True) + # model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, safe_serialization=True) + model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, use_safetensors=True) + tokenizer.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True) + print("Pushed.") + +def get_quantized_model_id(cfg: DictDefault): +# def get_quantized_model_id(cfg: DictDefault, quantize_config): + # return f"{cfg.hub_model_id}-{quantize_config.bits}bits-gr{quantize_config.group_size}-desc_act{quantize_config.desc_act}" + if not cfg.hub_model_id: + raise ValueError("Missing hub_model_id in the configuration.") + return f"{cfg.hub_model_id}-GPTQ" + +def get_quantized_model_dir(cfg: DictDefault): +# def get_quantized_model_dir(cfg: DictDefault, quantize_config): + if not cfg.output_dir: + raise ValueError("Missing output_dir in the configuration.") + p = Path(cfg.output_dir) / "quantized" + return str(p).lstrip('./') + +def get_examples_for_quantization(dataset, n_samples): + print("Loading dataset...") + examples = dataset.shuffle(seed=42).select( + [ + random.randrange(0, len(dataset) - 1) # nosec + for _ in range(n_samples) + ] + ) + + examples_for_quant = [ + {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]} + for example in examples + ] + return examples_for_quant diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index ece1bd9b69..d959c896f1 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -30,6 +30,7 @@ SaveBetterTransformerModelCallback, SavePeftModelCallback, bench_eval_callback_factory, + log_prediction_callback_factory, ) from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.dataloader import MultipackDistributedDataloader @@ -703,6 +704,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ **trainer_kwargs, ) + if cfg.use_wandb and cfg.eval_table_size > 0: + LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer) + trainer.add_callback(LogPredictionCallback(cfg)) + if cfg.do_bench_eval: trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))