Skip to content

Commit

Permalink
correct collate function
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 13, 2024
1 parent 37ba704 commit bc209b2
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ def load_data(filename):

def collate_fn(batch):
"""Custom collate function to pad data batches."""
if not batch:
return torch.tensor([]), torch.tensor([])
inputs, targets = zip(*batch)
padded_inputs = pad_sequence(
[torch.tensor(seq) for seq in inputs], batch_first=True, padding_value=0
)
padded_targets = pad_sequence(
[torch.tensor(tgt) for tgt in targets], batch_first=True, padding_value=-1
[torch.tensor(tgt) for seq in targets], batch_first=True, padding_value=-1
)
return padded_inputs, padded_targets

Expand Down

0 comments on commit bc209b2

Please sign in to comment.