Skip to content

Commit

Permalink
Enable GLU FFN type (#796)
Browse files Browse the repository at this point in the history
* add glu ffn

* add ffn_type to determinism test

* Update llmfoundry/models/layers/ffn.py

Co-authored-by: Daniel King <[email protected]>

* pr comments

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
vchiley and dakinggg authored Dec 13, 2023
1 parent 1c3c909 commit 0797aa6
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 23 deletions.
88 changes: 74 additions & 14 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""GPT Blocks used for the GPT Model."""
"""MPT Blocks used for the MPT Model."""

from typing import Any, Optional
import logging
from typing import Any, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -15,42 +16,96 @@
except:
te = None

log = logging.getLogger(__name__)


def _resolve_ffn_hidden_and_exp_ratio(
d_model: int,
expansion_ratio: Union[int, float],
ffn_hidden_size: Optional[int] = None,
) -> tuple[Union[int, float], int]:
if ffn_hidden_size is not None:
log.info(
f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.'
)
else:
ffn_hidden_size = int(d_model * expansion_ratio)
if ffn_hidden_size != d_model * expansion_ratio:
raise ValueError(
f'`d_model * expansion_ratio` ({ffn_hidden_size}) must be an integer.'
)
return expansion_ratio, ffn_hidden_size


class MPTMLP(nn.Module):

def __init__(
self,
d_model: int,
expansion_ratio: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
ffn_hidden_size: Optional[int] = None,
device: Optional[str] = None,
bias: bool = True,
):
super().__init__()
fc_kwargs: dict[str, Any] = {
expansion_ratio, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio(
d_model, expansion_ratio, ffn_hidden_size)
self.fc_kwargs: dict[str, Any] = {
'bias': bias,
}
if fc_type != 'te':
fc_kwargs['device'] = device
self.fc_kwargs['device'] = device

self.up_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
expansion_ratio * d_model,
**fc_kwargs,
ffn_hidden_size,
**self.fc_kwargs,
)
self.act = nn.GELU(approximate='none')
self.down_proj = FC_CLASS_REGISTRY[fc_type](
expansion_ratio * d_model,
ffn_hidden_size,
d_model,
**fc_kwargs,
**self.fc_kwargs,
)
self.down_proj._is_residual = True

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.up_proj(x)))


class MPTGeGLU(MPTMLP):

def __init__(
self,
d_model: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
ffn_hidden_size: Optional[int] = None,
device: Optional[str] = None,
bias: bool = True,
):
super().__init__(
d_model=d_model,
expansion_ratio=expansion_ratio,
fc_type=fc_type,
ffn_hidden_size=ffn_hidden_size,
device=device,
bias=bias,
)
self.gate = FC_CLASS_REGISTRY[fc_type](
d_model,
self.up_proj.out_features,
**self.fc_kwargs,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.up_proj(x)) * self.gate(x))


FFN_CLASS_REGISTRY = {
'mptmlp': MPTMLP,
'mptgeglu': MPTGeGLU,
}

if te is not None:
Expand All @@ -60,29 +115,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

def build_ffn(
d_model: int,
expansion_ratio: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
ffn_hidden_size: Optional[int] = None,
device: Optional[str] = None,
bias: bool = True,
**kwargs: Any,
) -> nn.Module:
ffn_type = kwargs.pop('ffn_type')
if ffn_type == 'mptmlp':
if ffn_type in ['mptmlp', 'mptgeglu']:
if len(kwargs) > 0:
raise ValueError(
f'MPTMLP got an unexpected keyword argument: {kwargs}')
return MPTMLP(
f'MPTMLP (or MPTGeGLU) got an unexpected keyword argument: {kwargs}'
)
return FFN_CLASS_REGISTRY[ffn_type](
d_model=d_model,
expansion_ratio=expansion_ratio,
fc_type=fc_type,
ffn_hidden_size=ffn_hidden_size,
device=device,
bias=bias,
)
elif ffn_type == 'te_ln_mlp':
assert te is not None
_, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio(
d_model, expansion_ratio, ffn_hidden_size)
return te.LayerNormMLP(
hidden_size=d_model,
ffn_hidden_size=d_model * expansion_ratio,
ffn_hidden_size=ffn_hidden_size,
bias=bias,
**kwargs,
)
Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
expansion_ratio: int = 4,
expansion_ratio: Union[int, float] = 4,
max_seq_len: int = 2048,
vocab_size: int = 50368,
resid_pdrop: float = 0.0,
Expand All @@ -70,7 +70,7 @@ def __init__(
d_model (int): The size of the embedding dimension of the model.
n_heads (int): The number of attention heads.
n_layers (int): The number of layers in the model.
expansion_ratio (int): The ratio of the up/down scale in the ffn.
expansion_ratio (int, float): The ratio of the up/down scale in the ffn.
max_seq_len (int): The maximum sequence length of the model.
vocab_size (int): The size of the vocabulary.
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(
factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
ffn_config (Dict): A dictionary used to configure the model's ffn module:
ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp
ffn_type (str): type of ffn to use. Options: mptmlp, mptgeglu, te_ln_mlp
init_device (str): The device to use for parameter initialization.
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
no_bias (bool): Whether to use bias in all layers.
Expand Down Expand Up @@ -291,7 +291,7 @@ def _validate_config(self) -> None:
+ 'pip install flash-attn==1.0.6 --no-build-isolation \n' +
'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156'
)
if self.ffn_config['ffn_type'] == 'mptmlp':
if self.ffn_config['ffn_type'] in ['mptmlp', 'mptgeglu']:
self.ffn_config['fc_type'] = self.fc_type
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
self.ffn_config['bias'] = not self.no_bias
Expand Down
19 changes: 14 additions & 5 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,19 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2):
[('torch', torch.float16), ('torch', torch.bfloat16),
pytest.param('flash', torch.float16, marks=pytest.mark.gpu),
pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)])
def test_determinism(attn_impl: str, precision: torch.dtype):
@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptgeglu'])
def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str):
conf_path = 'scripts/train/yamls/pretrain/testing.yaml'
with open(conf_path) as f:
test_cfg = om.load(f)

test_cfg.model.attn_config = {
'attn_impl': attn_impl,
}
if hasattr(test_cfg.model, 'ffn_config'):
test_cfg.model.ffn_config['ffn_type'] = ffn_type
else:
test_cfg.model.setdefault('ffn_config', {'ffn_type': ffn_type})
test_cfg.model.init_device = 'cuda:0'
test_cfg.device = 'cuda:0'

Expand Down Expand Up @@ -552,11 +557,15 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
assert block.norm_2 is not None
assert block.norm_2.weight.shape == torch.Size([d_model])
assert isinstance(block.ffn.up_proj, nn.Linear)
assert block.ffn.up_proj.weight.shape == torch.Size(
[hf_config.d_model * hf_config.expansion_ratio, hf_config.d_model])
assert block.ffn.up_proj.weight.shape == torch.Size([
int(hf_config.d_model * hf_config.expansion_ratio),
hf_config.d_model
])
assert isinstance(block.ffn.down_proj, nn.Linear)
assert block.ffn.down_proj.weight.shape == torch.Size(
[hf_config.d_model, hf_config.d_model * hf_config.expansion_ratio])
assert block.ffn.down_proj.weight.shape == torch.Size([
hf_config.d_model,
int(hf_config.d_model * hf_config.expansion_ratio)
])
assert block.resid_attn_dropout.p == 0.2
assert block.resid_ffn_dropout.p == 0.2

Expand Down

0 comments on commit 0797aa6

Please sign in to comment.