Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 9, 2023
1 parent 322a47b commit c22374f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,10 +972,10 @@ def __init__(
loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy')
if loss_fn_config == 'fused_crossentropy':
try:
# NOTE: The following is the original import statement from flash_attn library, which we have currently replaced with a copy pasted code from the same library's version 1.0.9. The reason is that flash_attn's version 2.3.2 has a bug in their Cross Entropy Loss that throws an illegal memory access error at long sequence lengths.
# NOTE: The following is the original import statement from flash_attn library, which we have currently replaced with a copy pasted code from the same library's version 1.0.9. The reason is that using the CE loss from FA v2.3.2 results in an illegal memory access error at long sequence lengths.
# from flash_attn.losses.cross_entropy import \
# CrossEntropyLoss as FusedCrossEntropyLoss
# TODO: Once the flash_attn library is updated to fix the bug in their Cross Entropy Loss, we can revert back to the original import statement.
# TODO: Once the problem with using FA v2's CE loss at longer sequence lengths is resolved, revert back to directly importing the CE loss from FA library.
from llmfoundry.models.layers.cross_entropy_loss import \
CrossEntropyLoss as FusedCrossEntropyLoss

Expand Down

0 comments on commit c22374f

Please sign in to comment.