From a4f567906d56051cd6161a20d709da662c0874f6 Mon Sep 17 00:00:00 2001 From: Yanan Xie Date: Fri, 15 Sep 2023 10:30:44 -0700 Subject: [PATCH 1/7] Allow MPT models to return attention weights --- llmfoundry/models/layers/blocks.py | 2 ++ llmfoundry/models/mpt/modeling_mpt.py | 1 + 2 files changed, 3 insertions(+) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index dd208302b8..2c5b5d1c7c 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -91,6 +91,7 @@ def forward( attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, + output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) @@ -100,6 +101,7 @@ def forward( attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal, + needs_weights=output_attentions, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 3371c67a0d..6a184ee6dd 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -434,6 +434,7 @@ def forward( attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, + output_attentions=output_attentions == True, ) if past_key_values is not None: past_key_values[b_idx] = past_key_value From 89f8e83fbc559e9bcface5c0d4fdc1be3d002a65 Mon Sep 17 00:00:00 2001 From: Yanan Xie <108375850+lorabit110@users.noreply.github.com> Date: Fri, 15 Sep 2023 15:01:22 -0700 Subject: [PATCH 2/7] Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 6a184ee6dd..26d564ff8c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -434,7 +434,7 @@ def forward( attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, - output_attentions=output_attentions == True, + output_attentions=bool(output_attentions), ) if past_key_values is not None: past_key_values[b_idx] = past_key_value From 655cee4f45daa16636abcaaf0a5286d82747c65f Mon Sep 17 00:00:00 2001 From: Yanan Xie Date: Fri, 15 Sep 2023 15:08:32 -0700 Subject: [PATCH 3/7] Add unit test --- tests/test_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_model.py b/tests/test_model.py index 501d9bf6e7..9b29eeacb3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1341,6 +1341,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( if output_attentions: assert len(outputs.attentions) == n_layers + assert outputs.attentions[0] is not None if output_hidden_states: assert len(outputs.hidden_states) == n_layers + 1 From 2fad7a5151add200f8175db9c0fa80f108241d08 Mon Sep 17 00:00:00 2001 From: Yanan Xie <108375850+lorabit110@users.noreply.github.com> Date: Fri, 15 Sep 2023 16:08:51 -0700 Subject: [PATCH 4/7] Update tests/test_model.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 9b29eeacb3..136e387fbe 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1341,7 +1341,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( if output_attentions: assert len(outputs.attentions) == n_layers - assert outputs.attentions[0] is not None + assert all(attn.shape == (1, 4, 2048, 2048) for attn in outputs.attentions) if output_hidden_states: assert len(outputs.hidden_states) == n_layers + 1 From 2846d29a7cf90ef5ad67f21328967e33ca1bbe65 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 15 Sep 2023 22:37:38 -0700 Subject: [PATCH 5/7] Update tests/test_model.py --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 136e387fbe..ff5a97efc7 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1341,7 +1341,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( if output_attentions: assert len(outputs.attentions) == n_layers - assert all(attn.shape == (1, 4, 2048, 2048) for attn in outputs.attentions) + assert all(attn.shape == (1, 4, 3, 3) for attn in outputs.attentions) if output_hidden_states: assert len(outputs.hidden_states) == n_layers + 1 From f4e4f49cfde851d3df3aa44fdf85b06c7fb98627 Mon Sep 17 00:00:00 2001 From: Yanan Xie Date: Sat, 16 Sep 2023 12:15:50 -0700 Subject: [PATCH 6/7] retrigger checks From 507b4df1453bf06c0f8425c3c2057ad6bd971245 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 20 Sep 2023 17:30:35 -0700 Subject: [PATCH 7/7] Update tests/test_model.py --- tests/test_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index ff5a97efc7..6ea530731a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1341,7 +1341,8 @@ def test_forward_with_output_attentions_and_output_hidden_states( if output_attentions: assert len(outputs.attentions) == n_layers - assert all(attn.shape == (1, 4, 3, 3) for attn in outputs.attentions) + assert all( + attn.shape == (1, 4, 3, 3) for attn in outputs.attentions) if output_hidden_states: assert len(outputs.hidden_states) == n_layers + 1