diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 46e28c545e..60bdb891f3 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -1007,7 +1007,7 @@ def get_targets(self, batch: Mapping) -> torch.Tensor: if self.tokenizer is not None and hasattr(self.tokenizer, 'eos_token_id'): targets = torch.where( - batch['input_ids'] == self.tokenizer.eos_token_id, -100, + batch['labels'] == self.tokenizer.eos_token_id, -100, targets) return targets