Skip to content

Commit

Permalink
Add diffllama (#34083)
Browse files Browse the repository at this point in the history
* first adding diffllama

* add Diff Attention and other but still with errors

* complate make attention Diff-Attention

* fix some bugs which may be caused by transformer-cli while adding model

* fix a bug caused by forgetting KV cache...

* Update src/transformers/models/diffllama/modeling_diffllama.py

You don't need to divide by 2 if we use same number of attention heads as llama. instead you can just split in forward.

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

fit to changeing "num_heads // 2" place

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

new codes are more meaningful than before

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

new codes are more meaningful than before

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

fit to changeing "num_heads // 2" place

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

fix 2times divide by sqrt(self.head_dim)

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

fix 2times divide by sqrt(self.head_dim)

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

fit to changeing "num_heads // 2" place.
and more visible

Co-authored-by: Minho Ryu <[email protected]>

* I found Attention missed implemented from paper still on e072544.

* re-implemented

* adding groupnorm

Co-authored-by: Minho Ryu <[email protected]>

* align with transformers code style

Co-authored-by: Minho Ryu <[email protected]>

* fix typo

Co-authored-by: Minho Ryu <[email protected]>

* adding groupnorm

Co-authored-by: Minho Ryu <[email protected]>

* change SdpaAttention to DiffSdpaAttention

Co-authored-by: Minho Ryu <[email protected]>

* fix bug

* Update src/transformers/models/diffllama/modeling_diffllama.py

resolve "not same outputs" problem

Co-authored-by: Minho Ryu <[email protected]>

* fix bugs of places of "GroupNorm with scale" and etc

* Revert "fix bugs of places of "GroupNorm with scale" and etc"

This reverts commit 26307d9.

* simplify multiple of attention (matmul) operations into one by repeating value_states

Co-authored-by: Minho Ryu <[email protected]>

* simplify multiple of attention (matmul) operations into one by repeating value_states

Co-authored-by: Minho Ryu <[email protected]>

* simplify multiple of attention (matmul) operations into one by repeating value_states

Co-authored-by: Minho Ryu <[email protected]>

* remove missed type

* add diffllama model_doc

* apply make style/quality

* apply review comment about model

* apply review comment about test

* place diffllama alphabetically on the src/transformers/__init__.py

* fix forgot code

* Supports parameters that are not initialized with standard deviation 0 in the conventional method

* add DiffLlamaConfig to CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK on utils/check_config_docstrings.py

* remove unused property of config

* add to supported model list

* add to spda supported model list

* fix copyright, remove pretraining_tensor_parallel, and modify for initialization test

* remove unused import and etc.

* empty commit

* empty commit

* empty commit

* apply modular transformers but with bugs

* revert prev commit

* create src/transformers/model/diffllama/modular_diffllama.py

* run utils/modular_model_converter.py

* empty commit

* leaner modular diffllama

* remove more and more in modular_diffllama.pt

* remove more and more in modular_diffllama.pt

* resolve missing docstring entries

* force reset

* convert modular

---------

Co-authored-by: Minho Ryu <[email protected]>
  • Loading branch information
weak-kajuma and bzantium authored Jan 7, 2025
1 parent ed73ae2 commit 96bf3d6
Show file tree
Hide file tree
Showing 16 changed files with 3,249 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@
title: DeBERTa-v2
- local: model_doc/dialogpt
title: DialoGPT
- local: model_doc/diffllama
title: DiffLlama
- local: model_doc/distilbert
title: DistilBERT
- local: model_doc/dpr
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ Flax), PyTorch, and/or TensorFlow.
| [DETA](model_doc/deta) ||||
| [DETR](model_doc/detr) ||||
| [DialoGPT](model_doc/dialogpt) ||||
| [DiffLlama](model_doc/diffllama) ||||
| [DiNAT](model_doc/dinat) ||||
| [DINOv2](model_doc/dinov2) ||||
| [DINOv2 with Registers](model_doc/dinov2_with_registers) ||||
Expand Down
59 changes: 59 additions & 0 deletions docs/source/en/model_doc/diffllama.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# DiffLlama

## Overview

