diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index ca728b654e8c21..40e728850d9cb8 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -24,15 +24,21 @@ from ...cache_utils import Cache, StaticCache from ...modeling_flash_attention_utils import _flash_attention_forward +from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( is_flash_attn_greater_or_equal_2_10, logging, ) from ..gemma.modeling_gemma import GemmaForCausalLM from ..llama.modeling_llama import ( + LlamaDecoderLayer, LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv, ) @@ -430,6 +436,8 @@ class DiffLlamaForTokenClassification(LlamaForTokenClassification): __all__ = [ + "DiffLlamaPreTrainedModel", + "DiffLlamaModel", "DiffLlamaForCausalLM", "DiffLlamaForSequenceClassification", "DiffLlamaForQuestionAnswering",