diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 1eb57e2055..d1ed06b9d6 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -73,6 +73,7 @@ def __init__( super().__init__() self.ffn = build_ffn( + name=ffn_config['ffn_type'], d_model=d_model, expansion_ratio=expansion_ratio, device=device,