diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index cf7a67f10d..d1c22f7f3a 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -72,6 +72,14 @@ def __init__( del kwargs # unused, just to capture any extra args from the config super().__init__() + self.ffn = build_ffn( + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device, + bias=not no_bias, + **ffn_config, + ) + if self.fuse_norm_attn_norm: self.norm_attn_norm = FusedNormAttentionNorm( d_model=d_model, @@ -122,13 +130,6 @@ def __init__( device=device, ) - self.ffn = build_ffn( - d_model=d_model, - expansion_ratio=expansion_ratio, - device=device, - bias=not no_bias, - **ffn_config, - ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) self.use_pad_tok_in_ffn = use_pad_tok_in_ffn