Skip to content

Commit

Permalink
Re-run modular conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras committed Nov 4, 2024
1 parent 029e843 commit 4349938
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 36 deletions.
16 changes: 7 additions & 9 deletions src/transformers/models/olmo_1124/configuration_olmo_1124.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# modular_olmo_1124.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨


from ...configuration_utils import PretrainedConfig


Expand Down Expand Up @@ -116,6 +115,13 @@ def __init__(
rms_norm_eps=1e-5,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
Expand All @@ -139,14 +145,6 @@ def __init__(

self.rms_norm_eps = rms_norm_eps

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
Expand Down
49 changes: 22 additions & 27 deletions src/transformers/models/olmo_1124/modeling_olmo_1124.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,14 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
Expand All @@ -36,6 +32,11 @@
from ...modeling_flash_attention_utils import _flash_attention_forward


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "Olmo1124Config"


class Olmo1124RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Expand All @@ -56,9 +57,6 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


logger = logging.get_logger(__name__)


# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo1124
# TODO(joao): add me back asap :)
class Olmo1124RotaryEmbedding(nn.Module):
Expand Down Expand Up @@ -205,9 +203,9 @@ def __init__(self, config: Olmo1124Config, layer_idx: Optional[int] = None):
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self._init_rope()
self.q_norm = Olmo1124RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps)
self.k_norm = Olmo1124RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps)
self._init_rope()

def _init_rope(self):
if self.config.rope_scaling is None:
Expand Down Expand Up @@ -297,21 +295,6 @@ def forward(
return attn_output, attn_weights, past_key_value


class Olmo1124MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class Olmo1124FlashAttention2(Olmo1124Attention):
"""
Olmo1124 flash attention module. This module inherits from `Olmo1124Attention` as the weights of the module stays
Expand Down Expand Up @@ -496,6 +479,21 @@ def forward(
return attn_output, None, past_key_value


class Olmo1124MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


OLMO_1124_ATTENTION_CLASSES = {
"eager": Olmo1124Attention,
"flash_attention_2": Olmo1124FlashAttention2,
Expand Down Expand Up @@ -621,9 +619,6 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()


_CONFIG_FOR_DOC = "Olmo1124Config"


OLMO_1124_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Expand Down

0 comments on commit 4349938

Please sign in to comment.