diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 560e8c31fc..fa3e109bf8 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -117,7 +117,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.up_proj(x))) -class MPTGeGLU(MPTMLP): +class MPTGLU(MPTMLP): def __init__( self, @@ -138,19 +138,19 @@ def __init__( device=device, bias=bias, ) - self.gate = FC_CLASS_REGISTRY[fc_type]( + self.gate_proj = FC_CLASS_REGISTRY[fc_type]( d_model, self.up_proj.out_features, **self.fc_kwargs, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.up_proj(x)) * self.gate(x)) + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) FFN_CLASS_REGISTRY = { 'mptmlp': MPTMLP, - 'mptgeglu': MPTGeGLU, + 'mptglu': MPTGLU, } if te is not None: @@ -169,10 +169,10 @@ def build_ffn( **kwargs: Any, ) -> nn.Module: ffn_type = kwargs.pop('ffn_type') - if ffn_type in ['mptmlp', 'mptgeglu']: + if ffn_type in ['mptmlp', 'mptglu']: if len(kwargs) > 0: raise ValueError( - f'MPTMLP (or MPTGeGLU) got an unexpected keyword argument: {kwargs}' + f'MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}' ) return FFN_CLASS_REGISTRY[ffn_type]( d_model=d_model, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 6c4c286712..ae4754108c 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -107,7 +107,7 @@ def __init__( factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: - ffn_type (str): type of ffn to use. Options: mptmlp, mptgeglu, te_ln_mlp + ffn_type (str): type of ffn to use. Options: mptmlp, mptglu, te_ln_mlp init_device (str): The device to use for parameter initialization. logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. no_bias (bool): Whether to use bias in all layers. @@ -291,7 +291,13 @@ def _validate_config(self) -> None: + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' ) - if self.ffn_config['ffn_type'] in ['mptmlp', 'mptgeglu']: + if self.ffn_config['ffn_type'] == 'mptgeglu': + raise ValueError( + 'API CHANGE: `ffn_type=="mptgeglu"` changed to `ffn_type=="mptglu"`. ' + + + 'See [#829](https://github.com/mosaicml/llm-foundry/pull/829) for details.' + ) + elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']: self.ffn_config['fc_type'] = self.fc_type elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 2419dbfa41..7bccad089d 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -350,7 +350,7 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): [('torch', torch.float16), ('torch', torch.bfloat16), pytest.param('flash', torch.float16, marks=pytest.mark.gpu), pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)]) -@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptgeglu']) +@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptglu']) @pytest.mark.parametrize('ffn_act_fn', [ None, {