Skip to content

Commit

Permalink
Use presents, they are a gift. Do not update past_key_values in place.
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 7, 2023
1 parent fcc271f commit 655fc75
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def forward(
)

# initialize the past key values cache if it should be used
presents = () if use_cache else None
if use_cache and past_key_values is None:
past_key_values = [() for _ in range(self.config.n_layers)
] # type: ignore
Expand All @@ -434,16 +435,16 @@ def forward(
all_hidden_states = all_hidden_states + (x,)
past_key_value = (past_key_values[b_idx]
if past_key_values is not None else None)
x, attn_weights, past_key_value = block(
x, attn_weights, present = block(
x,
past_key_value=past_key_value,
attn_bias=attn_bias,
attention_mask=attention_mask,
is_causal=self.is_causal,
output_attentions=bool(output_attentions),
)
if past_key_values is not None:
past_key_values[b_idx] = past_key_value
if use_cache:
presents += (present,)

if output_attentions:
assert all_self_attns is not None # pyright
Expand All @@ -458,7 +459,7 @@ def forward(

return BaseModelOutputWithPast(
last_hidden_state=x,
past_key_values=past_key_values,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
Expand Down

0 comments on commit 655fc75

Please sign in to comment.