From 15e79f38a7d2fb729d52e3fe500fd3cfb0184c52 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Fri, 15 Dec 2023 11:48:55 -0800 Subject: [PATCH] Make the ffn activation func configurable (#805) * config ffn_act_fn v0 * irene pr comments * enable config to have function * updt test * dblalock pr comments * updt how default ffn_act config is set --- llmfoundry/models/layers/ffn.py | 45 ++++++++++++++++++- llmfoundry/models/mpt/configuration_mpt.py | 4 ++ tests/models/test_model.py | 51 +++++++++++++++++++--- 3 files changed, 92 insertions(+), 8 deletions(-) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index e18e611ca6..560e8c31fc 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -4,7 +4,9 @@ """MPT Blocks used for the MPT Model.""" import logging -from typing import Any, Optional, Union +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Optional, Union import torch import torch.nn as nn @@ -18,6 +20,36 @@ log = logging.getLogger(__name__) +_FFN_ACT_FN_DEFAULT = { + 'name': 'gelu', + 'approximate': 'none', +} + + +def resolve_ffn_act_fn( + config: Optional[dict] = None,) -> Callable[[torch.Tensor], torch.Tensor]: + """Resolve the activation function for the feed-forward network. + + Args: + config (Optional[dict]): The configuration dictionary for the activation function. + The dict config must specify the 'name' of a torch.nn.functional activation + function. All of other key values pairs are bound to the function as a partial. + + Returns: + Callable[[torch.Tensor], torch.Tensor]: The activation function. + """ + if config is None: + config = _FFN_ACT_FN_DEFAULT + config = deepcopy(config) + name = config.pop('name') + if not hasattr(torch.nn.functional, name): + raise ValueError(f'Unrecognised activation function name ({name}).') + act = getattr(torch.nn.functional, name) + return partial(act, **config) + + +_DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT) + def resolve_ffn_hidden_size( d_model: int, @@ -55,6 +87,7 @@ def __init__( expansion_ratio: Union[int, float], fc_type: str = 'torch', ffn_hidden_size: Optional[int] = None, + act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN, device: Optional[str] = None, bias: bool = True, ): @@ -72,7 +105,7 @@ def __init__( ffn_hidden_size, **self.fc_kwargs, ) - self.act = nn.GELU(approximate='none') + self.act = act_fn self.down_proj = FC_CLASS_REGISTRY[fc_type]( ffn_hidden_size, d_model, @@ -92,6 +125,7 @@ def __init__( expansion_ratio: Union[int, float], fc_type: str = 'torch', ffn_hidden_size: Optional[int] = None, + act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN, device: Optional[str] = None, bias: bool = True, ): @@ -100,6 +134,7 @@ def __init__( expansion_ratio=expansion_ratio, fc_type=fc_type, ffn_hidden_size=ffn_hidden_size, + act_fn=act_fn, device=device, bias=bias, ) @@ -128,6 +163,7 @@ def build_ffn( expansion_ratio: Union[int, float], fc_type: str = 'torch', ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, device: Optional[str] = None, bias: bool = True, **kwargs: Any, @@ -142,6 +178,7 @@ def build_ffn( d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, + act_fn=resolve_ffn_act_fn(ffn_act_fn), ffn_hidden_size=ffn_hidden_size, device=device, bias=bias, @@ -150,6 +187,10 @@ def build_ffn( assert te is not None ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size) + if ffn_act_fn is not None: + raise ValueError( + f'Transformer Engine block does not support custom activation functions.' + ) return te.LayerNormMLP( hidden_size=d_model, ffn_hidden_size=ffn_hidden_size, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 2ecc726aa3..913c39d44f 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -295,6 +295,10 @@ def _validate_config(self) -> None: self.ffn_config['fc_type'] = self.fc_type elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias + if 'ffn_act_fn' in self.ffn_config.keys(): + raise ValueError( + f'Transformer Engine block does not support custom activation functions.' + ) if not self.use_pad_tok_in_ffn: try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 6d48d115fd..c61e963e55 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -351,7 +351,25 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): pytest.param('flash', torch.float16, marks=pytest.mark.gpu), pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)]) @pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptgeglu']) -def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str): +@pytest.mark.parametrize('ffn_act_fn', [ + None, + { + 'name': 'gelu', + 'approximate': 'tanh', + }, + { + 'name': 'silu', + }, + { + 'name': 'relu', + 'inplace': True, + }, + pytest.param({'name': 'relu5'}, + marks=pytest.mark.xfail(reason='invalid choice.', + strict=True)), +]) +def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str, + ffn_act_fn: dict): conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: test_cfg = om.load(f) @@ -363,6 +381,7 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str): test_cfg.model.ffn_config['ffn_type'] = ffn_type else: test_cfg.model.setdefault('ffn_config', {'ffn_type': ffn_type}) + test_cfg.model.ffn_config['ffn_act_fn'] = ffn_act_fn test_cfg.model.init_device = 'cuda:0' test_cfg.device = 'cuda:0' @@ -516,12 +535,34 @@ def test_opt_wrapping(): @pytest.mark.parametrize('tie_word_embeddings', [True, False]) @pytest.mark.parametrize('expansion_ratio,ffn_hidden_size', [ (2, None), - (1.231, None), + pytest.param(1.231, + None, + marks=pytest.mark.xfail( + reason='d_model * expansion_ratio must be an integer.', + strict=True)), (2, 128), (2, 256), ]) +@pytest.mark.parametrize('ffn_act_fn', [ + None, + { + 'name': 'gelu', + 'approximate': 'tanh', + }, + { + 'name': 'silu', + }, + { + 'name': 'relu', + 'inplace': True, + }, + pytest.param({'name': 'relu5'}, + marks=pytest.mark.xfail(reason='invalid choice.', + strict=True)), +]) def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool, - expansion_ratio: Union[int, float], ffn_hidden_size: int): + expansion_ratio: Union[int, float], ffn_hidden_size: int, + ffn_act_fn: dict): # Test that the config constructs the model as expected. hf_config = MPTConfig( init_device='cpu', @@ -541,11 +582,9 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool, ffn_config={ 'ffn_type': 'mptmlp', 'ffn_hidden_size': ffn_hidden_size, + 'ffn_act_fn': ffn_act_fn, }, ) - if hf_config.d_model * hf_config.expansion_ratio != int( - hf_config.d_model * hf_config.expansion_ratio): - pytest.xfail('d_model * expansion_ratio must be an integer.') mpt = MPTForCausalLM(hf_config)