Skip to content

Commit

Permalink
Merge branch 'main' into update_openai_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml authored Dec 13, 2023
2 parents 0a7b8df + 5fdcc43 commit 11cf032
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 44 deletions.
31 changes: 18 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# LLM Foundry

This repository contains code for training, finetuning, evaluating, and deploying LLMs for inference with [Composer](https://github.com/mosaicml/composer) and the [MosaicML platform](https://forms.mosaicml.com/demo?utm_source=github.com&utm_medium=referral&utm_campaign=llm-foundry). Designed to be easy-to-use, efficient _and_ flexible, this codebase is designed to enable rapid experimentation with the latest techniques.
This repository contains code for training, finetuning, evaluating, and deploying LLMs for inference with [Composer](https://github.com/mosaicml/composer) and the [MosaicML platform](https://forms.mosaicml.com/demo?utm_source=github.com&utm_medium=referral&utm_campaign=llm-foundry). Designed to be easy-to-use, efficient _and_ flexible, this codebase enables rapid experimentation with the latest techniques.

You'll find in this repo:
* `llmfoundry/` - source code for models, datasets, callbacks, utilities, etc.
Expand All @@ -45,15 +45,17 @@ You'll find in this repo:
Mosaic Pretrained Transformers (MPT) are GPT-style models with some special features -- Flash Attention for efficiency, ALiBi for context length extrapolation, and stability improvements to mitigate loss spikes. As part of MosaicML's Foundation series, we have open-sourced several MPT models:


| Model | Context Length | Download | Demo | Commercial use? |
| ------------------ | -------------- | -------------------------------------------------- | ----------------------------------------------------------- | --------------- |
| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | | Yes |
| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | | Yes |
| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-30b-chat) | No |
| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | | Yes |
| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | | Yes |
| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-chat) | No |
| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | | Yes |
| Model | Context Length | Download | Commercial use? |
| ------------------ | -------------- | -------------------------------------------------- | --------------- |
| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | Yes |
| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | Yes |
| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | No |
| MPT-7b-8k | 8192 | https://huggingface.co/mosaicml/mpt-7b-8k | Yes |
| MPT-7b-8k-Chat | 8192 | https://huggingface.co/mosaicml/mpt-7b-8k-chat | No |
| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | Yes |
| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | Yes |
| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | No |
| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | Yes |

