diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 3bd62e8702..02a845faf3 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -233,11 +233,15 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D batch["labels"][batch["position_ids"] == 0] = self.ignore_index flattened_position_ids = batch["position_ids"].flatten() - indices_q = torch.arange(flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32) + indices_q = torch.arange( + flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32 + ) batch["cu_seq_lens_q"] = torch.cat( ( indices_q[flattened_position_ids == 0], - torch.tensor(flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32), + torch.tensor( + flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32 + ), ) ) batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"]