diff --git a/requirements.txt b/requirements.txt index d1fdccaf77..c7f5c9433c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ tensorboard python-dotenv==1.0.1 autoawq>=0.2.5 triton>=2.3.0 -liger-kernel==0.4.0 +liger-kernel==0.4.1 mamba-ssm==1.2.0.post1 diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index a64d748c67..fda98e469f 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -23,6 +23,7 @@ import sys from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb @@ -82,7 +83,9 @@ def pre_model_load(self, cfg): if cfg.liger_glu_activation: modeling_jamba.JambaMLP = LigerSwiGLUMLP if cfg.liger_cross_entropy: - modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy if cfg.liger_fused_linear_cross_entropy: modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward elif cfg.model_config_type == "deepseek_v2": @@ -106,6 +109,8 @@ def pre_model_load(self, cfg): if cfg.liger_glu_activation: modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward if cfg.liger_cross_entropy: + # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses + # nn.CrossEntropyLoss in the forward method. modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward