Skip to content

Commit

Permalink
fix: formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <[email protected]>
  • Loading branch information
Abhishek-TAMU committed Oct 8, 2024
1 parent a821ce0 commit fb669b6
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,15 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
batch["labels"][batch["position_ids"] == 0] = self.ignore_index

flattened_position_ids = batch["position_ids"].flatten()
indices_q = torch.arange(flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32)
indices_q = torch.arange(
flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32
)
batch["cu_seq_lens_q"] = torch.cat(
(
indices_q[flattened_position_ids == 0],
torch.tensor(flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32),
torch.tensor(
flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32
),
)
)
batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"]
Expand Down

0 comments on commit fb669b6

Please sign in to comment.