To try out these models locally, [follow the instructions](https://github.com/mosaicml/llm-foundry/tree/main/scripts/inference#interactive-generation-with-modelgenerate) in `scripts/inference/README.md` to prompt HF models using our [hf_generate.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_generate.py) or [hf_chat.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_chat.py) scripts.

Expand All @@ -75,6 +77,8 @@ Tutorial videos from the community:
Something missing? Contribute with a PR!

# Latest News
* [Blog: Announcing MPT-7B-8K: 8K Context Length for Document Understanding](https://www.mosaicml.com/blog/long-context-mpt-7b-8k)
* [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250)
* [Blog: MPT-30B: Raising the bar for open-source foundation models](https://www.mosaicml.com/blog/mpt-30b)
* [Blog: Introducing MPT-7B](https://www.mosaicml.com/blog/mpt-7b)
* [Blog: Benchmarking LLMs on H100](https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1)
Expand Down Expand Up @@ -115,9 +119,10 @@ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117

# Installation

This assumes you already have PyTorch and CMake installed.
This assumes you already have PyTorch, CMake, and packaging installed. If not, you can install them with `pip install cmake packaging torch`.

To get started, clone the repo and set up your environment. Instructions to do so differ slightly depending on whether you're using Docker.

### With Docker (recommended)

We *strongly* recommend working with LLM Foundry inside a Docker container (see our recommended Docker image above). If you are doing so, follow these steps to clone the repo and install the requirements.
Expand Down Expand Up @@ -179,7 +184,7 @@ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.o

Notes:
1. `attn_impl: triton` does not work.
1. We don't yet have a docker img where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue.
1. We don't yet have a Docker image where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue.

# Quickstart

Expand Down Expand Up @@ -233,7 +238,7 @@ python inference/hf_generate.py \
"Here's a quick recipe for baking chocolate chip cookies: Start by"
```

Note: the `composer` command used above to train the model refers to [Composer](https://github.com/mosaicml/composer) library's distributed launcher.
Note: the `composer` command used above to train the model refers to the [Composer](https://github.com/mosaicml/composer) library's distributed launcher.

If you have a write-enabled [HuggingFace auth token](https://huggingface.co/docs/hub/security-tokens), you can optionally upload your model to the Hub! Just export your token like this:

Expand Down
173 changes: 173 additions & 0 deletions llmfoundry/models/layers/cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copied from https://github.com/Dao-AILab/flash-attention/blob/f1a73d074002226c42ce65a1df170ecff9f022c0/flash_attn/losses/cross_entropy.py
# type: ignore

# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
import torch
import torch.nn as nn

try: # This try...except is needed because hf transformers library requires it
import xentropy_cuda_lib
except Exception as e:
xentropy_cuda_lib = None

# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if 'all_gather_into_tensor' not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base


class SoftmaxCrossEntropyLossFn(torch.autograd.Function):

@staticmethod
def forward(
ctx,
logits,
labels,
smoothing=0.0,
ignored_index=-100,
inplace_backward=False,
process_group=None,
):
"""The forward function for softmax cross entropy loss.
logits: (batch, vocab_size)
labels: (batch,)
If process_group is not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss needs to be aggregated across processes.
"""
batch, vocab_size = logits.shape
assert labels.shape == (batch,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(
process_group)
ctx.total_classes = world_size * vocab_size
if world_size == 1:
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
losses.masked_fill_(labels == ignored_index, 0)
labels_local = labels
else:
rank = torch.distributed.get_rank(process_group)
vocab_start_index, vocab_end_index = rank * vocab_size, (
rank + 1) * vocab_size

# Create a mask of valid vocab ids (1 means it needs to be masked).
labels_mask = (labels < vocab_start_index) | (labels >=
vocab_end_index)
ignored_mask = labels == ignored_index
labels_local = torch.where(ignored_mask, labels,
labels - vocab_start_index)

# For tensor parallel cross entropy with smoothing, we want to pass in the total number
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# last dimension of the input tensor.
losses, lse_local = xentropy_cuda_lib.forward(
logits, labels_local, smoothing, world_size * vocab_size)
assert lse_local.shape == (batch,)
assert losses.shape == (batch,)
losses.masked_fill_(ignored_mask, 0)
# For labels == ignored_index, the loss is always 0.
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# lse_local - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
# For labels not in the vocab of this partition, losses contains
# 0.1 * (lse_local - sum logit / total_classes).

lse_allgather = torch.empty(world_size,
batch,
dtype=lse_local.dtype,
device=lse_local.device)
torch.distributed.all_gather_into_tensor(lse_allgather,
lse_local.contiguous(),
group=process_group)
handle_losses = torch.distributed.all_reduce(
losses,
op=torch.distributed.ReduceOp.SUM,
group=process_group,
async_op=True)
lse = torch.logsumexp(lse_allgather, dim=0)
# If there's no smoothing, the total losses are lse_local - predicted_logit,
# we just have to subtract the lse_local and add the lse (global).
# If there's smoothing=0.1, the total losses are
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
rank_per_sample = torch.div(labels,
vocab_size,
rounding_mode='floor')
lse_local = lse_allgather[
rank_per_sample,
torch.arange(batch, device=lse_allgather.device)]

handle_losses.wait()
if smoothing == 0.0:
losses += lse - lse_local
else:
losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
lse - lse_allgather.sum(dim=0))
losses.masked_fill_(ignored_mask, 0)

ctx.save_for_backward(logits, lse, labels_local)
ctx.smoothing = smoothing
ctx.ignored_index = ignored_index
ctx.inplace_backward = inplace_backward
return losses

@staticmethod
def backward(ctx, grad_loss):
logits, lse, labels = ctx.saved_tensors
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
ctx.smoothing,
ctx.inplace_backward,
ctx.total_classes)
return grad_logits, None, None, None, None, None, None


class CrossEntropyLoss(nn.Module):

def __init__(
self,
ignore_index=-100,
reduction='mean',
label_smoothing=0.0,
inplace_backward=False,
process_group=None,
):
super().__init__()
if xentropy_cuda_lib is None:
raise ValueError(
'xentropy_cuda_lib is None, probably because importing xentropy_cuda_lib failed.'
)
if reduction not in ['mean', 'none']:
raise NotImplementedError(
"Only support reduction = 'mean' or 'none'")
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.inplace_backward = inplace_backward
self.process_group = process_group

def forward(self, input, target):
assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossFn.apply(
input,
target,
self.label_smoothing,
self.ignore_index,
self.inplace_backward,
self.process_group,
)
if self.reduction == 'mean':
return loss.sum() / (target != self.ignore_index).sum()
else:
return loss
Loading

0 comments on commit 11cf032

Please sign in to comment.