diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 956b0d0624..40b3aaa6ee 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -1396,9 +1396,7 @@ def config_class(self) -> Type[MPTConfig]: return MPTConfig def get_targets(self, batch: Mapping) -> torch.Tensor: - targets = torch.roll(batch['labels'], shifts=-1) - targets[:, -1] = -100 - return targets + return get_targets(batch['labels']) def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: