-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathloss.py
28 lines (24 loc) · 866 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
class PGLoss(nn.Module):
"""
Pseudo-loss that gives corresponding policy gradients (on calling .backward())
for adversial training of Generator
"""
def __init__(self):
super(PGLoss, self).__init__()
def forward(self, pred, target, reward):
"""
Inputs: pred, target, reward
- pred: (batch_size, seq_len),
- target : (batch_size, seq_len),
- reward : (batch_size, ), reward of each whole sentence
"""
one_hot = torch.zeros(pred.size(), dtype=torch.uint8)
if pred.is_cuda:
one_hot = one_hot.cuda()
one_hot.scatter_(1, target.data.view(-1, 1), 1)
loss = torch.masked_select(pred, one_hot)
loss = loss * reward.contiguous().view(-1)
loss = -torch.sum(loss)
return loss