diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 82e8e94f74..c88cf33d1b 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -170,7 +170,9 @@ def forward( extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + if key_value_states is not None: extra_kwargs['key_value_states'] = key_value_states + if self.fuse_norm_attn_norm: x, m, attn_weights, past_key_value = self.norm_attn_norm( x, @@ -336,7 +338,9 @@ def forward( extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + if key_value_states is not None: extra_kwargs['key_value_states'] = key_value_states + b, attn_weights, past_key_value = self.attn( a, past_key_value=past_key_value,