From 0797aa66f87d3a8b41f70f8a8be5529a3337ea68 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Tue, 12 Dec 2023 19:28:06 -0800 Subject: [PATCH] Enable GLU FFN type (#796) * add glu ffn * add ffn_type to determinism test * Update llmfoundry/models/layers/ffn.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * pr comments --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/models/layers/ffn.py | 88 ++++++++++++++++++---- llmfoundry/models/mpt/configuration_mpt.py | 8 +- tests/models/test_model.py | 19 +++-- 3 files changed, 92 insertions(+), 23 deletions(-) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 2f6d05f424..8f37b39306 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -1,9 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -"""GPT Blocks used for the GPT Model.""" +"""MPT Blocks used for the MPT Model.""" -from typing import Any, Optional +import logging +from typing import Any, Optional, Union import torch import torch.nn as nn @@ -15,33 +16,57 @@ except: te = None +log = logging.getLogger(__name__) + + +def _resolve_ffn_hidden_and_exp_ratio( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int] = None, +) -> tuple[Union[int, float], int]: + if ffn_hidden_size is not None: + log.info( + f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.' + ) + else: + ffn_hidden_size = int(d_model * expansion_ratio) + if ffn_hidden_size != d_model * expansion_ratio: + raise ValueError( + f'`d_model * expansion_ratio` ({ffn_hidden_size}) must be an integer.' + ) + return expansion_ratio, ffn_hidden_size + class MPTMLP(nn.Module): def __init__( self, d_model: int, - expansion_ratio: int, + expansion_ratio: Union[int, float], fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, device: Optional[str] = None, bias: bool = True, ): super().__init__() - fc_kwargs: dict[str, Any] = { + expansion_ratio, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio( + d_model, expansion_ratio, ffn_hidden_size) + self.fc_kwargs: dict[str, Any] = { 'bias': bias, } if fc_type != 'te': - fc_kwargs['device'] = device + self.fc_kwargs['device'] = device + self.up_proj = FC_CLASS_REGISTRY[fc_type]( d_model, - expansion_ratio * d_model, - **fc_kwargs, + ffn_hidden_size, + **self.fc_kwargs, ) self.act = nn.GELU(approximate='none') self.down_proj = FC_CLASS_REGISTRY[fc_type]( - expansion_ratio * d_model, + ffn_hidden_size, d_model, - **fc_kwargs, + **self.fc_kwargs, ) self.down_proj._is_residual = True @@ -49,8 +74,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.up_proj(x))) +class MPTGeGLU(MPTMLP): + + def __init__( + self, + d_model: int, + expansion_ratio: Union[int, float], + fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, + device: Optional[str] = None, + bias: bool = True, + ): + super().__init__( + d_model=d_model, + expansion_ratio=expansion_ratio, + fc_type=fc_type, + ffn_hidden_size=ffn_hidden_size, + device=device, + bias=bias, + ) + self.gate = FC_CLASS_REGISTRY[fc_type]( + d_model, + self.up_proj.out_features, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x)) * self.gate(x)) + + FFN_CLASS_REGISTRY = { 'mptmlp': MPTMLP, + 'mptgeglu': MPTGeGLU, } if te is not None: @@ -60,29 +115,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def build_ffn( d_model: int, - expansion_ratio: int, + expansion_ratio: Union[int, float], fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, device: Optional[str] = None, bias: bool = True, **kwargs: Any, ) -> nn.Module: ffn_type = kwargs.pop('ffn_type') - if ffn_type == 'mptmlp': + if ffn_type in ['mptmlp', 'mptgeglu']: if len(kwargs) > 0: raise ValueError( - f'MPTMLP got an unexpected keyword argument: {kwargs}') - return MPTMLP( + f'MPTMLP (or MPTGeGLU) got an unexpected keyword argument: {kwargs}' + ) + return FFN_CLASS_REGISTRY[ffn_type]( d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, + ffn_hidden_size=ffn_hidden_size, device=device, bias=bias, ) elif ffn_type == 'te_ln_mlp': assert te is not None + _, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio( + d_model, expansion_ratio, ffn_hidden_size) return te.LayerNormMLP( hidden_size=d_model, - ffn_hidden_size=d_model * expansion_ratio, + ffn_hidden_size=ffn_hidden_size, bias=bias, **kwargs, ) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 6013c96d0b..b9b4929ad0 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -43,7 +43,7 @@ def __init__( d_model: int = 2048, n_heads: int = 16, n_layers: int = 24, - expansion_ratio: int = 4, + expansion_ratio: Union[int, float] = 4, max_seq_len: int = 2048, vocab_size: int = 50368, resid_pdrop: float = 0.0, @@ -70,7 +70,7 @@ def __init__( d_model (int): The size of the embedding dimension of the model. n_heads (int): The number of attention heads. n_layers (int): The number of layers in the model. - expansion_ratio (int): The ratio of the up/down scale in the ffn. + expansion_ratio (int, float): The ratio of the up/down scale in the ffn. max_seq_len (int): The maximum sequence length of the model. vocab_size (int): The size of the vocabulary. resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. @@ -107,7 +107,7 @@ def __init__( factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: - ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp + ffn_type (str): type of ffn to use. Options: mptmlp, mptgeglu, te_ln_mlp init_device (str): The device to use for parameter initialization. logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. no_bias (bool): Whether to use bias in all layers. @@ -291,7 +291,7 @@ def _validate_config(self) -> None: + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' ) - if self.ffn_config['ffn_type'] == 'mptmlp': + if self.ffn_config['ffn_type'] in ['mptmlp', 'mptgeglu']: self.ffn_config['fc_type'] = self.fc_type elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 2a24cc8732..13fe50d5cb 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -350,7 +350,8 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): [('torch', torch.float16), ('torch', torch.bfloat16), pytest.param('flash', torch.float16, marks=pytest.mark.gpu), pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)]) -def test_determinism(attn_impl: str, precision: torch.dtype): +@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptgeglu']) +def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str): conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: test_cfg = om.load(f) @@ -358,6 +359,10 @@ def test_determinism(attn_impl: str, precision: torch.dtype): test_cfg.model.attn_config = { 'attn_impl': attn_impl, } + if hasattr(test_cfg.model, 'ffn_config'): + test_cfg.model.ffn_config['ffn_type'] = ffn_type + else: + test_cfg.model.setdefault('ffn_config', {'ffn_type': ffn_type}) test_cfg.model.init_device = 'cuda:0' test_cfg.device = 'cuda:0' @@ -552,11 +557,15 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): assert block.norm_2 is not None assert block.norm_2.weight.shape == torch.Size([d_model]) assert isinstance(block.ffn.up_proj, nn.Linear) - assert block.ffn.up_proj.weight.shape == torch.Size( - [hf_config.d_model * hf_config.expansion_ratio, hf_config.d_model]) + assert block.ffn.up_proj.weight.shape == torch.Size([ + int(hf_config.d_model * hf_config.expansion_ratio), + hf_config.d_model + ]) assert isinstance(block.ffn.down_proj, nn.Linear) - assert block.ffn.down_proj.weight.shape == torch.Size( - [hf_config.d_model, hf_config.d_model * hf_config.expansion_ratio]) + assert block.ffn.down_proj.weight.shape == torch.Size([ + hf_config.d_model, + int(hf_config.d_model * hf_config.expansion_ratio) + ]) assert block.resid_attn_dropout.p == 0.2 assert block.resid_ffn_dropout.p == 0.2