diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 0465226820..40f349368f 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -74,7 +74,7 @@ def __init__( super().__init__() ffn_type = ffn_config['ffn_type'] - ffn_has_norm = not ffn_type in ffns_with_norm + ffn_has_norm = ffn_type in ffns_with_norm if self.fuse_norm_attn_norm: self.norm_attn_norm = FusedNormAttentionNorm(