diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index aca2350051..c812d9e6b7 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -325,6 +325,7 @@ def __init__(self, config: MPTConfig): self.emb_drop = nn.Dropout(config.emb_pdrop) self.mb_args = None block_args = config.to_dict() + print(block_args['ffn_config']) if block_args['ffn_config']['ffn_type'] in ffns_with_megablocks: block_args['ffn_config'] = config_moe_args( block_args['ffn_config'], @@ -332,6 +333,7 @@ def __init__(self, config: MPTConfig): config.expansion_ratio, config.n_layers, ) + print(block_args['ffn_config']) self.mb_args = block_args['ffn_config'].get('args') self.blocks = nn.ModuleList([ MPTBlock(