Skip to content

Commit

Permalink
no causal attention mask for the encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Eustache Le Bihan committed Dec 15, 2024
1 parent d0ed917 commit eabb86b
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions src/transformers/models/moonshine/modular_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,6 @@ def preprocess(self, input_features: torch.FloatTensor):
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
Expand Down Expand Up @@ -1084,9 +1083,6 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers
Expand All @@ -1097,15 +1093,15 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for encoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
encoder_layer.__call__,
hidden_states,
causal_mask,
None,
position_ids,
past_key_values,
output_attentions,
Expand All @@ -1114,9 +1110,8 @@ def forward(
position_embeddings,
)
else:
layer_outputs = decoder_layer(
layer_outputs = encoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
Expand Down

0 comments on commit eabb86b

Please sign in to comment.