diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index cd162195b6..87f7ba5f00 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -422,6 +422,7 @@ def forward( ) # initialize the past key values cache if it should be used + presents = () if use_cache else None if use_cache and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers) ] # type: ignore @@ -434,7 +435,7 @@ def forward( all_hidden_states = all_hidden_states + (x,) past_key_value = (past_key_values[b_idx] if past_key_values is not None else None) - x, attn_weights, past_key_value = block( + x, attn_weights, present = block( x, past_key_value=past_key_value, attn_bias=attn_bias, @@ -442,8 +443,8 @@ def forward( is_causal=self.is_causal, output_attentions=bool(output_attentions), ) - if past_key_values is not None: - past_key_values[b_idx] = past_key_value + if use_cache: + presents += (present,) if output_attentions: assert all_self_attns is not None # pyright @@ -458,7 +459,7 @@ def forward( return BaseModelOutputWithPast( last_hidden_state=x, - past_key_values=past_key_values, + past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attns, )