From 35a7e4999221248dfe91821853bf5204681ed78e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Dec 2023 14:49:40 -0800 Subject: [PATCH 01/15] support for mamba --- examples/mamba/config.yml | 61 +++++++++++++++++++++++++++++++++++++ setup.py | 3 ++ src/axolotl/utils/models.py | 36 +++++++++++++++++++--- 3 files changed, 95 insertions(+), 5 deletions(-) create mode 100644 examples/mamba/config.yml diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml new file mode 100644 index 0000000000..2f8bbcdb94 --- /dev/null +++ b/examples/mamba/config.yml @@ -0,0 +1,61 @@ +base_model: state-spaces/mamba-2.8b +model_type: MambaLMHeadModel +tokenizer_type: AutoTokenizer +tokenizer_config: EleutherAI/gpt-neox-20b + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./out + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.000005 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: + +warmup_steps: 10 +eval_steps: 0.05 +eval_table_size: +eval_table_max_new_tokens: 128 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/setup.py b/setup.py index e3ee54350b..a531b1fa4a 100644 --- a/setup.py +++ b/setup.py @@ -51,5 +51,8 @@ def parse_requirements(): "deepspeed": [ "deepspeed", ], + "mamba": [ + "git+https://github.com/state-spaces/mamba.git", + ], }, ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 40a0a89474..536b0c81ee 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -4,6 +4,7 @@ import os from typing import Optional, Tuple # noqa: F401 +import addict import bitsandbytes as bnb import torch import transformers @@ -52,6 +53,12 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig): def load_model_config(cfg): model_config_name = cfg.base_model_config or cfg.base_model trust_remote_code = cfg.trust_remote_code is True + if "state-spaces/mamba" in model_config_name: + return addict.Dict( + { + "model_type": "mamba", + } + ) model_config = AutoConfig.from_pretrained( model_config_name, trust_remote_code=trust_remote_code ) @@ -333,6 +340,18 @@ def load_model( load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, **model_kwargs, ) + elif model_type == "MambaLMHeadModel": + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + + del model_kwargs["torch_dtype"] + del model_kwargs["device_map"] + + model = MambaLMHeadModel.from_pretrained( + base_model, + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, + **model_kwargs, + ) elif model_type and not cfg.trust_remote_code: if cfg.gptq: model = AutoModelForCausalLM.from_pretrained( @@ -392,13 +411,17 @@ def load_model( if cfg.resize_token_embeddings_to_32x else len(tokenizer) ) - if model.get_input_embeddings().num_embeddings < embeddings_len: + if ( + hasattr(model, "get_input_embeddings") + and model.get_input_embeddings().num_embeddings < embeddings_len + ): model.resize_token_embeddings(embeddings_len) else: model.tie_weights() if ( - hasattr(model.config, "max_position_embeddings") + hasattr(model, "config") + and hasattr(model.config, "max_position_embeddings") and model.config.max_position_embeddings and cfg.sequence_len > model.config.max_position_embeddings ): @@ -408,14 +431,16 @@ def load_model( model.config.max_position_embeddings = cfg.sequence_len if ( - hasattr(model.config, "bos_token_id") + hasattr(model, "config") + and hasattr(model.config, "bos_token_id") and model.config.bos_token_id and model.config.bos_token_id != tokenizer.bos_token_id ): model.config.bos_token_id = tokenizer.bos_token_id if ( - hasattr(model.config, "eos_token_id") + hasattr(model, "config") + and hasattr(model.config, "eos_token_id") and model.config.eos_token_id and model.config.eos_token_id != tokenizer.eos_token_id ): @@ -480,7 +505,8 @@ def load_model( requires_grad.append(f"{name}: {param.requires_grad}") if len(requires_grad) == 0: LOG.warning("there are no parameters that require gradient updates") - model.config.use_cache = False + if hasattr(model, "config"): + model.config.use_cache = False if cfg.flash_optimum: model = BetterTransformer.transform(model) From af3a9dd5d2238c08097732e734f2a43d65b2c6aa Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Dec 2023 15:35:25 -0800 Subject: [PATCH 02/15] more mamba fixes --- examples/mamba/config.yml | 6 +-- src/axolotl/core/trainer_builder.py | 5 ++- src/axolotl/train.py | 6 ++- src/axolotl/utils/models.py | 8 ++-- src/axolotl/utils/trainer.py | 17 ++++++-- tests/e2e/test_mamba.py | 64 +++++++++++++++++++++++++++++ 6 files changed, 93 insertions(+), 13 deletions(-) create mode 100644 tests/e2e/test_mamba.py diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 2f8bbcdb94..5b8ff2b928 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -1,4 +1,4 @@ -base_model: state-spaces/mamba-2.8b +base_model: state-spaces/mamba-130m model_type: MambaLMHeadModel tokenizer_type: AutoTokenizer tokenizer_config: EleutherAI/gpt-neox-20b @@ -15,8 +15,8 @@ val_set_size: 0.05 output_dir: ./out sequence_len: 2048 -sample_packing: true -pad_to_sequence_len: true +sample_packing: false +pad_to_sequence_len: false wandb_project: wandb_entity: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d166691f10..ce54824c69 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -282,7 +282,10 @@ def compute_loss(self, model, inputs, return_outputs=False): # outputs = model(**inputs) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # return (loss, outputs) if return_outputs else loss - return super().compute_loss(model, inputs, return_outputs=return_outputs) + loss = super().compute_loss(model, inputs, return_outputs=return_outputs) + if loss.numel() > 1: + loss = loss.mean() + return loss class OneCycleLRSchedulerTrainer(AxolotlTrainer): diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 60c76b1b0f..022d230cbb 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -82,7 +82,8 @@ def train( cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps ) - model.config.use_cache = False + if hasattr(model, "config"): + model.config.use_cache = False # go ahead and presave, so we have the adapter config available to inspect if peft_config: @@ -92,7 +93,8 @@ def train( if not Path(cfg.output_dir).is_dir(): os.makedirs(cfg.output_dir, exist_ok=True) tokenizer.save_pretrained(str(Path(cfg.output_dir))) - model.config.save_pretrained(str(Path(cfg.output_dir))) + if hasattr(model, "config"): + model.config.save_pretrained(str(Path(cfg.output_dir))) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 536b0c81ee..4b20a11a4a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -341,15 +341,17 @@ def load_model( **model_kwargs, ) elif model_type == "MambaLMHeadModel": + # FIXME this is janky at best and hacked together to make it work from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + model_kwargs["dtype"] = model_kwargs["torch_dtype"] + model_kwargs["device"] = torch.cuda.current_device() del model_kwargs["torch_dtype"] del model_kwargs["device_map"] + del model_kwargs["max_memory"] model = MambaLMHeadModel.from_pretrained( base_model, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, **model_kwargs, ) elif model_type and not cfg.trust_remote_code: @@ -446,7 +448,7 @@ def load_model( ): model.config.eos_token_id = tokenizer.eos_token_id - if model.device.type == "cuda": + if hasattr(model, "device") and model.device.type == "cuda": log_gpu_memory_usage(LOG, "after model load", model.device) # make sure these are fp32 per Ramesh et al. (2021) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 469f6d8865..22a3ab8c5a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -131,13 +131,20 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): ) # Phi doesn't want the attention_mask feature when training - if "CodeGenTokenizer" in tokenizer.__class__.__name__ or ( - cfg.is_mistral_derived_model and cfg.flash_attention + if ( + "CodeGenTokenizer" in tokenizer.__class__.__name__ + or (cfg.is_mistral_derived_model and cfg.flash_attention) + or cfg.model_config_type == "mamba" ): train_dataset = train_dataset.remove_columns("attention_mask") if eval_dataset: eval_dataset = eval_dataset.remove_columns("attention_mask") + if cfg.model_config_type == "mamba": + train_dataset = train_dataset.remove_columns("labels") + if eval_dataset: + eval_dataset = eval_dataset.remove_columns("labels") + return train_dataset, eval_dataset @@ -153,7 +160,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): if update: cfg.total_num_tokens = total_num_tokens - if not cfg.total_supervised_tokens: + skip_estimates = cfg.model_config_type == "mamba" + + if not skip_estimates and not cfg.total_supervised_tokens: total_supervised_tokens = ( train_dataset.data.column("labels") .to_pandas() @@ -167,7 +176,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): if update: cfg.total_supervised_tokens = total_supervised_tokens - if cfg.sample_packing: + if not skip_estimates and cfg.sample_packing: # we have to drop anything longer then sequence len otherwise # flash attention with position ids fails diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py new file mode 100644 index 0000000000..9f688339d7 --- /dev/null +++ b/tests/e2e/test_mamba.py @@ -0,0 +1,64 @@ +""" +E2E tests for lora llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestMistral(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_fft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "state-spaces/mamba-130m", + "model_type": "MambaLMHeadModel", + "tokenizer_type": "AutoTokenizer", + "tokenizer_config": "EleutherAI/gpt-neox-20b", + "flash_attention": False, + "sequence_len": 1024, + "load_in_8bit": False, + "val_set_size": 0.1, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "gradient_checkpointing": False, + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "pytorch_model.bin").exists() From 16998fd5412347031d6dc392ac9f51c5bcee6385 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Dec 2023 15:38:03 -0800 Subject: [PATCH 03/15] use fork for mamba kwargs fix --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a531b1fa4a..93bad75ab7 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ def parse_requirements(): "deepspeed", ], "mamba": [ - "git+https://github.com/state-spaces/mamba.git", + "git+https://github.com/OpenAccess-AI-Collective/mamba.git@model-kwargs" ], }, ) From 4b642923fd688e73717f064a9bd539d316d19ba4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Dec 2023 15:38:55 -0800 Subject: [PATCH 04/15] grad checkpointing doesn't work --- examples/mamba/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 5b8ff2b928..1dd9b03d0b 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -37,7 +37,7 @@ bf16: true fp16: false tf32: false -gradient_checkpointing: true +gradient_checkpointing: false early_stopping_patience: resume_from_checkpoint: local_rank: From d90325ca0e488138597a413d54c8777d047871af Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Dec 2023 15:48:17 -0800 Subject: [PATCH 05/15] fix extras for mamaba --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 93bad75ab7..5d32c8b283 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ def parse_requirements(): "deepspeed", ], "mamba": [ - "git+https://github.com/OpenAccess-AI-Collective/mamba.git@model-kwargs" + "mamba-ssm @ git+https://github.com/OpenAccess-AI-Collective/mamba.git@model-kwargs", ], }, ) From 28617fe9153fcb23707beeb1548c22b49b566309 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 5 Dec 2023 17:13:43 -0800 Subject: [PATCH 06/15] mamba loss fix --- examples/mamba/config.yml | 2 +- src/axolotl/models/mamba/__init__.py | 50 ++++++++++++++++++++++++++++ src/axolotl/utils/models.py | 3 +- 3 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 src/axolotl/models/mamba/__init__.py diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 1dd9b03d0b..918373a4f6 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -29,7 +29,7 @@ micro_batch_size: 2 num_epochs: 4 optimizer: adamw_bnb_8bit lr_scheduler: cosine -learning_rate: 0.000005 +learning_rate: 3e-7 train_on_inputs: false group_by_length: false diff --git a/src/axolotl/models/mamba/__init__.py b/src/axolotl/models/mamba/__init__.py new file mode 100644 index 0000000000..e97a1375e2 --- /dev/null +++ b/src/axolotl/models/mamba/__init__.py @@ -0,0 +1,50 @@ +# pylint: skip-file + +from collections import namedtuple + +from torch.nn import CrossEntropyLoss + + +def fix_mamba_attn_for_loss(): + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + + MambaLMHeadModel.forward = mamba_forward + return MambaLMHeadModel # pylint: disable=invalid-name + + +def mamba_forward( + self, + input_ids, + position_ids=None, + inference_params=None, + num_last_tokens=0, + labels=None, +): + """ + "position_ids" is just to be compatible with Transformer generation. We don't use it. + num_last_tokens: if > 0, only return the logits for the last n tokens + """ + hidden_states = self.backbone(input_ids, inference_params=inference_params) + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + logits = lm_logits + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + print(loss, shift_logits, shift_logits.dtype, shift_labels, shift_labels.dtype) + return (loss,) + + else: + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) + return CausalLMOutput(logits=lm_logits) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4b20a11a4a..e3752c2eb6 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -22,6 +22,7 @@ PreTrainedTokenizerBase, ) +from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault @@ -342,7 +343,7 @@ def load_model( ) elif model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work - from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + MambaLMHeadModel = fix_mamba_attn_for_loss() model_kwargs["dtype"] = model_kwargs["torch_dtype"] model_kwargs["device"] = torch.cuda.current_device() From b7f34d6c6e3a547e323bb2e44a93d3d085191bfd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 6 Dec 2023 01:30:21 -0800 Subject: [PATCH 07/15] use fp32 and remove verbose logging --- examples/mamba/config.yml | 5 +++-- src/axolotl/models/mamba/__init__.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 918373a4f6..8ed17c5bb2 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -33,9 +33,10 @@ learning_rate: 3e-7 train_on_inputs: false group_by_length: false -bf16: true + +bf16: false fp16: false -tf32: false +tf32: true gradient_checkpointing: false early_stopping_patience: diff --git a/src/axolotl/models/mamba/__init__.py b/src/axolotl/models/mamba/__init__.py index e97a1375e2..7f395ff4c6 100644 --- a/src/axolotl/models/mamba/__init__.py +++ b/src/axolotl/models/mamba/__init__.py @@ -42,7 +42,6 @@ def mamba_forward( # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) - print(loss, shift_logits, shift_logits.dtype, shift_labels, shift_labels.dtype) return (loss,) else: From 4daa0bd7fb9d210856a6b73691648f9c2b8b7191 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 7 Dec 2023 12:04:19 -0800 Subject: [PATCH 08/15] mamba fixes --- examples/mamba/config.yml | 12 +- src/axolotl/core/trainer_builder.py | 51 ++++++-- src/axolotl/models/mamba/__init__.py | 51 ++------ .../models/mamba/configuration_mamba.py | 42 +++++++ src/axolotl/models/mamba/modeling_mamba.py | 116 ++++++++++++++++++ src/axolotl/utils/collators.py | 34 ++++- src/axolotl/utils/models.py | 22 ++-- src/axolotl/utils/trainer.py | 5 - 8 files changed, 257 insertions(+), 76 deletions(-) create mode 100644 src/axolotl/models/mamba/configuration_mamba.py create mode 100644 src/axolotl/models/mamba/modeling_mamba.py diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 8ed17c5bb2..67595a7501 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -1,4 +1,4 @@ -base_model: state-spaces/mamba-130m +base_model: state-spaces/mamba-2.8b model_type: MambaLMHeadModel tokenizer_type: AutoTokenizer tokenizer_config: EleutherAI/gpt-neox-20b @@ -25,11 +25,11 @@ wandb_name: wandb_log_model: gradient_accumulation_steps: 4 -micro_batch_size: 2 +micro_batch_size: 1 num_epochs: 4 -optimizer: adamw_bnb_8bit +optimizer: paged_adamw_8bit lr_scheduler: cosine -learning_rate: 3e-7 +learning_rate: 5e-5 train_on_inputs: false group_by_length: false @@ -57,6 +57,4 @@ weight_decay: 0.0 fsdp: fsdp_config: special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" + pad_token: "<|endoftext|>" diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ce54824c69..de27b4283d 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -31,7 +31,10 @@ bench_eval_callback_factory, log_prediction_callback_factory, ) -from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.collators import ( + BatchSamplerDataCollatorForSeq2Seq, + MambaDataCollator, +) from axolotl.utils.samplers import MultipackBatchSampler from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup @@ -49,6 +52,9 @@ class AxolotlTrainingArguments(TrainingArguments): Extend the base TrainingArguments for axolotl helpers """ + model_type: Optional[str] = field( + default=None, metadata={"help": "HF model configuration model_type."} + ) lr_quadratic_warmup: bool = field( default=False, metadata={"help": "Use quadratic warmup for cosine scheduling."}, @@ -282,10 +288,29 @@ def compute_loss(self, model, inputs, return_outputs=False): # outputs = model(**inputs) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # return (loss, outputs) if return_outputs else loss - loss = super().compute_loss(model, inputs, return_outputs=return_outputs) - if loss.numel() > 1: - loss = loss.mean() - return loss + if self.args.model_type == "mamba": + return self.compute_mamba_loss(model, inputs, return_outputs=return_outputs) + return super().compute_loss(model, inputs, return_outputs=return_outputs) + + def compute_mamba_loss( + self, + model, + inputs, + return_outputs=False, # pylint: disable=unused-argument + ): + input_ids = inputs.pop("input_ids") + lm_logits = model(input_ids).logits + + labels = input_ids.to(lm_logits.device) + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss() + lm_loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) + ) + + return lm_loss class OneCycleLRSchedulerTrainer(AxolotlTrainer): @@ -734,11 +759,7 @@ def build(self, total_num_steps): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - data_collator=BatchSamplerDataCollatorForSeq2Seq( - self.tokenizer, - return_tensors="pt", - **data_collator_kwargs, - ), + data_collator=self.build_collator(**data_collator_kwargs), bench_data_collator=transformers.DataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt", @@ -758,3 +779,13 @@ def build(self, total_num_steps): ] = self.cfg.micro_batch_size return trainer + + def build_collator(self, **kwargs): + if self.cfg.model_config_type == "mamba": + return MambaDataCollator(tokenizer=self.tokenizer) + + return BatchSamplerDataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + **kwargs, + ) diff --git a/src/axolotl/models/mamba/__init__.py b/src/axolotl/models/mamba/__init__.py index 7f395ff4c6..247c1d184b 100644 --- a/src/axolotl/models/mamba/__init__.py +++ b/src/axolotl/models/mamba/__init__.py @@ -1,49 +1,12 @@ -# pylint: skip-file - -from collections import namedtuple - -from torch.nn import CrossEntropyLoss +""" +Modeling module for Mamba models +""" def fix_mamba_attn_for_loss(): - from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel - - MambaLMHeadModel.forward = mamba_forward - return MambaLMHeadModel # pylint: disable=invalid-name - - -def mamba_forward( - self, - input_ids, - position_ids=None, - inference_params=None, - num_last_tokens=0, - labels=None, -): - """ - "position_ids" is just to be compatible with Transformer generation. We don't use it. - num_last_tokens: if > 0, only return the logits for the last n tokens - """ - hidden_states = self.backbone(input_ids, inference_params=inference_params) - if num_last_tokens > 0: - hidden_states = hidden_states[:, -num_last_tokens:] - lm_logits = self.lm_head(hidden_states) + from mamba_ssm.models import mixer_seq_simple - loss = None - if labels is not None: - logits = lm_logits - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - return (loss,) + from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed - else: - CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) - return CausalLMOutput(logits=lm_logits) + mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed + return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name diff --git a/src/axolotl/models/mamba/configuration_mamba.py b/src/axolotl/models/mamba/configuration_mamba.py new file mode 100644 index 0000000000..5160ee8d7e --- /dev/null +++ b/src/axolotl/models/mamba/configuration_mamba.py @@ -0,0 +1,42 @@ +""" +HF Transformers MambaConfig +""" +from transformers import PretrainedConfig + + +class MambaConfig(PretrainedConfig): + """ + modeling configuration for state space model/mamba + """ + + model_type = "mamba" + + def __init__( + self, + vocab_size=50280, + d_model=2560, + n_layer=64, + rms_norm=True, + residual_in_fp32=True, + fused_add_norm=True, + pad_vocab_size_multiple=8, + pad_token_id=50277, + bos_token_id=0, + eos_token_id=0, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.n_layer = n_layer + self.rms_norm = rms_norm + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.pad_vocab_size_multiple = pad_vocab_size_multiple + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/src/axolotl/models/mamba/modeling_mamba.py b/src/axolotl/models/mamba/modeling_mamba.py new file mode 100644 index 0000000000..2c2fc24787 --- /dev/null +++ b/src/axolotl/models/mamba/modeling_mamba.py @@ -0,0 +1,116 @@ +# pylint: skip-file + +from collections import namedtuple +from functools import partial + +from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights +from mamba_ssm.utils.generation import GenerationMixin +from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf +from torch import nn +from torch.nn import CrossEntropyLoss + +from axolotl.models.mamba.configuration_mamba import MambaConfig + + +class MambaLMHeadModel(nn.Module, GenerationMixin): + def __init__( + self, + d_model: int, + n_layer: int, + vocab_size: int, + initializer_cfg=None, + pad_vocab_size_multiple: int = 1, + device=None, + dtype=None, + **backbone_kwargs, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - ( + vocab_size % pad_vocab_size_multiple + ) + self.config = MambaConfig( + vocab_size=vocab_size, + d_model=d_model, + n_layer=n_layer, + pad_vocab_size_multiple=pad_vocab_size_multiple, + ) + self.backbone = MixerModel( + d_model=d_model, + n_layer=n_layer, + vocab_size=vocab_size, + initializer_cfg=initializer_cfg, + **backbone_kwargs, + **factory_kwargs, + ) + self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) + + # Initialize weights and apply final processing + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + self.tie_weights() + + def tie_weights(self): + self.lm_head.weight = self.backbone.embedding.weight + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.backbone.allocate_inference_cache( + batch_size, max_seqlen, dtype=dtype, **kwargs + ) + + def forward( + self, + input_ids, + position_ids=None, + inference_params=None, + num_last_tokens=0, + labels=None, + **kwargs, + ): + """ + "position_ids" is just to be compatible with Transformer generation. We don't use it. + num_last_tokens: if > 0, only return the logits for the last n tokens + """ + hidden_states = self.backbone(input_ids, inference_params=inference_params) + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] + lm_logits = self.lm_head(hidden_states) + + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) + return CausalLMOutput(logits=lm_logits) + + loss = None + if labels is not None: + logits = lm_logits + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"]) + print(loss) + return CausalLMOutput(logits=lm_logits, loss=loss) + + else: + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) + return CausalLMOutput(logits=lm_logits) + + @classmethod + def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): + config = load_config_hf(pretrained_model_name) + model = cls(**config, device=device, dtype=dtype, **kwargs) + model.load_state_dict( + load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype) + ) + return model diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index ffae3f2631..303a0af0f3 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -2,12 +2,16 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences """ from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Sequence, Union import numpy as np +import torch +import transformers from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy +IGNORE_INDEX = -100 + @dataclass class DataCollatorForSeq2Seq: @@ -146,3 +150,31 @@ def __call__(self, features, return_tensors=None): chunked_data[feature] = np.concatenate(arrays) features = [chunked_data] return super().__call__(features, return_tensors=return_tensors) + + +@dataclass +class MambaDataCollator: + """ + Collator for State Space Models (Mamba) + """ + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple( + [torch.LongTensor(instance[key]) for instance in instances] + for key in ("input_ids", "labels") + ) + input_ids = torch.nn.utils.rnn.pad_sequence( + torch.Tensor(input_ids), + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) + labels = torch.nn.utils.rnn.pad_sequence( + torch.Tensor(labels), batch_first=True, padding_value=IGNORE_INDEX + ) + + return { + "input_ids": input_ids, + "labels": labels, + } diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e3752c2eb6..27a5885070 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -54,15 +54,19 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig): def load_model_config(cfg): model_config_name = cfg.base_model_config or cfg.base_model trust_remote_code = cfg.trust_remote_code is True - if "state-spaces/mamba" in model_config_name: - return addict.Dict( - { - "model_type": "mamba", - } + try: + model_config = AutoConfig.from_pretrained( + model_config_name, trust_remote_code=trust_remote_code ) - model_config = AutoConfig.from_pretrained( - model_config_name, trust_remote_code=trust_remote_code - ) + except ValueError as err: + if "mamba" in model_config_name: + return addict.Dict( + { + "model_type": "mamba", + } + ) + raise err + if cfg.model_config: for key, val in cfg.model_config.items(): setattr(model_config, key, val) @@ -343,7 +347,7 @@ def load_model( ) elif model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work - MambaLMHeadModel = fix_mamba_attn_for_loss() + MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name model_kwargs["dtype"] = model_kwargs["torch_dtype"] model_kwargs["device"] = torch.cuda.current_device() diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 22a3ab8c5a..590861cc0b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -140,11 +140,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): if eval_dataset: eval_dataset = eval_dataset.remove_columns("attention_mask") - if cfg.model_config_type == "mamba": - train_dataset = train_dataset.remove_columns("labels") - if eval_dataset: - eval_dataset = eval_dataset.remove_columns("labels") - return train_dataset, eval_dataset From 54a207ad73dfcfeccaa192df926d6b40f798c5fc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 7 Dec 2023 12:12:52 -0800 Subject: [PATCH 09/15] fix collator for mamba --- setup.py | 4 ++-- src/axolotl/utils/collators.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 5d32c8b283..0e1b3d6b91 100644 --- a/setup.py +++ b/setup.py @@ -51,8 +51,8 @@ def parse_requirements(): "deepspeed": [ "deepspeed", ], - "mamba": [ - "mamba-ssm @ git+https://github.com/OpenAccess-AI-Collective/mamba.git@model-kwargs", + "mamba-ssm": [ + "mamba-ssm==1.0.1", ], }, ) diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index 303a0af0f3..0f0eb5a95a 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -166,12 +166,12 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: for key in ("input_ids", "labels") ) input_ids = torch.nn.utils.rnn.pad_sequence( - torch.Tensor(input_ids), + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id, ) labels = torch.nn.utils.rnn.pad_sequence( - torch.Tensor(labels), batch_first=True, padding_value=IGNORE_INDEX + labels, batch_first=True, padding_value=IGNORE_INDEX ) return { From 563251a9e66f647ad3d7ec7a7f0e1b76a34d2d87 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 7 Dec 2023 12:16:12 -0800 Subject: [PATCH 10/15] set model_type on training_args --- src/axolotl/core/trainer_builder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index de27b4283d..063a78dfd4 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -705,6 +705,7 @@ def build(self, total_num_steps): training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs ) + training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_args = ( AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg **training_arguments_kwargs, From cdc7b6914e3201d87e204f770be8686004506b99 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 7 Dec 2023 12:39:36 -0800 Subject: [PATCH 11/15] don't save safetensors for mamba --- examples/mamba/config.yml | 5 +++-- src/axolotl/core/trainer_builder.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 67595a7501..3b9b15b3a5 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -11,8 +11,9 @@ datasets: - path: mhenrichsen/alpaca_2k_test type: alpaca dataset_prepared_path: -val_set_size: 0.05 +val_set_size: 0.0 output_dir: ./out +save_safetensors: false sequence_len: 2048 sample_packing: false @@ -47,7 +48,7 @@ xformers_attention: flash_attention: warmup_steps: 10 -eval_steps: 0.05 +eval_steps: eval_table_size: eval_table_max_new_tokens: 128 save_steps: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 063a78dfd4..1b037420c5 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -288,11 +288,15 @@ def compute_loss(self, model, inputs, return_outputs=False): # outputs = model(**inputs) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # return (loss, outputs) if return_outputs else loss - if self.args.model_type == "mamba": - return self.compute_mamba_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs) - def compute_mamba_loss( + +class AxolotlMambaTrainer(AxolotlTrainer): + """ + Mamba specific trainer to handle loss calculation + """ + + def compute_loss( self, model, inputs, @@ -490,6 +494,8 @@ def _get_trainer_cls(self): return OneCycleLRSchedulerTrainer if self.cfg.relora_steps: return ReLoRATrainer + if self.cfg.model_config_type == "mamba": + return AxolotlMambaTrainer return AxolotlTrainer def build(self, total_num_steps): @@ -557,7 +563,7 @@ def build(self, total_num_steps): if self.cfg.hub_strategy: training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy - if self.cfg.save_safetensors: + if self.cfg.save_safetensors is not None: training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors if self.cfg.sample_packing_eff_est: From 046cd3b300b1c877dbf4e5be4078d8870be1e815 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 9 Dec 2023 09:14:31 -0500 Subject: [PATCH 12/15] update mamba config to disable safetensor checkpooints, install for tests --- .github/workflows/tests.yml | 2 +- examples/mamba/config.yml | 12 ++++++------ tests/e2e/test_mamba.py | 1 + 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9103126ce1..ad2cb428b0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -73,7 +73,7 @@ jobs: run: | pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1 pip3 uninstall -y transformers accelerate - pip3 install -U -e .[flash-attn] + pip3 install -U -e .[flash-attn,mamba-ssm] pip3 install -r requirements-tests.txt - name: Run e2e tests diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 3b9b15b3a5..c2e5a851f8 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -13,7 +13,6 @@ datasets: dataset_prepared_path: val_set_size: 0.0 output_dir: ./out -save_safetensors: false sequence_len: 2048 sample_packing: false @@ -27,15 +26,15 @@ wandb_log_model: gradient_accumulation_steps: 4 micro_batch_size: 1 -num_epochs: 4 +num_epochs: 2 optimizer: paged_adamw_8bit lr_scheduler: cosine learning_rate: 5e-5 train_on_inputs: false -group_by_length: false +group_by_length: true -bf16: false +bf16: true fp16: false tf32: true @@ -51,11 +50,12 @@ warmup_steps: 10 eval_steps: eval_table_size: eval_table_max_new_tokens: 128 -save_steps: +save_steps: 0.25 debug: deepspeed: weight_decay: 0.0 fsdp: fsdp_config: special_tokens: - pad_token: "<|endoftext|>" +tokens: +save_safetensors: False diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index 9f688339d7..f65530e877 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -54,6 +54,7 @@ def test_fft(self, temp_dir): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_safetensors": False, } ) normalize_config(cfg) From 937d8e825408918ec010be776591d3af503cc322 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 9 Dec 2023 11:12:03 -0500 Subject: [PATCH 13/15] no evals for mamba tests --- tests/e2e/test_mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index f65530e877..463b0ddac0 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -36,7 +36,7 @@ def test_fft(self, temp_dir): "flash_attention": False, "sequence_len": 1024, "load_in_8bit": False, - "val_set_size": 0.1, + "val_set_size": 0.0, "datasets": [ { "path": "mhenrichsen/alpaca_2k_test", @@ -53,7 +53,7 @@ def test_fft(self, temp_dir): "lr_scheduler": "cosine", "max_steps": 20, "save_steps": 10, - "eval_steps": 10, + "eval_steps": None, "save_safetensors": False, } ) From 2df422eda465ec3a0c8e11425d709b23015fc9e6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 9 Dec 2023 11:37:08 -0500 Subject: [PATCH 14/15] handle save_pretrained --- src/axolotl/models/mamba/modeling_mamba.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/axolotl/models/mamba/modeling_mamba.py b/src/axolotl/models/mamba/modeling_mamba.py index 2c2fc24787..5e3c2219f0 100644 --- a/src/axolotl/models/mamba/modeling_mamba.py +++ b/src/axolotl/models/mamba/modeling_mamba.py @@ -1,8 +1,10 @@ # pylint: skip-file - +import os from collections import namedtuple from functools import partial +from typing import Optional, Union +import torch from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights from mamba_ssm.utils.generation import GenerationMixin from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf @@ -106,6 +108,15 @@ def forward( CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) return CausalLMOutput(logits=lm_logits) + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + if state_dict is None: + state_dict = self.state_dict() + torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin")) + @classmethod def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): config = load_config_hf(pretrained_model_name) From 52d46b47f2b8c8f9ddefb3038ba869161b01e94b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 9 Dec 2023 11:59:11 -0500 Subject: [PATCH 15/15] handle unused safetensors arg --- src/axolotl/models/mamba/modeling_mamba.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/axolotl/models/mamba/modeling_mamba.py b/src/axolotl/models/mamba/modeling_mamba.py index 5e3c2219f0..70e9c88c88 100644 --- a/src/axolotl/models/mamba/modeling_mamba.py +++ b/src/axolotl/models/mamba/modeling_mamba.py @@ -112,6 +112,7 @@ def save_pretrained( self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, + safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument ): if state_dict is None: state_dict = self.state_dict()