Skip to content

Commit

Permalink
add TODO comments
Browse files Browse the repository at this point in the history
  • Loading branch information
alvinzz committed Jun 2, 2022
1 parent cbc1cb1 commit 0e6ff38
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sparsecoding/priors/l0.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.")
if prob_distr.dtype != torch.float32:
raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.")
if not torch.allclose(torch.sum(prob_distr), torch.ones_like(prob_distr)):
if not torch.allclose(torch.sum(prob_distr), torch.ones(1, dtype=torch.float32)):
raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.")

self.prob_distr = prob_distr
Expand All @@ -35,7 +35,7 @@ def D(self):

def sample(
self,
num_samples: int
num_samples: int,
):
N = num_samples

Expand Down Expand Up @@ -74,3 +74,7 @@ def log_prob(
log_prob = torch.log(self.prob_distr[l0_norm - 1])
log_prob[l0_norm == 0] = -torch.inf
return log_prob

# TODO: Add L0ExpPrior, where the number of active units is distributed exponentially.

# TODO: Add L0IidPrior, where the magnitude of an active unit is distributed according to an i.i.d. Prior.

0 comments on commit 0e6ff38

Please sign in to comment.