From 7c3d4938a1e31a0a40459a3c71debb13d1437c2a Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 11 Apr 2024 19:58:08 -0700 Subject: [PATCH] debug --- llmfoundry/models/layers/blocks.py | 2 +- llmfoundry/models/layers/layer_builders.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 5f6fba1803..77f39c2a38 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -74,7 +74,7 @@ def __init__( del kwargs # unused, just to capture any extra args from the config super().__init__() - ffn_type = ffn_config.pop('ffn_type') + ffn_type = ffn_config['ffn_type'] self.ffn = build_ffn( name=ffn_type, diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index ed24de358e..1d32b6baf7 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -35,6 +35,7 @@ def build_ffn( bias: bool, ffn_kwargs: Dict[str, Any], ): + registry_to_use = ffns if name in ffns_with_norm: registry_to_use = ffns_with_norm @@ -47,7 +48,7 @@ def build_ffn( 'expansion_ratio': expansion_ratio, 'device': device, 'bias': bias, - **ffn_kwargs, + **{k:v for k,v in ffn_kwargs.items() if k != 'ffn_type'}, } def _validation_function(maybe_module: Any):