Skip to content

Commit

Permalink
Falcon embeddings (#1149) [skip docker]
Browse files Browse the repository at this point in the history
* also fix multipack for falcon and add smoke tests

* make sure to handle special tokens and added tokens for lora

* fix reference to model_type

* fix tests for falcon

* fix stray typo

* fixes for smoke tests
  • Loading branch information
winglian authored Jan 23, 2024
1 parent 0f77b8d commit e799e08
Show file tree
Hide file tree
Showing 10 changed files with 326 additions and 19 deletions.
2 changes: 1 addition & 1 deletion examples/falcon/config-7b-lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
2 changes: 1 addition & 1 deletion examples/falcon/config-7b-qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,5 @@ fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
2 changes: 1 addition & 1 deletion examples/falcon/config-7b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
12 changes: 12 additions & 0 deletions src/axolotl/monkeypatch/falcon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Patches to support multipack for falcon
"""
import transformers

from axolotl.monkeypatch.utils import get_unpad_data


def replace_falcon_attn_with_multipack_flash_attn():
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
2 changes: 2 additions & 0 deletions src/axolotl/utils/lora_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ def get_linear_embedding_layers(model_type):
return ["embd.wte", "lm_head.linear"]
if model_type == "gpt_neox":
return ["embed_in", "embed_out"]
if model_type == "falcon":
return ["word_embeddings", "lm_head"]
return ["embed_tokens", "lm_head"]
37 changes: 23 additions & 14 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,14 @@ def load_model(
LOG.info("patching mixtral with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()

if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.falcon import (
replace_falcon_attn_with_multipack_flash_attn,
)

LOG.info("patching falcon with flash attention")
replace_falcon_attn_with_multipack_flash_attn()

if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.qwen2 import (
replace_qwen2_attn_with_multipack_flash_attn,
Expand Down Expand Up @@ -434,18 +442,13 @@ def load_model(
if not cfg.sample_packing:
if cfg.s2_attention:
pass
if (
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
or model_config.model_type in ["mixtral", "qwen2"]
):
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
# most other models support flash attention, we can define exceptions as they come up
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type in ["mixtral", "qwen2"]:
if model_config.model_type in ["mixtral", "qwen2", "falcon"]:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
Expand All @@ -461,7 +464,11 @@ def load_model(
model_config.fused_dense = True

try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
if (
model_config.model_type == "llama"
and not cfg.trust_remote_code
and not cfg.gptq
):
from transformers import LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained(
Expand Down Expand Up @@ -755,8 +762,10 @@ def find_all_linear_names(model):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
embedding_modules = get_linear_embedding_layers(model.config.model_type)
output_embedding = embedding_modules[1]
if output_embedding in lora_module_names: # needed for 16-bit
lora_module_names.remove(output_embedding)

return list(lora_module_names)

Expand Down
6 changes: 6 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")

if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column")
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("token_type_ids")

train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
Expand Down
112 changes: 112 additions & 0 deletions tests/e2e/patched/test_falcon_samplepack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
E2E tests for falcon
"""

import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"


class TestFalconPatched(unittest.TestCase):
"""
Test case for Falcon models
"""

@with_temp_dir
def test_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"lora_modules_to_save": ["word_embeddings", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {
"bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.1,
"special_tokens": {
"bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
4 changes: 2 additions & 2 deletions tests/e2e/patched/test_mixtral_samplepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_qlora(self, temp_dir):
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"load_in_4bit": True,
"adapter": "qlora",
Expand All @@ -57,7 +58,6 @@ def test_qlora(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"sample_packing": True,
"bf16": "auto",
}
)
Expand All @@ -76,6 +76,7 @@ def test_ft(self, temp_dir):
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.1,
"special_tokens": {},
Expand All @@ -95,7 +96,6 @@ def test_ft(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"sample_packing": True,
"bf16": "auto",
}
)
Expand Down
Loading

0 comments on commit e799e08

Please sign in to comment.