diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index cdec1f4d45d..3bd62e8702e 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -232,6 +232,19 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0) batch["labels"][batch["position_ids"] == 0] = self.ignore_index + flattened_position_ids = batch["position_ids"].flatten() + indices_q = torch.arange(flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32) + batch["cu_seq_lens_q"] = torch.cat( + ( + indices_q[flattened_position_ids == 0], + torch.tensor(flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32), + ) + ) + batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"] + + batch["max_length_k"] = flattened_position_ids.max() + 1 + batch["max_length_q"] = batch["max_length_k"] + return batch