Skip to content

Commit

Permalink
yo
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Jun 11, 2024
1 parent 832f17d commit e2678b9
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,14 @@ def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100):
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.ignore_index = ignore_index
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss
self.loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
except ImportError:
log.debug(
'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' +
'to compute LanguageCrossEntropy metric, which will be slower.',
)
self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
self.add_state('total_items', default=torch.tensor(0), dist_reduce_fx='sum')

Expand Down

0 comments on commit e2678b9

Please sign in to comment.