Skip to content

Commit

Permalink
mamba fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 7, 2023
1 parent b7f34d6 commit 4daa0bd
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 76 deletions.
12 changes: 5 additions & 7 deletions examples/mamba/config.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
base_model: state-spaces/mamba-130m
base_model: state-spaces/mamba-2.8b
model_type: MambaLMHeadModel
tokenizer_type: AutoTokenizer
tokenizer_config: EleutherAI/gpt-neox-20b
Expand All @@ -25,11 +25,11 @@ wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_bnb_8bit
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 3e-7
learning_rate: 5e-5

train_on_inputs: false
group_by_length: false
Expand Down Expand Up @@ -57,6 +57,4 @@ weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
pad_token: "<|endoftext|>"
51 changes: 41 additions & 10 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
MambaDataCollator,
)
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup

Expand All @@ -49,6 +52,9 @@ class AxolotlTrainingArguments(TrainingArguments):
Extend the base TrainingArguments for axolotl helpers
"""

model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
Expand Down Expand Up @@ -282,10 +288,29 @@ def compute_loss(self, model, inputs, return_outputs=False):
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
loss = super().compute_loss(model, inputs, return_outputs=return_outputs)
if loss.numel() > 1:
loss = loss.mean()
return loss
if self.args.model_type == "mamba":
return self.compute_mamba_loss(model, inputs, return_outputs=return_outputs)
return super().compute_loss(model, inputs, return_outputs=return_outputs)

def compute_mamba_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits

labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()

loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)

return lm_loss


class OneCycleLRSchedulerTrainer(AxolotlTrainer):
Expand Down Expand Up @@ -734,11 +759,7 @@ def build(self, total_num_steps):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
data_collator=self.build_collator(**data_collator_kwargs),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
Expand All @@ -758,3 +779,13 @@ def build(self, total_num_steps):
] = self.cfg.micro_batch_size

return trainer

def build_collator(self, **kwargs):
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)

return BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**kwargs,
)
51 changes: 7 additions & 44 deletions src/axolotl/models/mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,12 @@
# pylint: skip-file

from collections import namedtuple

from torch.nn import CrossEntropyLoss
"""
Modeling module for Mamba models
"""


def fix_mamba_attn_for_loss():
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

MambaLMHeadModel.forward = mamba_forward
return MambaLMHeadModel # pylint: disable=invalid-name


def mamba_forward(
self,
input_ids,
position_ids=None,
inference_params=None,
num_last_tokens=0,
labels=None,
):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
from mamba_ssm.models import mixer_seq_simple

loss = None
if labels is not None:
logits = lm_logits
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return (loss,)
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed

else:
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
42 changes: 42 additions & 0 deletions src/axolotl/models/mamba/configuration_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
HF Transformers MambaConfig
"""
from transformers import PretrainedConfig


class MambaConfig(PretrainedConfig):
"""
modeling configuration for state space model/mamba
"""

model_type = "mamba"

def __init__(
self,
vocab_size=50280,
d_model=2560,
n_layer=64,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=8,
pad_token_id=50277,
bos_token_id=0,
eos_token_id=0,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layer = n_layer
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
116 changes: 116 additions & 0 deletions src/axolotl/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# pylint: skip-file

from collections import namedtuple
from functools import partial

from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
from torch import nn
from torch.nn import CrossEntropyLoss

from axolotl.models.mamba.configuration_mamba import MambaConfig


class MambaLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
d_model: int,
n_layer: int,
vocab_size: int,
initializer_cfg=None,
pad_vocab_size_multiple: int = 1,
device=None,
dtype=None,
**backbone_kwargs,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (
vocab_size % pad_vocab_size_multiple
)
self.config = MambaConfig(
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layer,
pad_vocab_size_multiple=pad_vocab_size_multiple,
)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
vocab_size=vocab_size,
initializer_cfg=initializer_cfg,
**backbone_kwargs,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)

# Initialize weights and apply final processing
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()

def tie_weights(self):
self.lm_head.weight = self.backbone.embedding.weight

def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)

def forward(
self,
input_ids,
position_ids=None,
inference_params=None,
num_last_tokens=0,
labels=None,
**kwargs,
):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)

CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)

loss = None
if labels is not None:
logits = lm_logits
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
print(loss)
return CausalLMOutput(logits=lm_logits, loss=loss)

else:
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)

@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config = load_config_hf(pretrained_model_name)
model = cls(**config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
)
return model
34 changes: 33 additions & 1 deletion src/axolotl/utils/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union

import numpy as np
import torch
import transformers
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy

IGNORE_INDEX = -100


@dataclass
class DataCollatorForSeq2Seq:
Expand Down Expand Up @@ -146,3 +150,31 @@ def __call__(self, features, return_tensors=None):
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)


@dataclass
class MambaDataCollator:
"""
Collator for State Space Models (Mamba)
"""

tokenizer: transformers.PreTrainedTokenizer

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[torch.LongTensor(instance[key]) for instance in instances]
for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
torch.Tensor(input_ids),
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
torch.Tensor(labels), batch_first=True, padding_value=IGNORE_INDEX
)

return {
"input_ids": input_ids,
"labels": labels,
}
Loading

0 comments on commit 4daa0bd

Please sign in to comment.