From 0e6ff383a4424d0075fe338542ac16b954e59ce8 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Wed, 1 Jun 2022 15:17:18 -0700 Subject: [PATCH] add TODO comments --- sparsecoding/priors/l0.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sparsecoding/priors/l0.py b/sparsecoding/priors/l0.py index dde7503..18bf698 100644 --- a/sparsecoding/priors/l0.py +++ b/sparsecoding/priors/l0.py @@ -24,7 +24,7 @@ def __init__( raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.") if prob_distr.dtype != torch.float32: raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.") - if not torch.allclose(torch.sum(prob_distr), torch.ones_like(prob_distr)): + if not torch.allclose(torch.sum(prob_distr), torch.ones(1, dtype=torch.float32)): raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.") self.prob_distr = prob_distr @@ -35,7 +35,7 @@ def D(self): def sample( self, - num_samples: int + num_samples: int, ): N = num_samples @@ -74,3 +74,7 @@ def log_prob( log_prob = torch.log(self.prob_distr[l0_norm - 1]) log_prob[l0_norm == 0] = -torch.inf return log_prob + +# TODO: Add L0ExpPrior, where the number of active units is distributed exponentially. + +# TODO: Add L0IidPrior, where the magnitude of an active unit is distributed according to an i.i.d. Prior.