diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 5f6fba1803..77f39c2a38 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -74,7 +74,7 @@ def __init__( del kwargs # unused, just to capture any extra args from the config super().__init__() - ffn_type = ffn_config.pop('ffn_type') + ffn_type = ffn_config['ffn_type'] self.ffn = build_ffn( name=ffn_type, diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index ed24de358e..1d32b6baf7 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -35,6 +35,7 @@ def build_ffn( bias: bool, ffn_kwargs: Dict[str, Any], ): + registry_to_use = ffns if name in ffns_with_norm: registry_to_use = ffns_with_norm @@ -47,7 +48,7 @@ def build_ffn( 'expansion_ratio': expansion_ratio, 'device': device, 'bias': bias, - **ffn_kwargs, + **{k:v for k,v in ffn_kwargs.items() if k != 'ffn_type'}, } def _validation_function(maybe_module: Any):