Skip to content

Commit

Permalink
enabling quick_gelu fn
Browse files Browse the repository at this point in the history
  • Loading branch information
gupta-abhay committed Jul 29, 2024
1 parent 6f4aa8c commit c08c0a6
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@
}


def quickgelu_activation(input: torch.Tensor) -> torch.Tensor:
"""
Applies GELU approximation that is fast but somewhat inaccurate.
See: https://github.com/hendrycks/GELUs
"""
return input * torch.sigmoid(1.702 * input)


def resolve_ffn_act_fn(
config: Optional[dict] = None,
) -> Callable[[torch.Tensor], torch.Tensor]:
Expand All @@ -70,10 +78,13 @@ def resolve_ffn_act_fn(
config = _FFN_ACT_FN_DEFAULT
config = deepcopy(config)
name = config.pop('name')
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognized activation function name ({name}).')
act = getattr(torch.nn.functional, name)
return partial(act, **config)
if name == 'quick_gelu':
return quickgelu_activation
else:
if not hasattr(torch.nn.functional, name):
raise ValueError(f'Unrecognized activation function name ({name}).')
act = getattr(torch.nn.functional, name)
return partial(act, **config)


_DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT)
Expand Down

0 comments on commit c08c0a6

Please sign in to comment.