Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Apr 11, 2024
1 parent 13746a4 commit 3eb723f
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ def __init__(
)
else:
assert isinstance(attn_config['attn_type'], str)
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]

# Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max',
Expand All @@ -106,13 +104,16 @@ def __init__(
normalized_shape=d_model,
device=device,
)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class,
bias=not no_bias,
self.attn = build_attention_layer(
name=attn_config['attn_type'],
attn_kwargs={
'd_model': d_model,
'n_heads': n_heads,
'fc_type': fc_type,
'device': device,
'bias': not no_bias,
**attn_config_subset_for_attn_class
},
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']],
Expand Down

0 comments on commit 3eb723f

Please sign in to comment.