From 3eb723fda7be97f4d80ff2c4d93e958e4dfa1024 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 15:47:43 -0700 Subject: [PATCH] fix --- llmfoundry/models/layers/blocks.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 7bf3b513a6..1ad9ec954f 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -87,8 +87,6 @@ def __init__( ) else: assert isinstance(attn_config['attn_type'], str) - attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] - # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { 'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max', @@ -106,13 +104,16 @@ def __init__( normalized_shape=d_model, device=device, ) - self.attn = attn_class( - d_model=d_model, - n_heads=n_heads, - fc_type=fc_type, - device=device, - **attn_config_subset_for_attn_class, - bias=not no_bias, + self.attn = build_attention_layer( + name=attn_config['attn_type'], + attn_kwargs={ + 'd_model': d_model, + 'n_heads': n_heads, + 'fc_type': fc_type, + 'device': device, + 'bias': not no_bias, + **attn_config_subset_for_attn_class + }, ) self.norm_2 = None if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']],