-
Notifications
You must be signed in to change notification settings - Fork 411
Open
Description
I'm pretty sure the log-likelihood for the solution to the second exercise is off. The nll is defined as:
def nll(self, x, cond=None):
loc, log_scale, weight_logits = torch.chunk(self.forward(x), 3, dim=1)
weights = F.softmax(weight_logits, dim=1) #.repeat(1, 1, self.n_components, 1, 1)
log_det_jacobian = Normal(loc, log_scale.exp()).log_prob(x.unsqueeze(1).repeat(1,1,self.n_components,1,1))
return -log_det_jacobian.mean()
As you can see, the weights are never used. I believe this should be:
def nll(self, x, cond=None):
loc, log_scale, weight_logits = torch.chunk(self.forward(x), 3, dim=1)
weights = F.softmax(weight_logits, dim=1) #.repeat(1, 1, self.n_components, 1, 1)
log_det_jacobian = Normal(loc, log_scale.exp()).log_prob(x.unsqueeze(1).repeat(1,1,self.n_components,1,1)).exp()
return -torch.log((log_det_jacobian * weights).sum(dim=2)).mean()
ghty010
Metadata
Metadata
Assignees
Labels
No labels