From 96cf646c2a17692dfa24462f814cf392aa1dc4d6 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Tue, 12 Dec 2023 11:39:16 -0800 Subject: [PATCH 1/4] Adding a fix for Cross Entropy Loss for long sequence lengths. (#795) * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. --- .../models/layers/cross_entropy_loss.py | 173 ++++++++++++++++++ llmfoundry/models/mpt/modeling_mpt.py | 6 +- tests/models/test_model.py | 16 +- 3 files changed, 189 insertions(+), 6 deletions(-) create mode 100644 llmfoundry/models/layers/cross_entropy_loss.py diff --git a/llmfoundry/models/layers/cross_entropy_loss.py b/llmfoundry/models/layers/cross_entropy_loss.py new file mode 100644 index 0000000000..e3b0931701 --- /dev/null +++ b/llmfoundry/models/layers/cross_entropy_loss.py @@ -0,0 +1,173 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +# Copied from https://github.com/Dao-AILab/flash-attention/blob/f1a73d074002226c42ce65a1df170ecff9f022c0/flash_attn/losses/cross_entropy.py +# type: ignore + +# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py +# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and +# the losses we can get the global loss. There's no need to do it step by step +# (compute local max, exchange, compute exp, compute local sum, exchange, etc.) +# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py +import torch +import torch.nn as nn + +try: # This try...except is needed because hf transformers library requires it + import xentropy_cuda_lib +except Exception as e: + xentropy_cuda_lib = None + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if 'all_gather_into_tensor' not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +class SoftmaxCrossEntropyLossFn(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + logits, + labels, + smoothing=0.0, + ignored_index=-100, + inplace_backward=False, + process_group=None, + ): + """The forward function for softmax cross entropy loss. + + logits: (batch, vocab_size) + labels: (batch,) + If process_group is not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss needs to be aggregated across processes. + """ + batch, vocab_size = logits.shape + assert labels.shape == (batch,) + world_size = 1 if process_group is None else torch.distributed.get_world_size( + process_group) + ctx.total_classes = world_size * vocab_size + if world_size == 1: + losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) + losses.masked_fill_(labels == ignored_index, 0) + labels_local = labels + else: + rank = torch.distributed.get_rank(process_group) + vocab_start_index, vocab_end_index = rank * vocab_size, ( + rank + 1) * vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + labels_mask = (labels < vocab_start_index) | (labels >= + vocab_end_index) + ignored_mask = labels == ignored_index + labels_local = torch.where(ignored_mask, labels, + labels - vocab_start_index) + + # For tensor parallel cross entropy with smoothing, we want to pass in the total number + # of classes so that smoothing can be applied correctly. If total_classes=-1, use the + # last dimension of the input tensor. + losses, lse_local = xentropy_cuda_lib.forward( + logits, labels_local, smoothing, world_size * vocab_size) + assert lse_local.shape == (batch,) + assert losses.shape == (batch,) + losses.masked_fill_(ignored_mask, 0) + # For labels == ignored_index, the loss is always 0. + # If there's no smoothing, if labels are in the vocab of this partition, losses contains + # lse_local - predicted logit, and 0 otherwise. + # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains + # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) + # For labels not in the vocab of this partition, losses contains + # 0.1 * (lse_local - sum logit / total_classes). + + lse_allgather = torch.empty(world_size, + batch, + dtype=lse_local.dtype, + device=lse_local.device) + torch.distributed.all_gather_into_tensor(lse_allgather, + lse_local.contiguous(), + group=process_group) + handle_losses = torch.distributed.all_reduce( + losses, + op=torch.distributed.ReduceOp.SUM, + group=process_group, + async_op=True) + lse = torch.logsumexp(lse_allgather, dim=0) + # If there's no smoothing, the total losses are lse_local - predicted_logit, + # we just have to subtract the lse_local and add the lse (global). + # If there's smoothing=0.1, the total losses are + # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) + # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). + rank_per_sample = torch.div(labels, + vocab_size, + rounding_mode='floor') + lse_local = lse_allgather[ + rank_per_sample, + torch.arange(batch, device=lse_allgather.device)] + + handle_losses.wait() + if smoothing == 0.0: + losses += lse - lse_local + else: + losses += (1 - smoothing) * (lse - lse_local) + smoothing * ( + lse - lse_allgather.sum(dim=0)) + losses.masked_fill_(ignored_mask, 0) + + ctx.save_for_backward(logits, lse, labels_local) + ctx.smoothing = smoothing + ctx.ignored_index = ignored_index + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + def backward(ctx, grad_loss): + logits, lse, labels = ctx.saved_tensors + grad_loss = grad_loss.contiguous() + grad_loss.masked_fill_(labels == ctx.ignored_index, 0) + grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels, + ctx.smoothing, + ctx.inplace_backward, + ctx.total_classes) + return grad_logits, None, None, None, None, None, None + + +class CrossEntropyLoss(nn.Module): + + def __init__( + self, + ignore_index=-100, + reduction='mean', + label_smoothing=0.0, + inplace_backward=False, + process_group=None, + ): + super().__init__() + if xentropy_cuda_lib is None: + raise ValueError( + 'xentropy_cuda_lib is None, probably because importing xentropy_cuda_lib failed.' + ) + if reduction not in ['mean', 'none']: + raise NotImplementedError( + "Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.inplace_backward = inplace_backward + self.process_group = process_group + + def forward(self, input, target): + assert input.is_cuda and target.is_cuda + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = SoftmaxCrossEntropyLossFn.apply( + input, + target, + self.label_smoothing, + self.ignore_index, + self.inplace_backward, + self.process_group, + ) + if self.reduction == 'mean': + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 8c134e2b9f..c587edb723 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -972,7 +972,11 @@ def __init__( loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy') if loss_fn_config == 'fused_crossentropy': try: - from flash_attn.losses.cross_entropy import \ + # NOTE: The following is the original import statement from flash_attn library, which we have currently replaced with a copy pasted code from the same library's version 1.0.9. The reason is that using the CE loss from FA v2.3.2 results in an illegal memory access error at long sequence lengths (github issue: https://github.com/Dao-AILab/flash-attention/issues/714). + # from flash_attn.losses.cross_entropy import \ + # CrossEntropyLoss as FusedCrossEntropyLoss + # TODO: Once the problem with using FA v2's CE loss at longer sequence lengths is resolved (github issue: https://github.com/Dao-AILab/flash-attention/issues/714), revert back to directly importing the CE loss from FA library. + from llmfoundry.models.layers.cross_entropy_loss import \ CrossEntropyLoss as FusedCrossEntropyLoss self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 12d7b3de37..2a24cc8732 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -398,16 +398,22 @@ def test_determinism(attn_impl: str, precision: torch.dtype): @pytest.mark.gpu -def test_loss_fn(): +@pytest.mark.parametrize('ce_loss_implementation', + ['FA_v1_copied', 'FA_imported']) +def test_loss_fn(ce_loss_implementation: str): """Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function. We provide non-zero tolerances to account for small numerics differences between the two loss implementations. """ - try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip - except: - pytest.skip('Fused cross entropy was not installed') + if ce_loss_implementation == 'FA_imported': + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip + except: + pytest.skip('Fused cross entropy was not installed') + else: + from llmfoundry.models.layers.cross_entropy_loss import \ + CrossEntropyLoss as FusedCrossEntropyLoss # run numerical test in pure fp32 from torch.backends import cuda, cudnn From 1c3c909d0591ac47231d039e4df7f7e3b0b5a2b5 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Tue, 12 Dec 2023 12:20:59 -0800 Subject: [PATCH 2/4] Minor readme updates and bump min python version (#799) * minor readme cleanup * bump min python version --- README.md | 31 ++++++++++++++++++------------- setup.py | 2 +- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index f7b5148cf6..59869ba4bc 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ # LLM Foundry -This repository contains code for training, finetuning, evaluating, and deploying LLMs for inference with [Composer](https://github.com/mosaicml/composer) and the [MosaicML platform](https://forms.mosaicml.com/demo?utm_source=github.com&utm_medium=referral&utm_campaign=llm-foundry). Designed to be easy-to-use, efficient _and_ flexible, this codebase is designed to enable rapid experimentation with the latest techniques. +This repository contains code for training, finetuning, evaluating, and deploying LLMs for inference with [Composer](https://github.com/mosaicml/composer) and the [MosaicML platform](https://forms.mosaicml.com/demo?utm_source=github.com&utm_medium=referral&utm_campaign=llm-foundry). Designed to be easy-to-use, efficient _and_ flexible, this codebase enables rapid experimentation with the latest techniques. You'll find in this repo: * `llmfoundry/` - source code for models, datasets, callbacks, utilities, etc. @@ -45,15 +45,17 @@ You'll find in this repo: Mosaic Pretrained Transformers (MPT) are GPT-style models with some special features -- Flash Attention for efficiency, ALiBi for context length extrapolation, and stability improvements to mitigate loss spikes. As part of MosaicML's Foundation series, we have open-sourced several MPT models: -| Model | Context Length | Download | Demo | Commercial use? | -| ------------------ | -------------- | -------------------------------------------------- | ----------------------------------------------------------- | --------------- | -| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | | Yes | -| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | | Yes | -| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-30b-chat) | No | -| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | | Yes | -| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | | Yes | -| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-chat) | No | -| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | | Yes | +| Model | Context Length | Download | Commercial use? | +| ------------------ | -------------- | -------------------------------------------------- | --------------- | +| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | Yes | +| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | Yes | +| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | No | +| MPT-7b-8k | 8192 | https://huggingface.co/mosaicml/mpt-7b-8k | Yes | +| MPT-7b-8k-Chat | 8192 | https://huggingface.co/mosaicml/mpt-7b-8k-chat | No | +| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | Yes | +| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | Yes | +| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | No | +| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | Yes | To try out these models locally, [follow the instructions](https://github.com/mosaicml/llm-foundry/tree/main/scripts/inference#interactive-generation-with-modelgenerate) in `scripts/inference/README.md` to prompt HF models using our [hf_generate.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_generate.py) or [hf_chat.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_chat.py) scripts. @@ -75,6 +77,8 @@ Tutorial videos from the community: Something missing? Contribute with a PR! # Latest News +* [Blog: Announcing MPT-7B-8K: 8K Context Length for Document Understanding](https://www.mosaicml.com/blog/long-context-mpt-7b-8k) +* [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250) * [Blog: MPT-30B: Raising the bar for open-source foundation models](https://www.mosaicml.com/blog/mpt-30b) * [Blog: Introducing MPT-7B](https://www.mosaicml.com/blog/mpt-7b) * [Blog: Benchmarking LLMs on H100](https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1) @@ -115,9 +119,10 @@ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117 # Installation -This assumes you already have PyTorch and CMake installed. +This assumes you already have PyTorch, CMake, and packaging installed. If not, you can install them with `pip install cmake packaging torch`. To get started, clone the repo and set up your environment. Instructions to do so differ slightly depending on whether you're using Docker. + ### With Docker (recommended) We *strongly* recommend working with LLM Foundry inside a Docker container (see our recommended Docker image above). If you are doing so, follow these steps to clone the repo and install the requirements. @@ -179,7 +184,7 @@ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.o Notes: 1. `attn_impl: triton` does not work. -1. We don't yet have a docker img where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue. +1. We don't yet have a Docker image where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue. # Quickstart @@ -233,7 +238,7 @@ python inference/hf_generate.py \ "Here's a quick recipe for baking chocolate chip cookies: Start by" ``` -Note: the `composer` command used above to train the model refers to [Composer](https://github.com/mosaicml/composer) library's distributed launcher. +Note: the `composer` command used above to train the model refers to the [Composer](https://github.com/mosaicml/composer) library's distributed launcher. If you have a write-enabled [HuggingFace auth token](https://huggingface.co/docs/hub/security-tokens), you can optionally upload your model to the Hub! Just export your token like this: diff --git a/setup.py b/setup.py index a228105a4c..b4ff85f992 100644 --- a/setup.py +++ b/setup.py @@ -143,5 +143,5 @@ classifiers=classifiers, install_requires=install_requires, extras_require=extra_deps, - python_requires='>=3.7', + python_requires='>=3.9', ) From 0797aa66f87d3a8b41f70f8a8be5529a3337ea68 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Tue, 12 Dec 2023 19:28:06 -0800 Subject: [PATCH 3/4] Enable GLU FFN type (#796) * add glu ffn * add ffn_type to determinism test * Update llmfoundry/models/layers/ffn.py Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * pr comments --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/models/layers/ffn.py | 88 ++++++++++++++++++---- llmfoundry/models/mpt/configuration_mpt.py | 8 +- tests/models/test_model.py | 19 +++-- 3 files changed, 92 insertions(+), 23 deletions(-) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 2f6d05f424..8f37b39306 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -1,9 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -"""GPT Blocks used for the GPT Model.""" +"""MPT Blocks used for the MPT Model.""" -from typing import Any, Optional +import logging +from typing import Any, Optional, Union import torch import torch.nn as nn @@ -15,33 +16,57 @@ except: te = None +log = logging.getLogger(__name__) + + +def _resolve_ffn_hidden_and_exp_ratio( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int] = None, +) -> tuple[Union[int, float], int]: + if ffn_hidden_size is not None: + log.info( + f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.' + ) + else: + ffn_hidden_size = int(d_model * expansion_ratio) + if ffn_hidden_size != d_model * expansion_ratio: + raise ValueError( + f'`d_model * expansion_ratio` ({ffn_hidden_size}) must be an integer.' + ) + return expansion_ratio, ffn_hidden_size + class MPTMLP(nn.Module): def __init__( self, d_model: int, - expansion_ratio: int, + expansion_ratio: Union[int, float], fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, device: Optional[str] = None, bias: bool = True, ): super().__init__() - fc_kwargs: dict[str, Any] = { + expansion_ratio, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio( + d_model, expansion_ratio, ffn_hidden_size) + self.fc_kwargs: dict[str, Any] = { 'bias': bias, } if fc_type != 'te': - fc_kwargs['device'] = device + self.fc_kwargs['device'] = device + self.up_proj = FC_CLASS_REGISTRY[fc_type]( d_model, - expansion_ratio * d_model, - **fc_kwargs, + ffn_hidden_size, + **self.fc_kwargs, ) self.act = nn.GELU(approximate='none') self.down_proj = FC_CLASS_REGISTRY[fc_type]( - expansion_ratio * d_model, + ffn_hidden_size, d_model, - **fc_kwargs, + **self.fc_kwargs, ) self.down_proj._is_residual = True @@ -49,8 +74,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.up_proj(x))) +class MPTGeGLU(MPTMLP): + + def __init__( + self, + d_model: int, + expansion_ratio: Union[int, float], + fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, + device: Optional[str] = None, + bias: bool = True, + ): + super().__init__( + d_model=d_model, + expansion_ratio=expansion_ratio, + fc_type=fc_type, + ffn_hidden_size=ffn_hidden_size, + device=device, + bias=bias, + ) + self.gate = FC_CLASS_REGISTRY[fc_type]( + d_model, + self.up_proj.out_features, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x)) * self.gate(x)) + + FFN_CLASS_REGISTRY = { 'mptmlp': MPTMLP, + 'mptgeglu': MPTGeGLU, } if te is not None: @@ -60,29 +115,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def build_ffn( d_model: int, - expansion_ratio: int, + expansion_ratio: Union[int, float], fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, device: Optional[str] = None, bias: bool = True, **kwargs: Any, ) -> nn.Module: ffn_type = kwargs.pop('ffn_type') - if ffn_type == 'mptmlp': + if ffn_type in ['mptmlp', 'mptgeglu']: if len(kwargs) > 0: raise ValueError( - f'MPTMLP got an unexpected keyword argument: {kwargs}') - return MPTMLP( + f'MPTMLP (or MPTGeGLU) got an unexpected keyword argument: {kwargs}' + ) + return FFN_CLASS_REGISTRY[ffn_type]( d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, + ffn_hidden_size=ffn_hidden_size, device=device, bias=bias, ) elif ffn_type == 'te_ln_mlp': assert te is not None + _, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio( + d_model, expansion_ratio, ffn_hidden_size) return te.LayerNormMLP( hidden_size=d_model, - ffn_hidden_size=d_model * expansion_ratio, + ffn_hidden_size=ffn_hidden_size, bias=bias, **kwargs, ) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 6013c96d0b..b9b4929ad0 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -43,7 +43,7 @@ def __init__( d_model: int = 2048, n_heads: int = 16, n_layers: int = 24, - expansion_ratio: int = 4, + expansion_ratio: Union[int, float] = 4, max_seq_len: int = 2048, vocab_size: int = 50368, resid_pdrop: float = 0.0, @@ -70,7 +70,7 @@ def __init__( d_model (int): The size of the embedding dimension of the model. n_heads (int): The number of attention heads. n_layers (int): The number of layers in the model. - expansion_ratio (int): The ratio of the up/down scale in the ffn. + expansion_ratio (int, float): The ratio of the up/down scale in the ffn. max_seq_len (int): The maximum sequence length of the model. vocab_size (int): The size of the vocabulary. resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. @@ -107,7 +107,7 @@ def __init__( factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: - ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp + ffn_type (str): type of ffn to use. Options: mptmlp, mptgeglu, te_ln_mlp init_device (str): The device to use for parameter initialization. logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. no_bias (bool): Whether to use bias in all layers. @@ -291,7 +291,7 @@ def _validate_config(self) -> None: + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' ) - if self.ffn_config['ffn_type'] == 'mptmlp': + if self.ffn_config['ffn_type'] in ['mptmlp', 'mptgeglu']: self.ffn_config['fc_type'] = self.fc_type elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 2a24cc8732..13fe50d5cb 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -350,7 +350,8 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): [('torch', torch.float16), ('torch', torch.bfloat16), pytest.param('flash', torch.float16, marks=pytest.mark.gpu), pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)]) -def test_determinism(attn_impl: str, precision: torch.dtype): +@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptgeglu']) +def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str): conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: test_cfg = om.load(f) @@ -358,6 +359,10 @@ def test_determinism(attn_impl: str, precision: torch.dtype): test_cfg.model.attn_config = { 'attn_impl': attn_impl, } + if hasattr(test_cfg.model, 'ffn_config'): + test_cfg.model.ffn_config['ffn_type'] = ffn_type + else: + test_cfg.model.setdefault('ffn_config', {'ffn_type': ffn_type}) test_cfg.model.init_device = 'cuda:0' test_cfg.device = 'cuda:0' @@ -552,11 +557,15 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): assert block.norm_2 is not None assert block.norm_2.weight.shape == torch.Size([d_model]) assert isinstance(block.ffn.up_proj, nn.Linear) - assert block.ffn.up_proj.weight.shape == torch.Size( - [hf_config.d_model * hf_config.expansion_ratio, hf_config.d_model]) + assert block.ffn.up_proj.weight.shape == torch.Size([ + int(hf_config.d_model * hf_config.expansion_ratio), + hf_config.d_model + ]) assert isinstance(block.ffn.down_proj, nn.Linear) - assert block.ffn.down_proj.weight.shape == torch.Size( - [hf_config.d_model, hf_config.d_model * hf_config.expansion_ratio]) + assert block.ffn.down_proj.weight.shape == torch.Size([ + hf_config.d_model, + int(hf_config.d_model * hf_config.expansion_ratio) + ]) assert block.resid_attn_dropout.p == 0.2 assert block.resid_ffn_dropout.p == 0.2 From 5fdcc43bb31beccbd65bc0f18c5b12b7940445be Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Wed, 13 Dec 2023 10:16:40 -0800 Subject: [PATCH 4/4] clean up resolve_ffn_hidden_and_exp_ratio (#801) * remove superfulous return; add doc str * pr cmts; add test --- llmfoundry/models/layers/ffn.py | 26 ++++++++++----- llmfoundry/models/mpt/configuration_mpt.py | 2 +- tests/models/test_model.py | 38 +++++++++++++++------- 3 files changed, 46 insertions(+), 20 deletions(-) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 8f37b39306..e18e611ca6 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -19,11 +19,21 @@ log = logging.getLogger(__name__) -def _resolve_ffn_hidden_and_exp_ratio( +def resolve_ffn_hidden_size( d_model: int, expansion_ratio: Union[int, float], ffn_hidden_size: Optional[int] = None, -) -> tuple[Union[int, float], int]: +) -> int: + """Resolve the hidden size of the feed-forward network. + + Args: + d_model (int): The dimension of the input and output of the feed-forward network. + expansion_ratio (Union[int, float]): The expansion ratio of the feed-forward network. + ffn_hidden_size (Optional[int]): The hidden size of the feed-forward network. + + Returns: + int: The hidden size of the feed-forward network. + """ if ffn_hidden_size is not None: log.info( f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.' @@ -32,9 +42,9 @@ def _resolve_ffn_hidden_and_exp_ratio( ffn_hidden_size = int(d_model * expansion_ratio) if ffn_hidden_size != d_model * expansion_ratio: raise ValueError( - f'`d_model * expansion_ratio` ({ffn_hidden_size}) must be an integer.' + f'`d_model * expansion_ratio` must be an integer ({d_model=}; {expansion_ratio=}; {d_model * expansion_ratio=}).' ) - return expansion_ratio, ffn_hidden_size + return ffn_hidden_size class MPTMLP(nn.Module): @@ -49,8 +59,8 @@ def __init__( bias: bool = True, ): super().__init__() - expansion_ratio, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio( - d_model, expansion_ratio, ffn_hidden_size) + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size) self.fc_kwargs: dict[str, Any] = { 'bias': bias, } @@ -138,8 +148,8 @@ def build_ffn( ) elif ffn_type == 'te_ln_mlp': assert te is not None - _, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio( - d_model, expansion_ratio, ffn_hidden_size) + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size) return te.LayerNormMLP( hidden_size=d_model, ffn_hidden_size=ffn_hidden_size, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index b9b4929ad0..2ecc726aa3 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -70,7 +70,7 @@ def __init__( d_model (int): The size of the embedding dimension of the model. n_heads (int): The number of attention heads. n_layers (int): The number of layers in the model. - expansion_ratio (int, float): The ratio of the up/down scale in the ffn. + expansion_ratio (Union[int, float]): The ratio of the up/down scale in the ffn. max_seq_len (int): The maximum sequence length of the model. vocab_size (int): The size of the vocabulary. resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 13fe50d5cb..6d48d115fd 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -514,14 +514,21 @@ def test_opt_wrapping(): @pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys()) @pytest.mark.parametrize('no_bias', [False, True]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): +@pytest.mark.parametrize('expansion_ratio,ffn_hidden_size', [ + (2, None), + (1.231, None), + (2, 128), + (2, 256), +]) +def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool, + expansion_ratio: Union[int, float], ffn_hidden_size: int): # Test that the config constructs the model as expected. hf_config = MPTConfig( init_device='cpu', d_model=128, n_heads=4, n_layers=2, - expansion_ratio=2, + expansion_ratio=expansion_ratio, max_seq_len=2048, emb_pdrop=0.1, resid_pdrop=0.2, @@ -531,13 +538,24 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): norm_type=norm_type, no_bias=no_bias, tie_word_embeddings=tie_word_embeddings, + ffn_config={ + 'ffn_type': 'mptmlp', + 'ffn_hidden_size': ffn_hidden_size, + }, ) + if hf_config.d_model * hf_config.expansion_ratio != int( + hf_config.d_model * hf_config.expansion_ratio): + pytest.xfail('d_model * expansion_ratio must be an integer.') + mpt = MPTForCausalLM(hf_config) assert mpt.config.d_model == 128 assert mpt.config.n_heads == 4 assert mpt.config.n_layers == 2 - assert mpt.config.expansion_ratio == 2 + if ffn_hidden_size is None: + assert mpt.config.expansion_ratio == expansion_ratio + else: + assert mpt.config.ffn_config['ffn_hidden_size'] == ffn_hidden_size assert mpt.config.max_seq_len == 2048 assert mpt.transformer.wte.weight.shape == torch.Size( @@ -551,21 +569,19 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): assert len(mpt.transformer.blocks) == 2 d_model = hf_config.d_model + if ffn_hidden_size is None: + ffn_hidden_size = int(hf_config.d_model * hf_config.expansion_ratio) for block in mpt.transformer.blocks: assert isinstance(block, MPTBlock) assert block.norm_1.weight.shape == torch.Size([d_model]) assert block.norm_2 is not None assert block.norm_2.weight.shape == torch.Size([d_model]) assert isinstance(block.ffn.up_proj, nn.Linear) - assert block.ffn.up_proj.weight.shape == torch.Size([ - int(hf_config.d_model * hf_config.expansion_ratio), - hf_config.d_model - ]) + assert block.ffn.up_proj.weight.shape == torch.Size( + [ffn_hidden_size, hf_config.d_model]) assert isinstance(block.ffn.down_proj, nn.Linear) - assert block.ffn.down_proj.weight.shape == torch.Size([ - hf_config.d_model, - int(hf_config.d_model * hf_config.expansion_ratio) - ]) + assert block.ffn.down_proj.weight.shape == torch.Size( + [hf_config.d_model, ffn_hidden_size]) assert block.resid_attn_dropout.p == 0.2 assert block.resid_ffn_dropout.p == 0.2