From 88aad940e340c2f9b66c6de0e9274d74c494a636 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Wed, 1 Jun 2022 11:59:34 -0700 Subject: [PATCH] bugfix spike_slab log_prob --- sparsecoding/priors/spike_slab.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sparsecoding/priors/spike_slab.py b/sparsecoding/priors/spike_slab.py index e6c6ccb..6336cf0 100644 --- a/sparsecoding/priors/spike_slab.py +++ b/sparsecoding/priors/spike_slab.py @@ -91,7 +91,7 @@ def log_prob( log_prob[slab_mask] = ( torch.log(torch.tensor(1. - self.p_spike)) - torch.log(torch.tensor(self.scale)) - - sample[slab_mask] / self.scale + - torch.abs(sample[slab_mask]) / self.scale ) if self.positive_only: log_prob[sample < 0.] = -torch.inf