From 8df7b888ff27d93a879415b7348031bf49001e21 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 14 Mar 2024 15:52:23 -0400 Subject: [PATCH] beta support for multipack with gemmoe: (#1402) --- src/axolotl/monkeypatch/multipack.py | 18 +++++++++++++++++- src/axolotl/utils/models.py | 2 +- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 964b41f707..fbcaf7a668 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -1,6 +1,9 @@ """multipack patching for v2 of sample packing""" +import importlib import transformers +from accelerate import init_empty_weights +from transformers import AutoConfig, AutoModelForCausalLM from transformers.integrations import is_deepspeed_zero3_enabled from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 @@ -12,11 +15,12 @@ "falcon", "phi", "gemma", + "gemmoe", "starcoder2", ] -def patch_for_multipack(model_type): +def patch_for_multipack(model_type, model_name=None): if model_type == "mixtral": transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data @@ -43,3 +47,15 @@ def patch_for_multipack(model_type): transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) + elif model_type == "gemmoe": + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + # we need to load the model here in order for modeling_gemmoe to be available + with init_empty_weights(): + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + module_name = model_config.__class__.__module__.replace( + ".configuration_gemmoe", ".modeling_gemmoe" + ) + modeling_gemmoe = importlib.import_module(module_name) + modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 53201c9968..fce7b20a7a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -429,7 +429,7 @@ def load_model( and cfg.flash_attention and cfg.sample_packing ): - patch_for_multipack(cfg.model_config_type) + patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) elif cfg.is_llama_derived_model: # Modify all llama derived models in one block