diff --git a/README.md b/README.md index f7b5148cf6..59869ba4bc 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ # LLM Foundry -This repository contains code for training, finetuning, evaluating, and deploying LLMs for inference with [Composer](https://github.com/mosaicml/composer) and the [MosaicML platform](https://forms.mosaicml.com/demo?utm_source=github.com&utm_medium=referral&utm_campaign=llm-foundry). Designed to be easy-to-use, efficient _and_ flexible, this codebase is designed to enable rapid experimentation with the latest techniques. +This repository contains code for training, finetuning, evaluating, and deploying LLMs for inference with [Composer](https://github.com/mosaicml/composer) and the [MosaicML platform](https://forms.mosaicml.com/demo?utm_source=github.com&utm_medium=referral&utm_campaign=llm-foundry). Designed to be easy-to-use, efficient _and_ flexible, this codebase enables rapid experimentation with the latest techniques. You'll find in this repo: * `llmfoundry/` - source code for models, datasets, callbacks, utilities, etc. @@ -45,15 +45,17 @@ You'll find in this repo: Mosaic Pretrained Transformers (MPT) are GPT-style models with some special features -- Flash Attention for efficiency, ALiBi for context length extrapolation, and stability improvements to mitigate loss spikes. As part of MosaicML's Foundation series, we have open-sourced several MPT models: -| Model | Context Length | Download | Demo | Commercial use? | -| ------------------ | -------------- | -------------------------------------------------- | ----------------------------------------------------------- | --------------- | -| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | | Yes | -| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | | Yes | -| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-30b-chat) | No | -| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | | Yes | -| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | | Yes | -| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-chat) | No | -| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | | Yes | +| Model | Context Length | Download | Commercial use? | +| ------------------ | -------------- | -------------------------------------------------- | --------------- | +| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | Yes | +| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | Yes | +| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | No | +| MPT-7b-8k | 8192 | https://huggingface.co/mosaicml/mpt-7b-8k | Yes | +| MPT-7b-8k-Chat | 8192 | https://huggingface.co/mosaicml/mpt-7b-8k-chat | No | +| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | Yes | +| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | Yes | +| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | No | +| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | Yes | To try out these models locally, [follow the instructions](https://github.com/mosaicml/llm-foundry/tree/main/scripts/inference#interactive-generation-with-modelgenerate) in `scripts/inference/README.md` to prompt HF models using our [hf_generate.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_generate.py) or [hf_chat.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_chat.py) scripts. @@ -75,6 +77,8 @@ Tutorial videos from the community: Something missing? Contribute with a PR! # Latest News +* [Blog: Announcing MPT-7B-8K: 8K Context Length for Document Understanding](https://www.mosaicml.com/blog/long-context-mpt-7b-8k) +* [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250) * [Blog: MPT-30B: Raising the bar for open-source foundation models](https://www.mosaicml.com/blog/mpt-30b) * [Blog: Introducing MPT-7B](https://www.mosaicml.com/blog/mpt-7b) * [Blog: Benchmarking LLMs on H100](https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1) @@ -115,9 +119,10 @@ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117 # Installation -This assumes you already have PyTorch and CMake installed. +This assumes you already have PyTorch, CMake, and packaging installed. If not, you can install them with `pip install cmake packaging torch`. To get started, clone the repo and set up your environment. Instructions to do so differ slightly depending on whether you're using Docker. + ### With Docker (recommended) We *strongly* recommend working with LLM Foundry inside a Docker container (see our recommended Docker image above). If you are doing so, follow these steps to clone the repo and install the requirements. @@ -179,7 +184,7 @@ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.o Notes: 1. `attn_impl: triton` does not work. -1. We don't yet have a docker img where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue. +1. We don't yet have a Docker image where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue. # Quickstart @@ -233,7 +238,7 @@ python inference/hf_generate.py \ "Here's a quick recipe for baking chocolate chip cookies: Start by" ``` -Note: the `composer` command used above to train the model refers to [Composer](https://github.com/mosaicml/composer) library's distributed launcher. +Note: the `composer` command used above to train the model refers to the [Composer](https://github.com/mosaicml/composer) library's distributed launcher. If you have a write-enabled [HuggingFace auth token](https://huggingface.co/docs/hub/security-tokens), you can optionally upload your model to the Hub! Just export your token like this: diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 6db9ff22ca..c077ccb535 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -12,6 +12,11 @@ from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +try: + from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip +except: + unpad_input, pad_input = None, None + attn_config_defaults: Dict = { 'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, @@ -53,6 +58,7 @@ def __init__( fc_type: str = 'torch', device: Optional[str] = None, no_bias: bool = False, + use_pad_tok_in_ffn: bool = True, **kwargs: Any, ): if attn_config is None: @@ -105,6 +111,8 @@ def __init__( self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) + self.use_pad_tok_in_ffn = use_pad_tok_in_ffn + def forward( self, x: torch.Tensor, @@ -132,6 +140,14 @@ def forward( m = x if self.norm_2 is not None: m = self.norm_2(x) + batch_size, seq_len = m.size()[:2] + indices = None + if not self.use_pad_tok_in_ffn: + assert unpad_input is not None + m, indices, _, _ = unpad_input(m, attention_mask) n = self.ffn(m) + if not self.use_pad_tok_in_ffn: + assert pad_input is not None + n = pad_input(n, indices, batch_size, seq_len) x = x + self.resid_ffn_dropout(n) return x, attn_weights, past_key_value diff --git a/llmfoundry/models/layers/cross_entropy_loss.py b/llmfoundry/models/layers/cross_entropy_loss.py new file mode 100644 index 0000000000..e3b0931701 --- /dev/null +++ b/llmfoundry/models/layers/cross_entropy_loss.py @@ -0,0 +1,173 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +# Copied from https://github.com/Dao-AILab/flash-attention/blob/f1a73d074002226c42ce65a1df170ecff9f022c0/flash_attn/losses/cross_entropy.py +# type: ignore + +# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py +# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and +# the losses we can get the global loss. There's no need to do it step by step +# (compute local max, exchange, compute exp, compute local sum, exchange, etc.) +# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py +import torch +import torch.nn as nn + +try: # This try...except is needed because hf transformers library requires it + import xentropy_cuda_lib +except Exception as e: + xentropy_cuda_lib = None + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if 'all_gather_into_tensor' not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +class SoftmaxCrossEntropyLossFn(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + logits, + labels, + smoothing=0.0, + ignored_index=-100, + inplace_backward=False, + process_group=None, + ): + """The forward function for softmax cross entropy loss. + + logits: (batch, vocab_size) + labels: (batch,) + If process_group is not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss needs to be aggregated across processes. + """ + batch, vocab_size = logits.shape + assert labels.shape == (batch,) + world_size = 1 if process_group is None else torch.distributed.get_world_size( + process_group) + ctx.total_classes = world_size * vocab_size + if world_size == 1: + losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) + losses.masked_fill_(labels == ignored_index, 0) + labels_local = labels + else: + rank = torch.distributed.get_rank(process_group) + vocab_start_index, vocab_end_index = rank * vocab_size, ( + rank + 1) * vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + labels_mask = (labels < vocab_start_index) | (labels >= + vocab_end_index) + ignored_mask = labels == ignored_index + labels_local = torch.where(ignored_mask, labels, + labels - vocab_start_index) + + # For tensor parallel cross entropy with smoothing, we want to pass in the total number + # of classes so that smoothing can be applied correctly. If total_classes=-1, use the + # last dimension of the input tensor. + losses, lse_local = xentropy_cuda_lib.forward( + logits, labels_local, smoothing, world_size * vocab_size) + assert lse_local.shape == (batch,) + assert losses.shape == (batch,) + losses.masked_fill_(ignored_mask, 0) + # For labels == ignored_index, the loss is always 0. + # If there's no smoothing, if labels are in the vocab of this partition, losses contains + # lse_local - predicted logit, and 0 otherwise. + # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains + # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) + # For labels not in the vocab of this partition, losses contains + # 0.1 * (lse_local - sum logit / total_classes). + + lse_allgather = torch.empty(world_size, + batch, + dtype=lse_local.dtype, + device=lse_local.device) + torch.distributed.all_gather_into_tensor(lse_allgather, + lse_local.contiguous(), + group=process_group) + handle_losses = torch.distributed.all_reduce( + losses, + op=torch.distributed.ReduceOp.SUM, + group=process_group, + async_op=True) + lse = torch.logsumexp(lse_allgather, dim=0) + # If there's no smoothing, the total losses are lse_local - predicted_logit, + # we just have to subtract the lse_local and add the lse (global). + # If there's smoothing=0.1, the total losses are + # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) + # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). + rank_per_sample = torch.div(labels, + vocab_size, + rounding_mode='floor') + lse_local = lse_allgather[ + rank_per_sample, + torch.arange(batch, device=lse_allgather.device)] + + handle_losses.wait() + if smoothing == 0.0: + losses += lse - lse_local + else: + losses += (1 - smoothing) * (lse - lse_local) + smoothing * ( + lse - lse_allgather.sum(dim=0)) + losses.masked_fill_(ignored_mask, 0) + + ctx.save_for_backward(logits, lse, labels_local) + ctx.smoothing = smoothing + ctx.ignored_index = ignored_index + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + def backward(ctx, grad_loss): + logits, lse, labels = ctx.saved_tensors + grad_loss = grad_loss.contiguous() + grad_loss.masked_fill_(labels == ctx.ignored_index, 0) + grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels, + ctx.smoothing, + ctx.inplace_backward, + ctx.total_classes) + return grad_logits, None, None, None, None, None, None + + +class CrossEntropyLoss(nn.Module): + + def __init__( + self, + ignore_index=-100, + reduction='mean', + label_smoothing=0.0, + inplace_backward=False, + process_group=None, + ): + super().__init__() + if xentropy_cuda_lib is None: + raise ValueError( + 'xentropy_cuda_lib is None, probably because importing xentropy_cuda_lib failed.' + ) + if reduction not in ['mean', 'none']: + raise NotImplementedError( + "Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.inplace_backward = inplace_backward + self.process_group = process_group + + def forward(self, input, target): + assert input.is_cuda and target.is_cuda + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = SoftmaxCrossEntropyLossFn.apply( + input, + target, + self.label_smoothing, + self.ignore_index, + self.inplace_backward, + self.process_group, + ) + if self.reduction == 'mean': + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 2f6d05f424..8f37b39306 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -1,9 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -"""GPT Blocks used for the GPT Model.""" +"""MPT Blocks used for the MPT Model.""" -from typing import Any, Optional +import logging +from typing import Any, Optional, Union import torch import torch.nn as nn @@ -15,33 +16,57 @@ except: te = None +log = logging.getLogger(__name__) + + +def _resolve_ffn_hidden_and_exp_ratio( + d_model: int, + expansion_ratio: Union[int, float], + ffn_hidden_size: Optional[int] = None, +) -> tuple[Union[int, float], int]: + if ffn_hidden_size is not None: + log.info( + f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.' + ) + else: + ffn_hidden_size = int(d_model * expansion_ratio) + if ffn_hidden_size != d_model * expansion_ratio: + raise ValueError( + f'`d_model * expansion_ratio` ({ffn_hidden_size}) must be an integer.' + ) + return expansion_ratio, ffn_hidden_size + class MPTMLP(nn.Module): def __init__( self, d_model: int, - expansion_ratio: int, + expansion_ratio: Union[int, float], fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, device: Optional[str] = None, bias: bool = True, ): super().__init__() - fc_kwargs: dict[str, Any] = { + expansion_ratio, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio( + d_model, expansion_ratio, ffn_hidden_size) + self.fc_kwargs: dict[str, Any] = { 'bias': bias, } if fc_type != 'te': - fc_kwargs['device'] = device + self.fc_kwargs['device'] = device + self.up_proj = FC_CLASS_REGISTRY[fc_type]( d_model, - expansion_ratio * d_model, - **fc_kwargs, + ffn_hidden_size, + **self.fc_kwargs, ) self.act = nn.GELU(approximate='none') self.down_proj = FC_CLASS_REGISTRY[fc_type]( - expansion_ratio * d_model, + ffn_hidden_size, d_model, - **fc_kwargs, + **self.fc_kwargs, ) self.down_proj._is_residual = True @@ -49,8 +74,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.up_proj(x))) +class MPTGeGLU(MPTMLP): + + def __init__( + self, + d_model: int, + expansion_ratio: Union[int, float], + fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, + device: Optional[str] = None, + bias: bool = True, + ): + super().__init__( + d_model=d_model, + expansion_ratio=expansion_ratio, + fc_type=fc_type, + ffn_hidden_size=ffn_hidden_size, + device=device, + bias=bias, + ) + self.gate = FC_CLASS_REGISTRY[fc_type]( + d_model, + self.up_proj.out_features, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x)) * self.gate(x)) + + FFN_CLASS_REGISTRY = { 'mptmlp': MPTMLP, + 'mptgeglu': MPTGeGLU, } if te is not None: @@ -60,29 +115,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def build_ffn( d_model: int, - expansion_ratio: int, + expansion_ratio: Union[int, float], fc_type: str = 'torch', + ffn_hidden_size: Optional[int] = None, device: Optional[str] = None, bias: bool = True, **kwargs: Any, ) -> nn.Module: ffn_type = kwargs.pop('ffn_type') - if ffn_type == 'mptmlp': + if ffn_type in ['mptmlp', 'mptgeglu']: if len(kwargs) > 0: raise ValueError( - f'MPTMLP got an unexpected keyword argument: {kwargs}') - return MPTMLP( + f'MPTMLP (or MPTGeGLU) got an unexpected keyword argument: {kwargs}' + ) + return FFN_CLASS_REGISTRY[ffn_type]( d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, + ffn_hidden_size=ffn_hidden_size, device=device, bias=bias, ) elif ffn_type == 'te_ln_mlp': assert te is not None + _, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio( + d_model, expansion_ratio, ffn_hidden_size) return te.LayerNormMLP( hidden_size=d_model, - ffn_hidden_size=d_model * expansion_ratio, + ffn_hidden_size=ffn_hidden_size, bias=bias, **kwargs, ) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 47fd5ac9e5..b9b4929ad0 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -43,7 +43,7 @@ def __init__( d_model: int = 2048, n_heads: int = 16, n_layers: int = 24, - expansion_ratio: int = 4, + expansion_ratio: Union[int, float] = 4, max_seq_len: int = 2048, vocab_size: int = 50368, resid_pdrop: float = 0.0, @@ -60,6 +60,7 @@ def __init__( init_config: Dict = init_config_defaults, fc_type: str = 'torch', tie_word_embeddings: bool = True, + use_pad_tok_in_ffn: bool = True, verbose: Optional[int] = None, **kwargs: Any, ): @@ -69,7 +70,7 @@ def __init__( d_model (int): The size of the embedding dimension of the model. n_heads (int): The number of attention heads. n_layers (int): The number of layers in the model. - expansion_ratio (int): The ratio of the up/down scale in the ffn. + expansion_ratio (int, float): The ratio of the up/down scale in the ffn. max_seq_len (int): The maximum sequence length of the model. vocab_size (int): The size of the vocabulary. resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. @@ -106,7 +107,7 @@ def __init__( factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: - ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp + ffn_type (str): type of ffn to use. Options: mptmlp, mptgeglu, te_ln_mlp init_device (str): The device to use for parameter initialization. logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. no_bias (bool): Whether to use bias in all layers. @@ -131,6 +132,7 @@ def __init__( See llmfoundry.models.utils.param_init_fns.py for info on other param init config options fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs. tie_word_embeddings (bool): Whether to tie the input embedding and output layers. + use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks. """ self.d_model = d_model self.n_heads = n_heads @@ -151,6 +153,7 @@ def __init__( self.use_cache = use_cache self.init_config = init_config self.fc_type = fc_type + self.use_pad_tok_in_ffn = use_pad_tok_in_ffn if verbose is not None: warnings.warn( DeprecationWarning( @@ -288,7 +291,14 @@ def _validate_config(self) -> None: + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' ) - if self.ffn_config['ffn_type'] == 'mptmlp': + if self.ffn_config['ffn_type'] in ['mptmlp', 'mptgeglu']: self.ffn_config['fc_type'] = self.fc_type elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias + if not self.use_pad_tok_in_ffn: + try: + from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip + except: + raise ImportError( + 'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.2' + ) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 61fedaebed..6074d7ca32 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -419,7 +419,7 @@ def _attn_bias( attn_bias = attn_bias.masked_fill( ~attention_mask.view(-1, 1, 1, s_k), min_val) - return attn_bias, None + return attn_bias, attention_mask def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor: @@ -973,7 +973,11 @@ def __init__( loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy') if loss_fn_config == 'fused_crossentropy': try: - from flash_attn.losses.cross_entropy import \ + # NOTE: The following is the original import statement from flash_attn library, which we have currently replaced with a copy pasted code from the same library's version 1.0.9. The reason is that using the CE loss from FA v2.3.2 results in an illegal memory access error at long sequence lengths (github issue: https://github.com/Dao-AILab/flash-attention/issues/714). + # from flash_attn.losses.cross_entropy import \ + # CrossEntropyLoss as FusedCrossEntropyLoss + # TODO: Once the problem with using FA v2's CE loss at longer sequence lengths is resolved (github issue: https://github.com/Dao-AILab/flash-attention/issues/714), revert back to directly importing the CE loss from FA library. + from llmfoundry.models.layers.cross_entropy_loss import \ CrossEntropyLoss as FusedCrossEntropyLoss self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100) diff --git a/setup.py b/setup.py index 1bf4372fce..9853aa17bf 100644 --- a/setup.py +++ b/setup.py @@ -143,5 +143,5 @@ classifiers=classifiers, install_requires=install_requires, extras_require=extra_deps, - python_requires='>=3.7', + python_requires='>=3.9', )