From c08c0a6574761370a50cb7c22ff1bc6d97c59dd8 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 29 Jul 2024 19:21:08 +0000 Subject: [PATCH] enabling quick_gelu fn --- llmfoundry/models/layers/ffn.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index a28725ee0f..5b7ae86f4d 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -53,6 +53,14 @@ } +def quickgelu_activation(input: torch.Tensor) -> torch.Tensor: + """ + Applies GELU approximation that is fast but somewhat inaccurate. + See: https://github.com/hendrycks/GELUs + """ + return input * torch.sigmoid(1.702 * input) + + def resolve_ffn_act_fn( config: Optional[dict] = None, ) -> Callable[[torch.Tensor], torch.Tensor]: @@ -70,10 +78,13 @@ def resolve_ffn_act_fn( config = _FFN_ACT_FN_DEFAULT config = deepcopy(config) name = config.pop('name') - if not hasattr(torch.nn.functional, name): - raise ValueError(f'Unrecognized activation function name ({name}).') - act = getattr(torch.nn.functional, name) - return partial(act, **config) + if name == 'quick_gelu': + return quickgelu_activation + else: + if not hasattr(torch.nn.functional, name): + raise ValueError(f'Unrecognized activation function name ({name}).') + act = getattr(torch.nn.functional, name) + return partial(act, **config) _DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT)