-
Notifications
You must be signed in to change notification settings - Fork 27.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* first adding diffllama * add Diff Attention and other but still with errors * complate make attention Diff-Attention * fix some bugs which may be caused by transformer-cli while adding model * fix a bug caused by forgetting KV cache... * Update src/transformers/models/diffllama/modeling_diffllama.py You don't need to divide by 2 if we use same number of attention heads as llama. instead you can just split in forward. Co-authored-by: Minho Ryu <[email protected]> * Update src/transformers/models/diffllama/modeling_diffllama.py fit to changeing "num_heads // 2" place Co-authored-by: Minho Ryu <[email protected]> * Update src/transformers/models/diffllama/modeling_diffllama.py new codes are more meaningful than before Co-authored-by: Minho Ryu <[email protected]> * Update src/transformers/models/diffllama/modeling_diffllama.py new codes are more meaningful than before Co-authored-by: Minho Ryu <[email protected]> * Update src/transformers/models/diffllama/modeling_diffllama.py fit to changeing "num_heads // 2" place Co-authored-by: Minho Ryu <[email protected]> * Update src/transformers/models/diffllama/modeling_diffllama.py fix 2times divide by sqrt(self.head_dim) Co-authored-by: Minho Ryu <[email protected]> * Update src/transformers/models/diffllama/modeling_diffllama.py fix 2times divide by sqrt(self.head_dim) Co-authored-by: Minho Ryu <[email protected]> * Update src/transformers/models/diffllama/modeling_diffllama.py fit to changeing "num_heads // 2" place. and more visible Co-authored-by: Minho Ryu <[email protected]> * I found Attention missed implemented from paper still on e072544. * re-implemented * adding groupnorm Co-authored-by: Minho Ryu <[email protected]> * align with transformers code style Co-authored-by: Minho Ryu <[email protected]> * fix typo Co-authored-by: Minho Ryu <[email protected]> * adding groupnorm Co-authored-by: Minho Ryu <[email protected]> * change SdpaAttention to DiffSdpaAttention Co-authored-by: Minho Ryu <[email protected]> * fix bug * Update src/transformers/models/diffllama/modeling_diffllama.py resolve "not same outputs" problem Co-authored-by: Minho Ryu <[email protected]> * fix bugs of places of "GroupNorm with scale" and etc * Revert "fix bugs of places of "GroupNorm with scale" and etc" This reverts commit 26307d9. * simplify multiple of attention (matmul) operations into one by repeating value_states Co-authored-by: Minho Ryu <[email protected]> * simplify multiple of attention (matmul) operations into one by repeating value_states Co-authored-by: Minho Ryu <[email protected]> * simplify multiple of attention (matmul) operations into one by repeating value_states Co-authored-by: Minho Ryu <[email protected]> * remove missed type * add diffllama model_doc * apply make style/quality * apply review comment about model * apply review comment about test * place diffllama alphabetically on the src/transformers/__init__.py * fix forgot code * Supports parameters that are not initialized with standard deviation 0 in the conventional method * add DiffLlamaConfig to CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK on utils/check_config_docstrings.py * remove unused property of config * add to supported model list * add to spda supported model list * fix copyright, remove pretraining_tensor_parallel, and modify for initialization test * remove unused import and etc. * empty commit * empty commit * empty commit * apply modular transformers but with bugs * revert prev commit * create src/transformers/model/diffllama/modular_diffllama.py * run utils/modular_model_converter.py * empty commit * leaner modular diffllama * remove more and more in modular_diffllama.pt * remove more and more in modular_diffllama.pt * resolve missing docstring entries * force reset * convert modular --------- Co-authored-by: Minho Ryu <[email protected]>
- Loading branch information
1 parent
ed73ae2
commit 96bf3d6
Showing
16 changed files
with
3,249 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
|
||
# DiffLlama | ||
|
||
## Overview | ||
|
||
The DiffLlama model was proposed in [Differential Transformer](https://arxiv.org/abs/2410.05258) by Kazuma Matsumoto and . | ||
This model is combine Llama model and Differential Transformer's Attention. | ||
|
||
The abstract from the paper is the following: | ||
|
||
*Transformer tends to overallocate attention to irrelevant context. In this work, we introduce Diff Transformer, which amplifies attention to the relevant context while canceling noise. Specifically, the differential attention mechanism calculates attention scores as the difference between two separate softmax attention maps. The subtraction cancels noise, promoting the emergence of sparse attention patterns. Experimental results on language modeling show that Diff Transformer outperforms Transformer in various settings of scaling up model size and training tokens. More intriguingly, it offers notable advantages in practical applications, such as long-context modeling, key information retrieval, hallucination mitigation, in-context learning, and reduction of activation outliers. By being less distracted by irrelevant context, Diff Transformer can mitigate hallucination in question answering and text summarization. For in-context learning, Diff Transformer not only enhances accuracy but is also more robust to order permutation, which was considered as a chronic robustness issue. The results position Diff Transformer as a highly effective and promising architecture to advance large language models.* | ||
|
||
### Usage tips | ||
The hyperparameters of this model is the same as Llama model. | ||
|
||
|
||
## DiffLlamaConfig | ||
|
||
[[autodoc]] DiffLlamaConfig | ||
|
||
## DiffLlamaModel | ||
|
||
[[autodoc]] DiffLlamaModel | ||
- forward | ||
|
||
## DiffLlamaForCausalLM | ||
|
||
[[autodoc]] DiffLlamaForCausalLM | ||
- forward | ||
|
||
## DiffLlamaForSequenceClassification | ||
|
||
[[autodoc]] DiffLlamaForSequenceClassification | ||
- forward | ||
|
||
## DiffLlamaForQuestionAnswering | ||
|
||
[[autodoc]] DiffLlamaForQuestionAnswering | ||
- forward | ||
|
||
## DiffLlamaForTokenClassification | ||
|
||
[[autodoc]] DiffLlamaForTokenClassification | ||
- forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,6 +75,7 @@ | |
depth_anything, | ||
detr, | ||
dialogpt, | ||
diffllama, | ||
dinat, | ||
dinov2, | ||
dinov2_with_registers, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import _LazyModule | ||
from ...utils.import_utils import define_import_structure | ||
|
||
|
||
if TYPE_CHECKING: | ||
from .configuration_diffllama import * | ||
from .modeling_diffllama import * | ||
else: | ||
import sys | ||
|
||
_file = globals()["__file__"] | ||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
Oops, something went wrong.