From ced63ee33198fff2610b9bb232ee0897ec0f966d Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 23 Jul 2024 14:54:11 -0700 Subject: [PATCH] Refactor loss function for ComposerMPTCausalLM (#1387) --- llmfoundry/models/mpt/modeling_mpt.py | 64 +++++++++++++++++---------- 1 file changed, 41 insertions(+), 23 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 3b2744f867..40b3aaa6ee 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -1285,6 +1285,40 @@ def _reorder_cache( return reordered_past +def get_targets(labels: torch.Tensor) -> torch.Tensor: + targets = torch.roll(labels, shifts=-1) + targets[:, -1] = -100 + return targets + + +def compute_loss_from_logits( + outputs: CausalLMOutputWithPast, + shift_labels: bool, + labels: torch.Tensor, + loss_fn: nn.Module, + sample_weighing_factor: Optional[torch.Tensor] = None, +) -> torch.Tensor: + targets = get_targets(labels) if shift_labels else labels + + losses = loss_fn( + outputs.logits.view(-1, outputs.logits.size(-1)), + targets.view(-1), + ) + + if torch.all(targets == loss_fn.ignore_index): + loss = losses.sum() + else: + loss = losses.sum() / (targets != loss_fn.ignore_index).sum() + if sample_weighing_factor is not None: + if sample_weighing_factor.shape[0] > 1: + raise ValueError( + 'Sample weighing factor is not supported when batch["sample_weighing_factor"].shape[0] > 1.', + ) + loss = loss * sample_weighing_factor[0].item() + + return loss + + class ComposerMPTCausalLM(HuggingFaceModel): def __init__( @@ -1362,9 +1396,7 @@ def config_class(self) -> Type[MPTConfig]: return MPTConfig def get_targets(self, batch: Mapping) -> torch.Tensor: - targets = torch.roll(batch['labels'], shifts=-1) - targets[:, -1] = -100 - return targets + return get_targets(batch['labels']) def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: @@ -1385,27 +1417,14 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> Union[dict, torch.Tensor]: - if self.shift_labels: - targets = self.get_targets(batch) - else: - targets = batch['labels'] - - losses = self.loss_fn( - outputs.logits.view(-1, outputs.logits.size(-1)), - targets.view(-1), + loss = compute_loss_from_logits( + outputs, + self.shift_labels, + batch['labels'], + self.loss_fn, + batch.get('sample_weighing_factor', None), ) - if torch.all(targets == self.loss_fn.ignore_index): - loss = losses.sum() - else: - loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum() - if 'sample_weighing_factor' in batch: - if batch['sample_weighing_factor'].shape[0] > 1: - raise ValueError( - 'Sample weighing factor is not supported when batch["sample_weighing_factor"].shape[0] > 1.', - ) - loss = loss * batch['sample_weighing_factor'][0].item() - if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # MegaBlocks MoE load balancing loss try: # Add try/catch to avoid transformers complaining and raising errors @@ -1420,7 +1439,6 @@ def loss(self, outputs: CausalLMOutputWithPast, 'loss': loss, 'lbl': lbl, } - return loss @cached_property