Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating the Flash Attention version to fix cross entropy loss #812

Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 0 additions & 173 deletions llmfoundry/models/layers/cross_entropy_loss.py

This file was deleted.

6 changes: 1 addition & 5 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,11 +972,7 @@ def __init__(
loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy')
if loss_fn_config == 'fused_crossentropy':
try:
# NOTE: The following is the original import statement from flash_attn library, which we have currently replaced with a copy pasted code from the same library's version 1.0.9. The reason is that using the CE loss from FA v2.3.2 results in an illegal memory access error at long sequence lengths (github issue: https://github.com/Dao-AILab/flash-attention/issues/714).
# from flash_attn.losses.cross_entropy import \
# CrossEntropyLoss as FusedCrossEntropyLoss
# TODO: Once the problem with using FA v2's CE loss at longer sequence lengths is resolved (github issue: https://github.com/Dao-AILab/flash-attention/issues/714), revert back to directly importing the CE loss from FA library.
from llmfoundry.models.layers.cross_entropy_loss import \
from flash_attn.losses.cross_entropy import \
CrossEntropyLoss as FusedCrossEntropyLoss

self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
Expand Down
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,8 @@
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
]
extra_deps['gpu-flash2'] = [
'flash-attn==2.3.2',
'flash-attn==2.3.6',
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
'mosaicml-turbo==0.0.4',
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
]

extra_deps['peft'] = [
Expand Down
49 changes: 49 additions & 0 deletions tests/models/layers/test_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copied and modified from https://github.com/Dao-AILab/flash-attention/blob/713bd3aa9ad518ecdb5fd41078550c25ebd58e1f/tests/losses/test_cross_entropy.py

import pytest
import torch
from flash_attn.losses.cross_entropy import CrossEntropyLoss


@pytest.mark.gpu
@pytest.mark.parametrize('vocab_size', [50432, 100352])
@pytest.mark.parametrize('seqlen', [4096, 65536])
@pytest.mark.parametrize('batch_size', [1, 8])
def test_cross_entropy_loss(vocab_size: int, seqlen: int, batch_size: int):
if batch_size > 1 and seqlen == 65536:
pytest.skip(f'Skipping since this will OOM because of data size.')
dtype = torch.bfloat16
device = 'cuda'
rtol, atol = (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
x_pt = torch.randn(batch_size * seqlen,
vocab_size,
device=device,
dtype=dtype,
requires_grad=True)
x = x_pt.detach().clone().requires_grad_()
y = torch.randint(0,
vocab_size, (batch_size * seqlen,),
dtype=torch.long,
device=device)
if batch_size * seqlen > 10:
y[torch.randperm(batch_size * seqlen)[:10]] = -100
model_pt = torch.nn.CrossEntropyLoss()
model = CrossEntropyLoss()
out = model(x, y)
x_pt_scaled = x_pt.float()
out_pt = model_pt(x_pt_scaled, y)
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)

g = torch.randn_like(out)
out_pt.backward(g)
out.backward(g)
assert torch.allclose(
x.grad, # type: ignore
x_pt.grad, # type: ignore
rtol=rtol,
atol=atol)
16 changes: 5 additions & 11 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,22 +422,16 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str,


@pytest.mark.gpu
@pytest.mark.parametrize('ce_loss_implementation',
['FA_v1_copied', 'FA_imported'])
def test_loss_fn(ce_loss_implementation: str):
def test_loss_fn():
"""Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function.

We provide non-zero tolerances to account for small numerics differences
between the two loss implementations.
"""
if ce_loss_implementation == 'FA_imported':
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')
else:
from llmfoundry.models.layers.cross_entropy_loss import \
CrossEntropyLoss as FusedCrossEntropyLoss
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')

# run numerical test in pure fp32
from torch.backends import cuda, cudnn
Expand Down
Loading