Skip to content

Commit

Permalink
beta support for multipack with gemmoe: (#1402)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Mar 14, 2024
1 parent 7fafd9b commit 9a1db1e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
18 changes: 17 additions & 1 deletion src/axolotl/monkeypatch/multipack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""multipack patching for v2 of sample packing"""
import importlib

import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.integrations import is_deepspeed_zero3_enabled

from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
Expand All @@ -12,11 +15,12 @@
"falcon",
"phi",
"gemma",
"gemmoe",
"starcoder2",
]


def patch_for_multipack(model_type):
def patch_for_multipack(model_type, model_name=None):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
Expand All @@ -43,3 +47,15 @@ def patch_for_multipack(model_type):
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_gemmoe to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_gemmoe", ".modeling_gemmoe"
)
modeling_gemmoe = importlib.import_module(module_name)
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
2 changes: 1 addition & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def load_model(
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type)
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block

Expand Down

0 comments on commit 9a1db1e

Please sign in to comment.