-
-
Notifications
You must be signed in to change notification settings - Fork 920
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* support for mamba * more mamba fixes * use fork for mamba kwargs fix * grad checkpointing doesn't work * fix extras for mamaba * mamba loss fix * use fp32 and remove verbose logging * mamba fixes * fix collator for mamba * set model_type on training_args * don't save safetensors for mamba * update mamba config to disable safetensor checkpooints, install for tests * no evals for mamba tests * handle save_pretrained * handle unused safetensors arg
- Loading branch information
Showing
12 changed files
with
447 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
base_model: state-spaces/mamba-2.8b | ||
model_type: MambaLMHeadModel | ||
tokenizer_type: AutoTokenizer | ||
tokenizer_config: EleutherAI/gpt-neox-20b | ||
|
||
load_in_8bit: false | ||
load_in_4bit: false | ||
strict: false | ||
|
||
datasets: | ||
- path: mhenrichsen/alpaca_2k_test | ||
type: alpaca | ||
dataset_prepared_path: | ||
val_set_size: 0.0 | ||
output_dir: ./out | ||
|
||
sequence_len: 2048 | ||
sample_packing: false | ||
pad_to_sequence_len: false | ||
|
||
wandb_project: | ||
wandb_entity: | ||
wandb_watch: | ||
wandb_name: | ||
wandb_log_model: | ||
|
||
gradient_accumulation_steps: 4 | ||
micro_batch_size: 1 | ||
num_epochs: 2 | ||
optimizer: paged_adamw_8bit | ||
lr_scheduler: cosine | ||
learning_rate: 5e-5 | ||
|
||
train_on_inputs: false | ||
group_by_length: true | ||
|
||
bf16: true | ||
fp16: false | ||
tf32: true | ||
|
||
gradient_checkpointing: false | ||
early_stopping_patience: | ||
resume_from_checkpoint: | ||
local_rank: | ||
logging_steps: 1 | ||
xformers_attention: | ||
flash_attention: | ||
|
||
warmup_steps: 10 | ||
eval_steps: | ||
eval_table_size: | ||
eval_table_max_new_tokens: 128 | ||
save_steps: 0.25 | ||
debug: | ||
deepspeed: | ||
weight_decay: 0.0 | ||
fsdp: | ||
fsdp_config: | ||
special_tokens: | ||
tokens: | ||
save_safetensors: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,5 +51,8 @@ def parse_requirements(): | |
"deepspeed": [ | ||
"deepspeed", | ||
], | ||
"mamba-ssm": [ | ||
"mamba-ssm==1.0.1", | ||
], | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
Modeling module for Mamba models | ||
""" | ||
|
||
|
||
def fix_mamba_attn_for_loss(): | ||
from mamba_ssm.models import mixer_seq_simple | ||
|
||
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed | ||
|
||
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed | ||
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
""" | ||
HF Transformers MambaConfig | ||
""" | ||
from transformers import PretrainedConfig | ||
|
||
|
||
class MambaConfig(PretrainedConfig): | ||
""" | ||
modeling configuration for state space model/mamba | ||
""" | ||
|
||
model_type = "mamba" | ||
|
||
def __init__( | ||
self, | ||
vocab_size=50280, | ||
d_model=2560, | ||
n_layer=64, | ||
rms_norm=True, | ||
residual_in_fp32=True, | ||
fused_add_norm=True, | ||
pad_vocab_size_multiple=8, | ||
pad_token_id=50277, | ||
bos_token_id=0, | ||
eos_token_id=0, | ||
tie_word_embeddings=False, | ||
**kwargs, | ||
): | ||
self.vocab_size = vocab_size | ||
self.d_model = d_model | ||
self.n_layer = n_layer | ||
self.rms_norm = rms_norm | ||
self.residual_in_fp32 = residual_in_fp32 | ||
self.fused_add_norm = fused_add_norm | ||
self.pad_vocab_size_multiple = pad_vocab_size_multiple | ||
super().__init__( | ||
pad_token_id=pad_token_id, | ||
bos_token_id=bos_token_id, | ||
eos_token_id=eos_token_id, | ||
tie_word_embeddings=tie_word_embeddings, | ||
**kwargs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# pylint: skip-file | ||
import os | ||
from collections import namedtuple | ||
from functools import partial | ||
from typing import Optional, Union | ||
|
||
import torch | ||
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights | ||
from mamba_ssm.utils.generation import GenerationMixin | ||
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf | ||
from torch import nn | ||
from torch.nn import CrossEntropyLoss | ||
|
||
from axolotl.models.mamba.configuration_mamba import MambaConfig | ||
|
||
|
||
class MambaLMHeadModel(nn.Module, GenerationMixin): | ||
def __init__( | ||
self, | ||
d_model: int, | ||
n_layer: int, | ||
vocab_size: int, | ||
initializer_cfg=None, | ||
pad_vocab_size_multiple: int = 1, | ||
device=None, | ||
dtype=None, | ||
**backbone_kwargs, | ||
) -> None: | ||
factory_kwargs = {"device": device, "dtype": dtype} | ||
super().__init__() | ||
if vocab_size % pad_vocab_size_multiple != 0: | ||
vocab_size += pad_vocab_size_multiple - ( | ||
vocab_size % pad_vocab_size_multiple | ||
) | ||
self.config = MambaConfig( | ||
vocab_size=vocab_size, | ||
d_model=d_model, | ||
n_layer=n_layer, | ||
pad_vocab_size_multiple=pad_vocab_size_multiple, | ||
) | ||
self.backbone = MixerModel( | ||
d_model=d_model, | ||
n_layer=n_layer, | ||
vocab_size=vocab_size, | ||
initializer_cfg=initializer_cfg, | ||
**backbone_kwargs, | ||
**factory_kwargs, | ||
) | ||
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) | ||
|
||
# Initialize weights and apply final processing | ||
self.apply( | ||
partial( | ||
_init_weights, | ||
n_layer=n_layer, | ||
**(initializer_cfg if initializer_cfg is not None else {}), | ||
) | ||
) | ||
self.tie_weights() | ||
|
||
def tie_weights(self): | ||
self.lm_head.weight = self.backbone.embedding.weight | ||
|
||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): | ||
return self.backbone.allocate_inference_cache( | ||
batch_size, max_seqlen, dtype=dtype, **kwargs | ||
) | ||
|
||
def forward( | ||
self, | ||
input_ids, | ||
position_ids=None, | ||
inference_params=None, | ||
num_last_tokens=0, | ||
labels=None, | ||
**kwargs, | ||
): | ||
""" | ||
"position_ids" is just to be compatible with Transformer generation. We don't use it. | ||
num_last_tokens: if > 0, only return the logits for the last n tokens | ||
""" | ||
hidden_states = self.backbone(input_ids, inference_params=inference_params) | ||
if num_last_tokens > 0: | ||
hidden_states = hidden_states[:, -num_last_tokens:] | ||
lm_logits = self.lm_head(hidden_states) | ||
|
||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) | ||
return CausalLMOutput(logits=lm_logits) | ||
|
||
loss = None | ||
if labels is not None: | ||
logits = lm_logits | ||
# Shift so that tokens < n predict n | ||
shift_logits = logits[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
# Flatten the tokens | ||
loss_fct = CrossEntropyLoss() | ||
shift_logits = shift_logits.view(-1, self.config.vocab_size) | ||
shift_labels = shift_labels.view(-1) | ||
# Enable model parallelism | ||
shift_labels = shift_labels.to(shift_logits.device) | ||
loss = loss_fct(shift_logits, shift_labels) | ||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"]) | ||
print(loss) | ||
return CausalLMOutput(logits=lm_logits, loss=loss) | ||
|
||
else: | ||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) | ||
return CausalLMOutput(logits=lm_logits) | ||
|
||
def save_pretrained( | ||
self, | ||
save_directory: Union[str, os.PathLike], | ||
state_dict: Optional[dict] = None, | ||
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument | ||
): | ||
if state_dict is None: | ||
state_dict = self.state_dict() | ||
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin")) | ||
|
||
@classmethod | ||
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): | ||
config = load_config_hf(pretrained_model_name) | ||
model = cls(**config, device=device, dtype=dtype, **kwargs) | ||
model.load_state_dict( | ||
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype) | ||
) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.