Skip to content

Commit

Permalink
Merge branch 'main' into clean-logs
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Dec 15, 2023
2 parents ba2dd5d + 15e79f3 commit 0f46b61
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 8 deletions.
45 changes: 43 additions & 2 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"""MPT Blocks used for the MPT Model."""

import logging
from typing import Any, Optional, Union
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -18,6 +20,36 @@

log = logging.getLogger(__name__)

_FFN_ACT_FN_DEFAULT = {
'name': 'gelu',
'approximate': 'none',
}


def resolve_ffn_act_fn(
config: Optional[dict] = None,) -> Callable[[torch.Tensor], torch.Tensor]:
"""Resolve the activation function for the feed-forward network.
Args:
config (Optional[dict]): The configuration dictionary for the activation function.
The dict config must specify the 'name' of a torch.nn.functional activation
function. All of other key values pairs are bound to the function as a partial.
Returns:
Callable[[torch.Tensor], torch.Tensor]: The activation function.
"""
if config is None:
config = _FFN_ACT_FN_DEFAULT
config = deepcopy(config)
name = config.pop('name')
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognised activation function name ({name}).')
act = getattr(torch.nn.functional, name)
return partial(act, **config)


_DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT)


def resolve_ffn_hidden_size(
d_model: int,
Expand Down Expand Up @@ -55,6 +87,7 @@ def __init__(
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
ffn_hidden_size: Optional[int] = None,
act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
device: Optional[str] = None,
bias: bool = True,
):
Expand All @@ -72,7 +105,7 @@ def __init__(
ffn_hidden_size,
**self.fc_kwargs,
)
self.act = nn.GELU(approximate='none')
self.act = act_fn
self.down_proj = FC_CLASS_REGISTRY[fc_type](
ffn_hidden_size,
d_model,
Expand All @@ -92,6 +125,7 @@ def __init__(
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
ffn_hidden_size: Optional[int] = None,
act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
device: Optional[str] = None,
bias: bool = True,
):
Expand All @@ -100,6 +134,7 @@ def __init__(
expansion_ratio=expansion_ratio,
fc_type=fc_type,
ffn_hidden_size=ffn_hidden_size,
act_fn=act_fn,
device=device,
bias=bias,
)
Expand Down Expand Up @@ -128,6 +163,7 @@ def build_ffn(
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
ffn_hidden_size: Optional[int] = None,
ffn_act_fn: Optional[dict] = None,
device: Optional[str] = None,
bias: bool = True,
**kwargs: Any,
Expand All @@ -142,6 +178,7 @@ def build_ffn(
d_model=d_model,
expansion_ratio=expansion_ratio,
fc_type=fc_type,
act_fn=resolve_ffn_act_fn(ffn_act_fn),
ffn_hidden_size=ffn_hidden_size,
device=device,
bias=bias,
Expand All @@ -150,6 +187,10 @@ def build_ffn(
assert te is not None
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio,
ffn_hidden_size)
if ffn_act_fn is not None:
raise ValueError(
f'Transformer Engine block does not support custom activation functions.'
)
return te.LayerNormMLP(
hidden_size=d_model,
ffn_hidden_size=ffn_hidden_size,
Expand Down
4 changes: 4 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ def _validate_config(self) -> None:
self.ffn_config['fc_type'] = self.fc_type
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
self.ffn_config['bias'] = not self.no_bias
if 'ffn_act_fn' in self.ffn_config.keys():
raise ValueError(
f'Transformer Engine block does not support custom activation functions.'
)
if not self.use_pad_tok_in_ffn:
try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down
51 changes: 45 additions & 6 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,25 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2):
pytest.param('flash', torch.float16, marks=pytest.mark.gpu),
pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)])
@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptgeglu'])
def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str):
@pytest.mark.parametrize('ffn_act_fn', [
None,
{
'name': 'gelu',
'approximate': 'tanh',
},
{
'name': 'silu',
},
{
'name': 'relu',
'inplace': True,
},
pytest.param({'name': 'relu5'},
marks=pytest.mark.xfail(reason='invalid choice.',
strict=True)),
])
def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str,
ffn_act_fn: dict):
conf_path = 'scripts/train/yamls/pretrain/testing.yaml'
with open(conf_path) as f:
test_cfg = om.load(f)
Expand All @@ -363,6 +381,7 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str):
test_cfg.model.ffn_config['ffn_type'] = ffn_type
else:
test_cfg.model.setdefault('ffn_config', {'ffn_type': ffn_type})
test_cfg.model.ffn_config['ffn_act_fn'] = ffn_act_fn
test_cfg.model.init_device = 'cuda:0'
test_cfg.device = 'cuda:0'

Expand Down Expand Up @@ -516,12 +535,34 @@ def test_opt_wrapping():
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
@pytest.mark.parametrize('expansion_ratio,ffn_hidden_size', [
(2, None),
(1.231, None),
pytest.param(1.231,
None,
marks=pytest.mark.xfail(
reason='d_model * expansion_ratio must be an integer.',
strict=True)),
(2, 128),
(2, 256),
])
@pytest.mark.parametrize('ffn_act_fn', [
None,
{
'name': 'gelu',
'approximate': 'tanh',
},
{
'name': 'silu',
},
{
'name': 'relu',
'inplace': True,
},
pytest.param({'name': 'relu5'},
marks=pytest.mark.xfail(reason='invalid choice.',
strict=True)),
])
def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool,
expansion_ratio: Union[int, float], ffn_hidden_size: int):
expansion_ratio: Union[int, float], ffn_hidden_size: int,
ffn_act_fn: dict):
# Test that the config constructs the model as expected.
hf_config = MPTConfig(
init_device='cpu',
Expand All @@ -541,11 +582,9 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool,
ffn_config={
'ffn_type': 'mptmlp',
'ffn_hidden_size': ffn_hidden_size,
'ffn_act_fn': ffn_act_fn,
},
)
if hf_config.d_model * hf_config.expansion_ratio != int(
hf_config.d_model * hf_config.expansion_ratio):
pytest.xfail('d_model * expansion_ratio must be an integer.')

mpt = MPTForCausalLM(hf_config)

Expand Down

0 comments on commit 0f46b61

Please sign in to comment.