Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 9, 2023
1 parent 668d66b commit 49b316b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
5 changes: 2 additions & 3 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,13 +972,12 @@ def __init__(
loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy')
if loss_fn_config == 'fused_crossentropy':
try:
from llmfoundry.models.layers.cross_entropy_loss import \
CrossEntropyLoss as FusedCrossEntropyLoss

# NOTE: The following is the original import statement from flash_attn library, which we have currently replaced with a copy pasted code from the same library's version 1.0.9. The reason is that flash_attn's version 2.3.2 has a bug in their Cross Entropy Loss that throws an illegal memory access error at long sequence lengths.
# from flash_attn.losses.cross_entropy import \
# CrossEntropyLoss as FusedCrossEntropyLoss
# TODO: Once the flash_attn library is updated to fix the bug in their Cross Entropy Loss, we can revert back to the original import statement.
from llmfoundry.models.layers.cross_entropy_loss import \
CrossEntropyLoss as FusedCrossEntropyLoss

self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
except:
Expand Down
16 changes: 11 additions & 5 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,16 +398,22 @@ def test_determinism(attn_impl: str, precision: torch.dtype):


@pytest.mark.gpu
def test_loss_fn():
@pytest.mark.parametrize('ce_loss_implementation',
['FA_v1_copied', 'FA_imported'])
def test_loss_fn(ce_loss_implementation: str):
"""Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function.
We provide non-zero tolerances to account for small numerics differences
between the two loss implementations.
"""
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')
if ce_loss_implementation == 'FA_imported':
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')
else:
from llmfoundry.models.layers.cross_entropy_loss import \
CrossEntropyLoss as FusedCrossEntropyLoss

# run numerical test in pure fp32
from torch.backends import cuda, cudnn
Expand Down

0 comments on commit 49b316b

Please sign in to comment.