Skip to content

Commit

Permalink
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <[email protected]>
  • Loading branch information
Abhishek-TAMU committed Oct 3, 2024
1 parent 1be4d86 commit 56e32cb
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,19 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
batch["labels"][batch["position_ids"] == 0] = self.ignore_index

flattened_position_ids = batch["position_ids"].flatten()
indices_q = torch.arange(flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32)
batch["cu_seq_lens_q"] = torch.cat(
(
indices_q[flattened_position_ids == 0],
torch.tensor(flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32),
)
)
batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"]

batch["max_length_k"] = flattened_position_ids.max() + 1
batch["max_length_q"] = batch["max_length_k"]

return batch


Expand Down

0 comments on commit 56e32cb

Please sign in to comment.