diff --git a/src/axolotl/models/mamba/__init__.py b/src/axolotl/models/mamba/__init__.py index 247c1d184b..6bea5b6b13 100644 --- a/src/axolotl/models/mamba/__init__.py +++ b/src/axolotl/models/mamba/__init__.py @@ -2,6 +2,16 @@ Modeling module for Mamba models """ +import importlib + + +def check_mamba_ssm_installed(): + mamba_ssm_spec = importlib.util.find_spec("mamba_ssm") + if mamba_ssm_spec is None: + raise ImportError( + "MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm,flash-attn]`" + ) + def fix_mamba_attn_for_loss(): from mamba_ssm.models import mixer_seq_simple @@ -10,3 +20,6 @@ def fix_mamba_attn_for_loss(): mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name + + +check_mamba_ssm_installed()