Skip to content

Commit

Permalink
support for mamba (axolotl-ai-cloud#915)
Browse files Browse the repository at this point in the history
* support for mamba

* more mamba fixes

* use fork for mamba kwargs fix

* grad checkpointing doesn't work

* fix extras for mamaba

* mamba loss fix

* use fp32 and remove verbose logging

* mamba fixes

* fix collator for mamba

* set model_type on training_args

* don't save safetensors for mamba

* update mamba config to disable safetensor checkpooints, install for tests

* no evals for mamba tests

* handle save_pretrained

* handle unused safetensors arg
  • Loading branch information
winglian authored Dec 9, 2023
1 parent 465b77d commit 1c544c8
Show file tree
Hide file tree
Showing 12 changed files with 447 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
run: |
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
pip3 uninstall -y transformers accelerate
pip3 install -U -e .[flash-attn]
pip3 install -U -e .[flash-attn,mamba-ssm]
pip3 install -r requirements-tests.txt
- name: Run e2e tests
Expand Down
61 changes: 61 additions & 0 deletions examples/mamba/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
base_model: state-spaces/mamba-2.8b
model_type: MambaLMHeadModel
tokenizer_type: AutoTokenizer
tokenizer_config: EleutherAI/gpt-neox-20b

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./out

sequence_len: 2048
sample_packing: false
pad_to_sequence_len: false

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 5e-5

train_on_inputs: false
group_by_length: true

bf16: true
fp16: false
tf32: true

gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:

warmup_steps: 10
eval_steps:
eval_table_size:
eval_table_max_new_tokens: 128
save_steps: 0.25
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
tokens:
save_safetensors: False
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,8 @@ def parse_requirements():
"deepspeed": [
"deepspeed",
],
"mamba-ssm": [
"mamba-ssm==1.0.1",
],
},
)
55 changes: 48 additions & 7 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
MambaDataCollator,
)
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup

Expand All @@ -49,6 +52,9 @@ class AxolotlTrainingArguments(TrainingArguments):
Extend the base TrainingArguments for axolotl helpers
"""

model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
Expand Down Expand Up @@ -285,6 +291,32 @@ def compute_loss(self, model, inputs, return_outputs=False):
return super().compute_loss(model, inputs, return_outputs=return_outputs)


class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""

def compute_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits

labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()

loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)

return lm_loss


class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
Expand Down Expand Up @@ -462,6 +494,8 @@ def _get_trainer_cls(self):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
return AxolotlTrainer

def build(self, total_num_steps):
Expand Down Expand Up @@ -529,7 +563,7 @@ def build(self, total_num_steps):
if self.cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy

if self.cfg.save_safetensors:
if self.cfg.save_safetensors is not None:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors

if self.cfg.sample_packing_eff_est:
Expand Down Expand Up @@ -677,6 +711,7 @@ def build(self, total_num_steps):
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
Expand Down Expand Up @@ -731,11 +766,7 @@ def build(self, total_num_steps):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
data_collator=self.build_collator(**data_collator_kwargs),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
Expand All @@ -755,3 +786,13 @@ def build(self, total_num_steps):
] = self.cfg.micro_batch_size

return trainer

def build_collator(self, **kwargs):
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)

return BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**kwargs,
)
12 changes: 12 additions & 0 deletions src/axolotl/models/mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Modeling module for Mamba models
"""


def fix_mamba_attn_for_loss():
from mamba_ssm.models import mixer_seq_simple

from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed

mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
42 changes: 42 additions & 0 deletions src/axolotl/models/mamba/configuration_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
HF Transformers MambaConfig
"""
from transformers import PretrainedConfig


class MambaConfig(PretrainedConfig):
"""
modeling configuration for state space model/mamba
"""

model_type = "mamba"

def __init__(
self,
vocab_size=50280,
d_model=2560,
n_layer=64,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=8,
pad_token_id=50277,
bos_token_id=0,
eos_token_id=0,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layer = n_layer
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
128 changes: 128 additions & 0 deletions src/axolotl/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# pylint: skip-file
import os
from collections import namedtuple
from functools import partial
from typing import Optional, Union

import torch
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
from torch import nn
from torch.nn import CrossEntropyLoss

from axolotl.models.mamba.configuration_mamba import MambaConfig


class MambaLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
d_model: int,
n_layer: int,
vocab_size: int,
initializer_cfg=None,
pad_vocab_size_multiple: int = 1,
device=None,
dtype=None,
**backbone_kwargs,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (
vocab_size % pad_vocab_size_multiple
)
self.config = MambaConfig(
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layer,
pad_vocab_size_multiple=pad_vocab_size_multiple,
)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
vocab_size=vocab_size,
initializer_cfg=initializer_cfg,
**backbone_kwargs,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)

# Initialize weights and apply final processing
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()

def tie_weights(self):
self.lm_head.weight = self.backbone.embedding.weight

def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)

def forward(
self,
input_ids,
position_ids=None,
inference_params=None,
num_last_tokens=0,
labels=None,
**kwargs,
):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)

CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)

loss = None
if labels is not None:
logits = lm_logits
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
print(loss)
return CausalLMOutput(logits=lm_logits, loss=loss)

else:
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
):
if state_dict is None:
state_dict = self.state_dict()
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))

@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config = load_config_hf(pretrained_model_name)
model = cls(**config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
)
return model
6 changes: 4 additions & 2 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def train(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)

model.config.use_cache = False
if hasattr(model, "config"):
model.config.use_cache = False

# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
Expand All @@ -92,7 +93,8 @@ def train(
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
model.config.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"):
model.config.save_pretrained(str(Path(cfg.output_dir)))

# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
Expand Down
Loading

0 comments on commit 1c544c8

Please sign in to comment.