From 96be9139ba9bf53e5943d29d7e88310c76cb5e5f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 21:01:42 -0500 Subject: [PATCH] Falcon embeddings (#1149) [skip docker] * also fix multipack for falcon and add smoke tests * make sure to handle special tokens and added tokens for lora * fix reference to model_type * fix tests for falcon * fix stray typo * fixes for smoke tests --- examples/falcon/config-7b-lora.yml | 2 +- examples/falcon/config-7b-qlora.yml | 2 +- examples/falcon/config-7b.yml | 2 +- src/axolotl/monkeypatch/falcon/__init__.py | 12 ++ src/axolotl/utils/lora_embeddings.py | 2 + src/axolotl/utils/models.py | 37 +++-- src/axolotl/utils/trainer.py | 6 + tests/e2e/patched/test_falcon_samplepack.py | 112 +++++++++++++ tests/e2e/patched/test_mixtral_samplepack.py | 4 +- tests/e2e/test_falcon.py | 166 +++++++++++++++++++ 10 files changed, 326 insertions(+), 19 deletions(-) create mode 100644 src/axolotl/monkeypatch/falcon/__init__.py create mode 100644 tests/e2e/patched/test_falcon_samplepack.py create mode 100644 tests/e2e/test_falcon.py diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml index 7cdbb6cef8..ff713d7d13 100644 --- a/examples/falcon/config-7b-lora.yml +++ b/examples/falcon/config-7b-lora.yml @@ -60,5 +60,5 @@ fsdp: fsdp_config: special_tokens: pad_token: "<|endoftext|>" - bos_token: ">>ABSTRACT<<" + bos_token: "<|endoftext|>" eos_token: "<|endoftext|>" diff --git a/examples/falcon/config-7b-qlora.yml b/examples/falcon/config-7b-qlora.yml index d93806dfc8..c6c71ac895 100644 --- a/examples/falcon/config-7b-qlora.yml +++ b/examples/falcon/config-7b-qlora.yml @@ -89,5 +89,5 @@ fsdp: fsdp_config: special_tokens: pad_token: "<|endoftext|>" - bos_token: ">>ABSTRACT<<" + bos_token: "<|endoftext|>" eos_token: "<|endoftext|>" diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml index 722ab07404..6082ee87eb 100644 --- a/examples/falcon/config-7b.yml +++ b/examples/falcon/config-7b.yml @@ -60,5 +60,5 @@ fsdp: fsdp_config: special_tokens: pad_token: "<|endoftext|>" - bos_token: ">>ABSTRACT<<" + bos_token: "<|endoftext|>" eos_token: "<|endoftext|>" diff --git a/src/axolotl/monkeypatch/falcon/__init__.py b/src/axolotl/monkeypatch/falcon/__init__.py new file mode 100644 index 0000000000..dc6e526f66 --- /dev/null +++ b/src/axolotl/monkeypatch/falcon/__init__.py @@ -0,0 +1,12 @@ +""" +Patches to support multipack for falcon +""" +import transformers + +from axolotl.monkeypatch.utils import get_unpad_data + + +def replace_falcon_attn_with_multipack_flash_attn(): + transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) diff --git a/src/axolotl/utils/lora_embeddings.py b/src/axolotl/utils/lora_embeddings.py index b5d2f7cc94..d9fe35eb81 100644 --- a/src/axolotl/utils/lora_embeddings.py +++ b/src/axolotl/utils/lora_embeddings.py @@ -11,4 +11,6 @@ def get_linear_embedding_layers(model_type): return ["embd.wte", "lm_head.linear"] if model_type == "gpt_neox": return ["embed_in", "embed_out"] + if model_type == "falcon": + return ["word_embeddings", "lm_head"] return ["embed_tokens", "lm_head"] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fb4caa6d85..d75926952f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -334,6 +334,14 @@ def load_model( LOG.info("patching mixtral with flash attention") replace_mixtral_attn_with_multipack_flash_attn() + if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing: + from axolotl.monkeypatch.falcon import ( + replace_falcon_attn_with_multipack_flash_attn, + ) + + LOG.info("patching falcon with flash attention") + replace_falcon_attn_with_multipack_flash_attn() + if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing: from axolotl.monkeypatch.qwen2 import ( replace_qwen2_attn_with_multipack_flash_attn, @@ -434,18 +442,13 @@ def load_model( if not cfg.sample_packing: if cfg.s2_attention: pass - if ( - cfg.is_llama_derived_model - or cfg.is_falcon_derived_model - or cfg.is_mistral_derived_model - or model_config.model_type in ["mixtral", "qwen2"] - ): - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) + # most other models support flash attention, we can define exceptions as they come up + model_kwargs["attn_implementation"] = "flash_attention_2" + model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) else: - if model_config.model_type in ["mixtral", "qwen2"]: + if model_config.model_type in ["mixtral", "qwen2", "falcon"]: model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2" @@ -461,7 +464,11 @@ def load_model( model_config.fused_dense = True try: - if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: + if ( + model_config.model_type == "llama" + and not cfg.trust_remote_code + and not cfg.gptq + ): from transformers import LlamaForCausalLM model = LlamaForCausalLM.from_pretrained( @@ -755,8 +762,10 @@ def find_all_linear_names(model): names = name.split(".") lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - if "lm_head" in lora_module_names: # needed for 16-bit - lora_module_names.remove("lm_head") + embedding_modules = get_linear_embedding_layers(model.config.model_type) + output_embedding = embedding_modules[1] + if output_embedding in lora_module_names: # needed for 16-bit + lora_module_names.remove(output_embedding) return list(lora_module_names) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index b8235d3cf8..2dec90eb79 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -124,6 +124,12 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): if eval_dataset: eval_dataset = eval_dataset.remove_columns("attention_mask") + if cfg.model_config_type == "falcon": + LOG.info("dropping token_type_ids column") + train_dataset = train_dataset.remove_columns("token_type_ids") + if eval_dataset: + eval_dataset = eval_dataset.remove_columns("token_type_ids") + train_dataset = train_dataset.filter( drop_long, num_proc=cfg.dataset_processes, diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py new file mode 100644 index 0000000000..ae6a497391 --- /dev/null +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -0,0 +1,112 @@ +""" +E2E tests for falcon +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestFalconPatched(unittest.TestCase): + """ + Test case for Falcon models + """ + + @with_temp_dir + def test_qlora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "illuin/tiny-random-FalconForCausalLM", + "flash_attention": True, + "sample_packing": True, + "sequence_len": 2048, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 16, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "lora_modules_to_save": ["word_embeddings", "lm_head"], + "val_set_size": 0.1, + "special_tokens": { + "bos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "bf16": "auto", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_ft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "illuin/tiny-random-FalconForCausalLM", + "flash_attention": True, + "sample_packing": True, + "sequence_len": 2048, + "val_set_size": 0.1, + "special_tokens": { + "bos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "bf16": "auto", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "pytorch_model.bin").exists() diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 30c53103ed..4c05113f55 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -32,6 +32,7 @@ def test_qlora(self, temp_dir): "base_model": "hf-internal-testing/Mixtral-tiny", "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", "flash_attention": True, + "sample_packing": True, "sequence_len": 2048, "load_in_4bit": True, "adapter": "qlora", @@ -57,7 +58,6 @@ def test_qlora(self, temp_dir): "max_steps": 20, "save_steps": 10, "eval_steps": 10, - "sample_packing": True, "bf16": "auto", } ) @@ -76,6 +76,7 @@ def test_ft(self, temp_dir): "base_model": "hf-internal-testing/Mixtral-tiny", "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", "flash_attention": True, + "sample_packing": True, "sequence_len": 2048, "val_set_size": 0.1, "special_tokens": {}, @@ -95,7 +96,6 @@ def test_ft(self, temp_dir): "max_steps": 20, "save_steps": 10, "eval_steps": 10, - "sample_packing": True, "bf16": "auto", } ) diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py new file mode 100644 index 0000000000..c76699a7c8 --- /dev/null +++ b/tests/e2e/test_falcon.py @@ -0,0 +1,166 @@ +""" +E2E tests for falcon +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestFalcon(unittest.TestCase): + """ + Test case for falcon + """ + + @with_temp_dir + def test_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "illuin/tiny-random-FalconForCausalLM", + "flash_attention": True, + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 64, + "lora_dropout": 0.05, + "lora_target_linear": True, + "lora_modules_to_save": [ + "word_embeddings", + "lm_head", + ], + "val_set_size": 0.1, + "special_tokens": { + "bos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "bf16": "auto", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_lora_added_vocab(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "illuin/tiny-random-FalconForCausalLM", + "flash_attention": True, + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 64, + "lora_dropout": 0.05, + "lora_target_linear": True, + "lora_modules_to_save": [ + "word_embeddings", + "lm_head", + ], + "val_set_size": 0.1, + "special_tokens": { + "bos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>", + }, + "tokens": [ + "<|im_start|>", + "<|im_end|>", + ], + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "bf16": "auto", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_ft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "illuin/tiny-random-FalconForCausalLM", + "flash_attention": True, + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": { + "bos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "bf16": "auto", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "pytorch_model.bin").exists()