diff --git a/sparsecoding/priors/spike_slab.py b/sparsecoding/priors/spike_slab.py index e6c6ccb..6336cf0 100644 --- a/sparsecoding/priors/spike_slab.py +++ b/sparsecoding/priors/spike_slab.py @@ -91,7 +91,7 @@ def log_prob( log_prob[slab_mask] = ( torch.log(torch.tensor(1. - self.p_spike)) - torch.log(torch.tensor(self.scale)) - - sample[slab_mask] / self.scale + - torch.abs(sample[slab_mask]) / self.scale ) if self.positive_only: log_prob[sample < 0.] = -torch.inf