Skip to content

Commit

Permalink
remove superfulous return; add doc str
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Dec 13, 2023
1 parent 0797aa6 commit ba3a463
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
22 changes: 16 additions & 6 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@
log = logging.getLogger(__name__)


def _resolve_ffn_hidden_and_exp_ratio(
def resolve_ffn_hidden_and_exp_ratio(
d_model: int,
expansion_ratio: Union[int, float],
ffn_hidden_size: Optional[int] = None,
) -> tuple[Union[int, float], int]:
) -> int:
"""Resolve the hidden size of the feed-forward network.
Args:
d_model (int): The dimension of the input and output of the feed-forward network.
expansion_ratio (Union[int, float]): The expansion ratio of the feed-forward network.
ffn_hidden_size (Optional[int]): The hidden size of the feed-forward network.
Returns:
int: The hidden size of the feed-forward network.
"""
if ffn_hidden_size is not None:
log.info(
f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.'
Expand All @@ -32,9 +42,9 @@ def _resolve_ffn_hidden_and_exp_ratio(
ffn_hidden_size = int(d_model * expansion_ratio)
if ffn_hidden_size != d_model * expansion_ratio:
raise ValueError(
f'`d_model * expansion_ratio` ({ffn_hidden_size}) must be an integer.'
f'`d_model * expansion_ratio` (={d_model * expansion_ratio}) must be an integer.'
)
return expansion_ratio, ffn_hidden_size
return ffn_hidden_size


class MPTMLP(nn.Module):
Expand All @@ -49,7 +59,7 @@ def __init__(
bias: bool = True,
):
super().__init__()
expansion_ratio, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio(
ffn_hidden_size = resolve_ffn_hidden_and_exp_ratio(
d_model, expansion_ratio, ffn_hidden_size)
self.fc_kwargs: dict[str, Any] = {
'bias': bias,
Expand Down Expand Up @@ -138,7 +148,7 @@ def build_ffn(
)
elif ffn_type == 'te_ln_mlp':
assert te is not None
_, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio(
ffn_hidden_size = resolve_ffn_hidden_and_exp_ratio(
d_model, expansion_ratio, ffn_hidden_size)
return te.LayerNormMLP(
hidden_size=d_model,
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
d_model (int): The size of the embedding dimension of the model.
n_heads (int): The number of attention heads.
n_layers (int): The number of layers in the model.
expansion_ratio (int, float): The ratio of the up/down scale in the ffn.
expansion_ratio (Union[int, float]): The ratio of the up/down scale in the ffn.
max_seq_len (int): The maximum sequence length of the model.
vocab_size (int): The size of the vocabulary.
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
Expand Down

0 comments on commit ba3a463

Please sign in to comment.