diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 4a1e4aff081..763ad97b1e7 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1608,9 +1608,8 @@ def forward( hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim) # Collect intermediate layer outputs from encoder output - all_intermediate_hidden_states = output[1] + all_intermediate_hidden_states = [output[1][i] for i in self.intermediate_layers_indices] intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) - intermediate_hidden_states = intermediate_hidden_states[..., self.intermediate_layers_indices] # Remove padding from intermediate hidden states intermediate_hidden_states = intermediate_hidden_states.reshape(