Skip to content

Commit

Permalink
irene pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Dec 14, 2023
1 parent 02459ec commit f7c02e6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
22 changes: 13 additions & 9 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,33 @@

log = logging.getLogger(__name__)

ffn_act_fn_default = {
_FFN_ACT_FN_DEFAULT = {
'name': 'gelu',
'approximate': 'none',
}


def resolve_ffn_act_fn(config: Optional[dict] = None) -> Callable:
def resolve_ffn_act_fn(
config: Optional[dict] = None,) -> Callable[[torch.Tensor], torch.Tensor]:
"""Resolve the activation function for the feed-forward network.
Args:
config (Optional[dict]): The configuration dictionary for the activation function.
Returns:
Callable: The activation function.
Callable[[torch.Tensor], torch.Tensor]: The activation function.
"""
config = deepcopy(config or ffn_act_fn_default)
config = deepcopy(config or _FFN_ACT_FN_DEFAULT)
name = config.pop('name')
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognised activation function name ({name}).')
act = getattr(torch.nn.functional, name)
return partial(act, **config)


_DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT)


def resolve_ffn_hidden_size(
d_model: int,
expansion_ratio: Union[int, float],
Expand Down Expand Up @@ -79,7 +83,7 @@ def __init__(
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
ffn_hidden_size: Optional[int] = None,
ffn_act_fn: Optional[dict] = None,
act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
device: Optional[str] = None,
bias: bool = True,
):
Expand All @@ -97,7 +101,7 @@ def __init__(
ffn_hidden_size,
**self.fc_kwargs,
)
self.act = resolve_ffn_act_fn(ffn_act_fn)
self.act = act_fn
self.down_proj = FC_CLASS_REGISTRY[fc_type](
ffn_hidden_size,
d_model,
Expand All @@ -117,7 +121,7 @@ def __init__(
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
ffn_hidden_size: Optional[int] = None,
ffn_act_fn: Optional[dict] = None,
act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
device: Optional[str] = None,
bias: bool = True,
):
Expand All @@ -126,7 +130,7 @@ def __init__(
expansion_ratio=expansion_ratio,
fc_type=fc_type,
ffn_hidden_size=ffn_hidden_size,
ffn_act_fn=ffn_act_fn,
act_fn=act_fn,
device=device,
bias=bias,
)
Expand Down Expand Up @@ -170,8 +174,8 @@ def build_ffn(
d_model=d_model,
expansion_ratio=expansion_ratio,
fc_type=fc_type,
act_fn=resolve_ffn_act_fn(ffn_act_fn),
ffn_hidden_size=ffn_hidden_size,
ffn_act_fn=ffn_act_fn,
device=device,
bias=bias,
)
Expand Down
3 changes: 3 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ def _validate_config(self) -> None:
self.ffn_config['fc_type'] = self.fc_type
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
self.ffn_config['bias'] = not self.no_bias
if 'ffn_act_fn' in self.ffn_config.keys():
raise ValueError(
f'te block does not support custom activation functions.')
if not self.use_pad_tok_in_ffn:
try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down

0 comments on commit f7c02e6

Please sign in to comment.