The DiffLlama model was proposed in [Differential Transformer](https://arxiv.org/abs/2410.05258) by Kazuma Matsumoto and .
This model is combine Llama model and Differential Transformer's Attention.

The abstract from the paper is the following:

*Transformer tends to overallocate attention to irrelevant context. In this work, we introduce Diff Transformer, which amplifies attention to the relevant context while canceling noise. Specifically, the differential attention mechanism calculates attention scores as the difference between two separate softmax attention maps. The subtraction cancels noise, promoting the emergence of sparse attention patterns. Experimental results on language modeling show that Diff Transformer outperforms Transformer in various settings of scaling up model size and training tokens. More intriguingly, it offers notable advantages in practical applications, such as long-context modeling, key information retrieval, hallucination mitigation, in-context learning, and reduction of activation outliers. By being less distracted by irrelevant context, Diff Transformer can mitigate hallucination in question answering and text summarization. For in-context learning, Diff Transformer not only enhances accuracy but is also more robust to order permutation, which was considered as a chronic robustness issue. The results position Diff Transformer as a highly effective and promising architecture to advance large language models.*

### Usage tips
The hyperparameters of this model is the same as Llama model.


## DiffLlamaConfig

[[autodoc]] DiffLlamaConfig

## DiffLlamaModel

[[autodoc]] DiffLlamaModel
- forward

## DiffLlamaForCausalLM

[[autodoc]] DiffLlamaForCausalLM
- forward

## DiffLlamaForSequenceClassification

[[autodoc]] DiffLlamaForSequenceClassification
- forward

## DiffLlamaForQuestionAnswering

[[autodoc]] DiffLlamaForQuestionAnswering
- forward

## DiffLlamaForTokenClassification

[[autodoc]] DiffLlamaForTokenClassification
- forward
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Cohere2](https://huggingface.co/docs/transformers/model_doc/cohere2#transformers.Cohere2Model)
* [GLM](https://huggingface.co/docs/transformers/model_doc/glm#transformers.GLMModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
Expand Down Expand Up @@ -237,6 +238,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [data2vec_vision](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecVisionModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
* [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel)
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
* [Dinov2_with_registers](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
Expand Down
20 changes: 20 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@
"models.depth_anything": ["DepthAnythingConfig"],
"models.detr": ["DetrConfig"],
"models.dialogpt": [],
"models.diffllama": ["DiffLlamaConfig"],
"models.dinat": ["DinatConfig"],
"models.dinov2": ["Dinov2Config"],
"models.dinov2_with_registers": ["Dinov2WithRegistersConfig"],
Expand Down Expand Up @@ -2145,6 +2146,16 @@
"DetrPreTrainedModel",
]
)
_import_structure["models.diffllama"].extend(
[
"DiffLlamaForCausalLM",
"DiffLlamaForQuestionAnswering",
"DiffLlamaForSequenceClassification",
"DiffLlamaForTokenClassification",
"DiffLlamaModel",
"DiffLlamaPreTrainedModel",
]
)
_import_structure["models.dinat"].extend(
[
"DinatBackbone",
Expand Down Expand Up @@ -5369,6 +5380,7 @@
)
from .models.depth_anything import DepthAnythingConfig
from .models.detr import DetrConfig
from .models.diffllama import DiffLlamaConfig
from .models.dinat import DinatConfig
from .models.dinov2 import Dinov2Config
from .models.dinov2_with_registers import Dinov2WithRegistersConfig
Expand Down Expand Up @@ -7017,6 +7029,14 @@
DetrModel,
DetrPreTrainedModel,
)
from .models.diffllama import (
DiffLlamaForCausalLM,
DiffLlamaForQuestionAnswering,
DiffLlamaForSequenceClassification,
DiffLlamaForTokenClassification,
DiffLlamaModel,
DiffLlamaPreTrainedModel,
)
from .models.dinat import (
DinatBackbone,
DinatForImageClassification,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
depth_anything,
detr,
dialogpt,
diffllama,
dinat,
dinov2,
dinov2_with_registers,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
("depth_anything", "DepthAnythingConfig"),
("deta", "DetaConfig"),
("detr", "DetrConfig"),
("diffllama", "DiffLlamaConfig"),
("dinat", "DinatConfig"),
("dinov2", "Dinov2Config"),
("dinov2_with_registers", "Dinov2WithRegistersConfig"),
Expand Down Expand Up @@ -403,6 +404,7 @@
("deta", "DETA"),
("detr", "DETR"),
("dialogpt", "DialoGPT"),
("diffllama", "DiffLlama"),
("dinat", "DiNAT"),
("dinov2", "DINOv2"),
("dinov2_with_registers", "DINOv2 with Registers"),
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
("deit", "DeiTModel"),
("deta", "DetaModel"),
("detr", "DetrModel"),
("diffllama", "DiffLlamaModel"),
("dinat", "DinatModel"),
("dinov2", "Dinov2Model"),
("dinov2_with_registers", "Dinov2WithRegistersModel"),
Expand Down Expand Up @@ -493,6 +494,7 @@
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForCausalLM"),
("dbrx", "DbrxForCausalLM"),
("diffllama", "DiffLlamaForCausalLM"),
("electra", "ElectraForCausalLM"),
("ernie", "ErnieForCausalLM"),
("falcon", "FalconForCausalLM"),
Expand Down Expand Up @@ -961,6 +963,7 @@
("data2vec-text", "Data2VecTextForSequenceClassification"),
("deberta", "DebertaForSequenceClassification"),
("deberta-v2", "DebertaV2ForSequenceClassification"),
("diffllama", "DiffLlamaForSequenceClassification"),
("distilbert", "DistilBertForSequenceClassification"),
("electra", "ElectraForSequenceClassification"),
("ernie", "ErnieForSequenceClassification"),
Expand Down Expand Up @@ -1056,6 +1059,7 @@
("data2vec-text", "Data2VecTextForQuestionAnswering"),
("deberta", "DebertaForQuestionAnswering"),
("deberta-v2", "DebertaV2ForQuestionAnswering"),
("diffllama", "DiffLlamaForQuestionAnswering"),
("distilbert", "DistilBertForQuestionAnswering"),
("electra", "ElectraForQuestionAnswering"),
("ernie", "ErnieForQuestionAnswering"),
Expand Down Expand Up @@ -1153,6 +1157,7 @@
("data2vec-text", "Data2VecTextForTokenClassification"),
("deberta", "DebertaForTokenClassification"),
("deberta-v2", "DebertaV2ForTokenClassification"),
("diffllama", "DiffLlamaForTokenClassification"),
("distilbert", "DistilBertForTokenClassification"),
("electra", "ElectraForTokenClassification"),
("ernie", "ErnieForTokenClassification"),
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"diffllama",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
(
"dpr",
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/models/diffllama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_diffllama import *
from .modeling_diffllama import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading

0 comments on commit 96bf3d6

Please sign in to comment.