Skip to content

Commit

Permalink
Replace self.get_targets
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Jul 23, 2024
1 parent 118e96c commit 6f4f58b
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,9 +1396,7 @@ def config_class(self) -> Type[MPTConfig]:
return MPTConfig

def get_targets(self, batch: Mapping) -> torch.Tensor:
targets = torch.roll(batch['labels'], shifts=-1)
targets[:, -1] = -100
return targets
return get_targets(batch['labels'])

def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
if self.config.ffn_config['ffn_type'] in ffns_with_megablocks:
Expand Down

0 comments on commit 6f4f58b

Please sign in to comment.