From eabb86b16dc1ffa181b69f563c537ae74540045d Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Sun, 15 Dec 2024 21:28:37 +0100 Subject: [PATCH] no causal attention mask for the encoder --- .../models/moonshine/modular_moonshine.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index bea8370b7313ea..3d2c62487cf30f 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -1032,7 +1032,6 @@ def preprocess(self, input_features: torch.FloatTensor): def forward( self, input_features: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -1084,9 +1083,6 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -1097,15 +1093,15 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for encoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + encoder_layer.__call__, hidden_states, - causal_mask, + None, position_ids, past_key_values, output_attentions, @@ -1114,9 +1110,8 @@ def forward( position_embeddings, ) else: - layer_outputs = decoder_layer( + layer_outputs = encoder_layer( hidden_states, - attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions,