-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add OLMo November 2024 #34551
Merged
ArthurZucker
merged 26 commits into
huggingface:main
from
2015aroras:shanea/add-olmo1124
Nov 18, 2024
Merged
Add OLMo November 2024 #34551
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
6e747c2
Add model skeletion with transformers-cli add-new-model-like
2015aroras a80ffd1
Convert config to modular, add rms_norm_eps, delete clip_qkv
2015aroras ffa794e
Convert model to modular, add RMSNorm
2015aroras 75d38f0
Add flash attention with qk norm and no qkv clipping
2015aroras dbd880d
Add decoder layer with RMSNorm after attention/feedforward layers
2015aroras 06c9c44
Add base and causal model
2015aroras b73f6d3
Add converter improvements from OLMo repo
2015aroras c8d9411
Update weight loading in OLMo to HF converter
2015aroras 4e3da14
Set correct default for rms_norm_eps
2015aroras 87d54bb
Set correct pipeline_model_mapping in test
2015aroras b7939d2
Run make fixup
2015aroras d39587f
Fix model type
2015aroras 30c20f6
Re-run modular conversion
2015aroras cdce157
Manually set config docs to fix build errors
2015aroras 3a9c61c
Convert olmo-1124 to olmo_1124 to fix flash attention docs errors
2015aroras 949648e
Start updating tests
2015aroras 0217f40
Update tests
2015aroras 1bdaa05
Copy upstream test_eager_matches_sdpa_inference_1_bfloat16 changes to…
2015aroras 0b1f2bf
Rename input_layernorm and post_attention_layernorm to reflect their …
2015aroras 9e7c77d
Use correct tokenizer
2015aroras 11f67eb
Remove test unsupported by GPT2 tokenizer
2015aroras 0c2a264
Create GenerationConfig outside of from_pretrained call
2015aroras a22d936
Use simpler init file structure
2015aroras a3cca57
Add explicit __all__ to support simplified init
2015aroras 82a75c2
Make safetensor serialization the default
2015aroras bfd2e63
Update OLMo November 2024 docs
2015aroras File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
|
||
--> | ||
|
||
# OLMo November 2024 | ||
|
||
## Overview | ||
|
||
The OLMo November 2024 model is a successor of the OLMo model, which was proposed in | ||
[OLMo: Accelerating the Science of Language Models](https://arxiv.org/abs/2402.00838). | ||
|
||
The architectural changes from the original OLMo model to this model are: | ||
|
||
- RMSNorm is used instead of standard layer norm. | ||
- Norm is applied to attention queries and keys. | ||
- Norm is applied after attention/feedforward layers rather than before. | ||
|
||
This model was contributed by [shanearora](https://huggingface.co/shanearora). | ||
The original code can be found [here](https://github.com/allenai/OLMo/tree/main/olmo). | ||
|
||
|
||
## Olmo1124Config | ||
|
||
[[autodoc]] Olmo1124Config | ||
|
||
## Olmo1124Model | ||
|
||
[[autodoc]] Olmo1124Model | ||
- forward | ||
|
||
## Olmo1124ForCausalLM | ||
|
||
[[autodoc]] Olmo1124ForCausalLM | ||
- forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -177,6 +177,7 @@ | |
nougat, | ||
nystromformer, | ||
olmo, | ||
olmo_1124, | ||
olmoe, | ||
omdet_turbo, | ||
oneformer, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import _LazyModule | ||
from ...utils.import_utils import define_import_structure | ||
|
||
|
||
if TYPE_CHECKING: | ||
from .configuration_olmo_1124 import * | ||
from .modeling_olmo_1124 import * | ||
else: | ||
import sys | ||
|
||
_file = globals()["__file__"] | ||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
166 changes: 166 additions & 0 deletions
166
src/transformers/models/olmo_1124/configuration_olmo_1124.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 | ||
# This file was automatically generated from src/transformers/models/olmo_1124/modular_olmo_1124.py. | ||
# Do NOT edit this file manually as any edits will be overwritten by the generation of | ||
# the file from the modular. If any change should be done, please apply the change to the | ||
# modular_olmo_1124.py file directly. One of our CI enforces this. | ||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 | ||
|
||
from ...configuration_utils import PretrainedConfig | ||
|
||
|
||
class Olmo1124Config(PretrainedConfig): | ||
r""" | ||
This is the configuration class to store the configuration of a [`Olmo1124Model`]. It is used to instantiate an OLMo November 2024 | ||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the | ||
defaults will yield a similar configuration to that of the [allenai/Olmo1124-7B-hf](https://huggingface.co/allenai/Olmo1124-7B-hf). | ||
|
||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | ||
documentation from [`PretrainedConfig`] for more information. | ||
|
||
|
||
Args: | ||
vocab_size (`int`, *optional*, defaults to 50304): | ||
Vocabulary size of the Olmo1124 model. Defines the number of different tokens that can be represented by the | ||
`inputs_ids` passed when calling [`Olmo1124Model`] | ||
hidden_size (`int`, *optional*, defaults to 4096): | ||
Dimension of the hidden representations. | ||
intermediate_size (`int`, *optional*, defaults to 11008): | ||
Dimension of the MLP representations. | ||
num_hidden_layers (`int`, *optional*, defaults to 32): | ||
Number of hidden layers in the Transformer decoder. | ||
num_attention_heads (`int`, *optional*, defaults to 32): | ||
Number of attention heads for each attention layer in the Transformer decoder. | ||
num_key_value_heads (`int`, *optional*): | ||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If | ||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if | ||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When | ||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed | ||
by meanpooling all the original heads within that group. For more details checkout [this | ||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to | ||
`num_attention_heads`. | ||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): | ||
The non-linear activation function (function or string) in the decoder. | ||
max_position_embeddings (`int`, *optional*, defaults to 2048): | ||
The maximum sequence length that this model might ever be used with. | ||
initializer_range (`float`, *optional*, defaults to 0.02): | ||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | ||
use_cache (`bool`, *optional*, defaults to `True`): | ||
Whether or not the model should return the last key/values attentions (not used by all models). Only | ||
relevant if `config.is_decoder=True`. | ||
pad_token_id (`int`, *optional*, defaults to 1): | ||
Padding token id. | ||
bos_token_id (`int`, *optional*): | ||
Beginning of stream token id. | ||
eos_token_id (`int`, *optional*, defaults to 50279): | ||
End of stream token id. | ||
tie_word_embeddings (`bool`, *optional*, defaults to `False`): | ||
Whether to tie weight embeddings | ||
rope_theta (`float`, *optional*, defaults to 10000.0): | ||
The base period of the RoPE embeddings. | ||
rope_scaling (`Dict`, *optional*): | ||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling | ||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is | ||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update | ||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how | ||
these scaling strategies behave: | ||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an | ||
experimental feature, subject to breaking API changes in future versions. | ||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): | ||
Whether to use a bias in the query, key, value and output projection layers during self-attention. | ||
attention_dropout (`float`, *optional*, defaults to 0.0): | ||
The dropout ratio for the attention probabilities. | ||
rms_norm_eps (`float`, *optional*, defaults to 1e-05): | ||
The epsilon used by the rms normalization layers. | ||
|
||
```python | ||
>>> from transformers import Olmo1124Model, Olmo1124Config | ||
|
||
>>> # Initializing a Olmo November 2024 7B style configuration | ||
>>> configuration = Olmo1124Config() | ||
|
||
>>> # Initializing a model from the Olmo November 2024 7B style configuration | ||
>>> model = Olmo1124Model(configuration) | ||
|
||
>>> # Accessing the model configuration | ||
>>> configuration = model.config | ||
``` | ||
""" | ||
|
||
model_type = "olmo_1124" | ||
keys_to_ignore_at_inference = ["past_key_values"] | ||
|
||
def __init__( | ||
self, | ||
vocab_size=50304, | ||
hidden_size=4096, | ||
intermediate_size=11008, | ||
num_hidden_layers=32, | ||
num_attention_heads=32, | ||
num_key_value_heads=None, | ||
hidden_act="silu", | ||
max_position_embeddings=2048, | ||
initializer_range=0.02, | ||
use_cache=True, | ||
pad_token_id=1, | ||
bos_token_id=None, | ||
eos_token_id=50279, | ||
tie_word_embeddings=False, | ||
rope_theta=10000.0, | ||
rope_scaling=None, | ||
attention_bias=False, | ||
attention_dropout=0.0, | ||
rms_norm_eps=1e-5, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
pad_token_id=pad_token_id, | ||
bos_token_id=bos_token_id, | ||
eos_token_id=eos_token_id, | ||
tie_word_embeddings=tie_word_embeddings, | ||
**kwargs, | ||
) | ||
self.vocab_size = vocab_size | ||
self.max_position_embeddings = max_position_embeddings | ||
self.hidden_size = hidden_size | ||
self.intermediate_size = intermediate_size | ||
self.num_hidden_layers = num_hidden_layers | ||
self.num_attention_heads = num_attention_heads | ||
|
||
# for backward compatibility | ||
if num_key_value_heads is None: | ||
num_key_value_heads = num_attention_heads | ||
|
||
self.num_key_value_heads = num_key_value_heads | ||
self.hidden_act = hidden_act | ||
self.initializer_range = initializer_range | ||
self.use_cache = use_cache | ||
self.rope_theta = rope_theta | ||
self.rope_scaling = rope_scaling | ||
self._rope_scaling_validation() | ||
self.attention_bias = attention_bias | ||
self.attention_dropout = attention_dropout | ||
|
||
self.rms_norm_eps = rms_norm_eps | ||
|
||
def _rope_scaling_validation(self): | ||
""" | ||
Validate the `rope_scaling` configuration. | ||
""" | ||
if self.rope_scaling is None: | ||
return | ||
|
||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: | ||
raise ValueError( | ||
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" | ||
) | ||
rope_scaling_type = self.rope_scaling.get("type", None) | ||
rope_scaling_factor = self.rope_scaling.get("factor", None) | ||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: | ||
raise ValueError( | ||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" | ||
) | ||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: | ||
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") | ||
|
||
|
||
__all__ = ["Olmo1124Config"] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💘 super clear, love this!