Skip to content

Commit

Permalink
Refactor loss function for ComposerMPTCausalLM (#1387)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Jul 23, 2024
1 parent 7b160fc commit ced63ee
Showing 1 changed file with 41 additions and 23 deletions.
64 changes: 41 additions & 23 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,40 @@ def _reorder_cache(
return reordered_past


def get_targets(labels: torch.Tensor) -> torch.Tensor:
targets = torch.roll(labels, shifts=-1)
targets[:, -1] = -100
return targets


def compute_loss_from_logits(
outputs: CausalLMOutputWithPast,
shift_labels: bool,
labels: torch.Tensor,
loss_fn: nn.Module,
sample_weighing_factor: Optional[torch.Tensor] = None,
) -> torch.Tensor:
targets = get_targets(labels) if shift_labels else labels

losses = loss_fn(
outputs.logits.view(-1, outputs.logits.size(-1)),
targets.view(-1),
)

if torch.all(targets == loss_fn.ignore_index):
loss = losses.sum()
else:
loss = losses.sum() / (targets != loss_fn.ignore_index).sum()
if sample_weighing_factor is not None:
if sample_weighing_factor.shape[0] > 1:
raise ValueError(
'Sample weighing factor is not supported when batch["sample_weighing_factor"].shape[0] > 1.',
)
loss = loss * sample_weighing_factor[0].item()

return loss


class ComposerMPTCausalLM(HuggingFaceModel):

def __init__(
Expand Down Expand Up @@ -1362,9 +1396,7 @@ def config_class(self) -> Type[MPTConfig]:
return MPTConfig

def get_targets(self, batch: Mapping) -> torch.Tensor:
targets = torch.roll(batch['labels'], shifts=-1)
targets[:, -1] = -100
return targets
return get_targets(batch['labels'])

def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
if self.config.ffn_config['ffn_type'] in ffns_with_megablocks:
Expand All @@ -1385,27 +1417,14 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:

def loss(self, outputs: CausalLMOutputWithPast,
batch: Mapping) -> Union[dict, torch.Tensor]:
if self.shift_labels:
targets = self.get_targets(batch)
else:
targets = batch['labels']

losses = self.loss_fn(
outputs.logits.view(-1, outputs.logits.size(-1)),
targets.view(-1),
loss = compute_loss_from_logits(
outputs,
self.shift_labels,
batch['labels'],
self.loss_fn,
batch.get('sample_weighing_factor', None),
)

if torch.all(targets == self.loss_fn.ignore_index):
loss = losses.sum()
else:
loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum()
if 'sample_weighing_factor' in batch:
if batch['sample_weighing_factor'].shape[0] > 1:
raise ValueError(
'Sample weighing factor is not supported when batch["sample_weighing_factor"].shape[0] > 1.',
)
loss = loss * batch['sample_weighing_factor'][0].item()

if self.config.ffn_config['ffn_type'] in ffns_with_megablocks:
# MegaBlocks MoE load balancing loss
try: # Add try/catch to avoid transformers complaining and raising errors
Expand All @@ -1420,7 +1439,6 @@ def loss(self, outputs: CausalLMOutputWithPast,
'loss': loss,
'lbl': lbl,
}

return loss

@cached_property
Expand Down

0 comments on commit ced63ee

Please sign in to comment.