Skip to content

Commit

Permalink
dblalock pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Dec 15, 2023
1 parent 70f8716 commit 635ab48
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
7 changes: 6 additions & 1 deletion llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def resolve_ffn_act_fn(
Args:
config (Optional[dict]): The configuration dictionary for the activation function.
The dict config must specify a 'function' or the 'name' of a function.
If 'function' is specified, a Callable function is expected. If 'name' is
specified, the name is expected to be the name of a `torch.nn.functional`
function. All of other key values pairs are bound to the function as a partial.
Returns:
Callable[[torch.Tensor], torch.Tensor]: The activation function.
Expand Down Expand Up @@ -196,7 +200,8 @@ def build_ffn(
ffn_hidden_size)
if ffn_act_fn is not None:
raise ValueError(
f'te block does not support custom activation functions.')
f'Transformer Engine block does not support custom activation functions.'
)
return te.LayerNormMLP(
hidden_size=d_model,
ffn_hidden_size=ffn_hidden_size,
Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def _validate_config(self) -> None:
self.ffn_config['bias'] = not self.no_bias
if 'ffn_act_fn' in self.ffn_config.keys():
raise ValueError(
f'te block does not support custom activation functions.')
f'Transformer Engine block does not support custom activation functions.'
)
if not self.use_pad_tok_in_ffn:
try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down
12 changes: 8 additions & 4 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,10 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2):
@pytest.mark.parametrize('ffn_act_fn', [
{
'name': 'gelu',
'approximate': 'none'
'approximate': 'none',
},
{
'name': 'silu'
'name': 'silu',
},
{
'name': 'relu',
Expand Down Expand Up @@ -545,10 +545,14 @@ def test_opt_wrapping():
@pytest.mark.parametrize('ffn_act_fn', [
{
'name': 'gelu',
'approximate': 'none'
'approximate': 'none',
},
{
'name': 'silu'
'name': 'silu',
},
{
'name': 'relu',
'inplace': True,
},
pytest.param({'name': 'relu5'},
marks=pytest.mark.xfail(reason='invalid choice.',
Expand Down

0 comments on commit 635ab48

Please sign in to comment.