diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index dd208302b8..2c5b5d1c7c 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -91,6 +91,7 @@ def forward( attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, + output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) @@ -100,6 +101,7 @@ def forward( attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal, + needs_weights=output_attentions, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 3371c67a0d..26d564ff8c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -434,6 +434,7 @@ def forward( attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, + output_attentions=bool(output_attentions), ) if past_key_values is not None: past_key_values[b_idx] = past_key_value diff --git a/tests/test_model.py b/tests/test_model.py index 501d9bf6e7..6ea530731a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1341,6 +1341,8 @@ def test_forward_with_output_attentions_and_output_hidden_states( if output_attentions: assert len(outputs.attentions) == n_layers + assert all( + attn.shape == (1, 4, 3, 3) for attn in outputs.attentions) if output_hidden_states: assert len(outputs.hidden_states) == n_layers + 1