Skip to content

Commit

Permalink
align glu impl with llama
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Dec 31, 2023
1 parent 2b1fa79 commit adf22a2
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.up_proj(x)))


class MPTGeGLU(MPTMLP):
class MPTGLU(MPTMLP):

def __init__(
self,
Expand All @@ -138,19 +138,19 @@ def __init__(
device=device,
bias=bias,
)
self.gate = FC_CLASS_REGISTRY[fc_type](
self.gate_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
self.up_proj.out_features,
**self.fc_kwargs,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.up_proj(x)) * self.gate(x))
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))


FFN_CLASS_REGISTRY = {
'mptmlp': MPTMLP,
'mptgeglu': MPTGeGLU,
'mptglu': MPTGLU,
}

if te is not None:
Expand All @@ -169,10 +169,10 @@ def build_ffn(
**kwargs: Any,
) -> nn.Module:
ffn_type = kwargs.pop('ffn_type')
if ffn_type in ['mptmlp', 'mptgeglu']:
if ffn_type in ['mptmlp', 'mptglu']:
if len(kwargs) > 0:
raise ValueError(
f'MPTMLP (or MPTGeGLU) got an unexpected keyword argument: {kwargs}'
f'MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}'
)
return FFN_CLASS_REGISTRY[ffn_type](
d_model=d_model,
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
ffn_config (Dict): A dictionary used to configure the model's ffn module:
ffn_type (str): type of ffn to use. Options: mptmlp, mptgeglu, te_ln_mlp
ffn_type (str): type of ffn to use. Options: mptmlp, mptglu, te_ln_mlp
init_device (str): The device to use for parameter initialization.
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
no_bias (bool): Whether to use bias in all layers.
Expand Down Expand Up @@ -291,7 +291,7 @@ def _validate_config(self) -> None:
+ 'pip install flash-attn==1.0.6 --no-build-isolation \n' +
'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156'
)
if self.ffn_config['ffn_type'] in ['mptmlp', 'mptgeglu']:
if self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']:
self.ffn_config['fc_type'] = self.fc_type
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
self.ffn_config['bias'] = not self.no_bias
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2):
[('torch', torch.float16), ('torch', torch.bfloat16),
pytest.param('flash', torch.float16, marks=pytest.mark.gpu),
pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)])
@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptgeglu'])
@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptglu'])
@pytest.mark.parametrize('ffn_act_fn', [
None,
{
Expand Down

0 comments on commit adf22a2

Please sign in to comment.