diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 5a86b9424c..3dbdbacb3a 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -32,6 +32,10 @@ def resolve_ffn_act_fn( Args: config (Optional[dict]): The configuration dictionary for the activation function. + The dict config must specify a 'function' or the 'name' of a function. + If 'function' is specified, a Callable function is expected. If 'name' is + specified, the name is expected to be the name of a `torch.nn.functional` + function. All of other key values pairs are bound to the function as a partial. Returns: Callable[[torch.Tensor], torch.Tensor]: The activation function. @@ -196,7 +200,8 @@ def build_ffn( ffn_hidden_size) if ffn_act_fn is not None: raise ValueError( - f'te block does not support custom activation functions.') + f'Transformer Engine block does not support custom activation functions.' + ) return te.LayerNormMLP( hidden_size=d_model, ffn_hidden_size=ffn_hidden_size, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 388ed148f7..913c39d44f 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -297,7 +297,8 @@ def _validate_config(self) -> None: self.ffn_config['bias'] = not self.no_bias if 'ffn_act_fn' in self.ffn_config.keys(): raise ValueError( - f'te block does not support custom activation functions.') + f'Transformer Engine block does not support custom activation functions.' + ) if not self.use_pad_tok_in_ffn: try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 90261e4c4a..04fa2aa7f6 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -354,10 +354,10 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): @pytest.mark.parametrize('ffn_act_fn', [ { 'name': 'gelu', - 'approximate': 'none' + 'approximate': 'none', }, { - 'name': 'silu' + 'name': 'silu', }, { 'name': 'relu', @@ -545,10 +545,14 @@ def test_opt_wrapping(): @pytest.mark.parametrize('ffn_act_fn', [ { 'name': 'gelu', - 'approximate': 'none' + 'approximate': 'none', }, { - 'name': 'silu' + 'name': 'silu', + }, + { + 'name': 'relu', + 'inplace': True, }, pytest.param({'name': 'relu5'}, marks=pytest.mark.xfail(reason='invalid choice.',