From 0be2ca8218b62a6e1f79d5372908931fbc21c2fd Mon Sep 17 00:00:00 2001 From: Yanan Xie <108375850+lorabit110@users.noreply.github.com> Date: Wed, 20 Sep 2023 17:47:24 -0700 Subject: [PATCH] Allow MPT models to return attention weights (#599) * Allow MPT models to return attention weights * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Add unit test * Update tests/test_model.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update tests/test_model.py --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/models/layers/blocks.py | 2 ++ llmfoundry/models/mpt/modeling_mpt.py | 1 + tests/test_model.py | 2 ++ 3 files changed, 5 insertions(+) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index dd208302b8..2c5b5d1c7c 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -91,6 +91,7 @@ def forward( attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, + output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) @@ -100,6 +101,7 @@ def forward( attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal, + needs_weights=output_attentions, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 3371c67a0d..26d564ff8c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -434,6 +434,7 @@ def forward( attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, + output_attentions=bool(output_attentions), ) if past_key_values is not None: past_key_values[b_idx] = past_key_value diff --git a/tests/test_model.py b/tests/test_model.py index 501d9bf6e7..6ea530731a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1341,6 +1341,8 @@ def test_forward_with_output_attentions_and_output_hidden_states( if output_attentions: assert len(outputs.attentions) == n_layers + assert all( + attn.shape == (1, 4, 3, 3) for attn in outputs.attentions) if output_hidden_states: assert len(outputs.hidden_states) == n_layers + 1