Skip to content

Commit

Permalink
setting absolute absolute value for reuse_kv_layer_idx
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jun 25, 2024
1 parent 89bc22e commit 0a3f1b4
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ def _construct_blocks_with_overrides(
f'The specified block overrides do not match the number of layers: {len(model_modules_order_expanded)} vs {config.n_layers}.',
)

reuse_kv_layer_idx_dict = {}
for i in range(config.n_layers):
module_name = model_modules_order_expanded[i]
override_config = {}
Expand All @@ -524,6 +525,10 @@ def _construct_blocks_with_overrides(
raise ValueError(
f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.',
)
if reuse_kv_layer_idx in reuse_kv_layer_idx_dict:
reuse_kv_layer_idx = reuse_kv_layer_idx_dict[
reuse_kv_layer_idx]
reuse_kv_layer_idx_dict[i] = reuse_kv_layer_idx
override_attn_config['reuse_kv_layer_idx'
] = reuse_kv_layer_idx
if self.kv_cache_layers is None:
Expand Down

0 comments on commit 0a3f1b4

Please sign in to comment.