Skip to content

Commit

Permalink
add log_prob() for priors
Browse files Browse the repository at this point in the history
  • Loading branch information
alvinzz committed Jun 2, 2022
1 parent db25837 commit cbc1cb1
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 3 deletions.
36 changes: 35 additions & 1 deletion sparsecoding/priors/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod

import torch


class Prior(ABC):
"""A distribution over weights.
Expand Down Expand Up @@ -29,6 +31,38 @@ def sample(
Returns
-------
samples : Tensor, shape [num_samples, self.D]
samples : Tensor, shape [num_samples, self.D()]
Sampled weights.
"""

@abstractmethod
def log_prob(
self,
sample: torch.Tensor,
):
"""Get the log-probability of the sample under this distribution.
Parameters
----------
sample : Tensor, shape [num_samples, self.D()]
Sample to get the log-probability for.
Returns
-------
log_prob : Tensor, shape [num_samples]
Log-probability of `sample`.
"""

def check_sample_input(
self,
sample: torch.Tensor,
):
"""Check the shape and dtype of the sample.
Used in:
self.log_prob().
"""
if sample.dtype != torch.float32:
raise ValueError(f"`sample` dtype should be float32, got {sample.dtype}.")
if sample.dim() != 2 or sample.shape[1] != self.D:
raise ValueError(f"`sample` should have shape [N, {self.D}], got {sample.shape}.")
11 changes: 11 additions & 0 deletions sparsecoding/priors/l0.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,14 @@ def sample(
weights[active_weight_idxs] += 1.

return weights

def log_prob(
self,
sample: torch.Tensor,
):
super().check_sample_input(sample)

l0_norm = torch.sum(sample != 0., dim=1).type(torch.long) # [num_samples]
log_prob = torch.log(self.prob_distr[l0_norm - 1])
log_prob[l0_norm == 0] = -torch.inf
return log_prob
35 changes: 35 additions & 0 deletions sparsecoding/priors/spike_slab.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,38 @@ def sample(self, num_samples: int):
)

return weights

def log_prob(
self,
sample: torch.Tensor,
):
super().check_sample_input(sample)

N = sample.shape[0]

log_prob = torch.zeros((N, self.D), dtype=torch.float32)

spike_mask = sample == 0.
slab_mask = sample != 0.

# Add log-probability for spike.
log_prob[spike_mask] = torch.log(torch.tensor(self.p_spike))

# Add log-probability for slab.
if self.positive_only:
log_prob[slab_mask] = (
torch.log(torch.tensor(1. - self.p_spike))
- torch.log(torch.tensor(self.scale))
- sample[slab_mask] / self.scale
)
log_prob[sample < 0.] = -torch.inf
else:
log_prob[slab_mask] = (
torch.log(torch.tensor(1. - self.p_spike))
- torch.log(torch.tensor(2. * self.scale))
- torch.abs(sample[slab_mask]) / self.scale
)

log_prob = torch.sum(log_prob, dim=1) # [N]

return log_prob
28 changes: 27 additions & 1 deletion tests/priors/test_l0.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from itertools import product

import torch
import unittest

from sparsecoding.priors.l0 import L0Prior


class TestL0Prior(unittest.TestCase):
def test_l0_prior(self):
def test_sample(self):
N = 10000
prob_distr = torch.tensor([0.5, 0.25, 0, 0.25])

Expand Down Expand Up @@ -36,6 +38,30 @@ def test_l0_prior(self):
atol=1e-2,
)

def test_log_prob(self):
prob_distr = torch.tensor([0.75, 0.25, 0.])

l0_prior = L0Prior(prob_distr)

samples = list(product([0, 1], repeat=3)) # [2**D, D]
samples = torch.tensor(samples, dtype=torch.float32) # [2**D, D]

log_probs = l0_prior.log_prob(samples)

# The l0-norm at index `i`
# is the number of ones
# in the binary representation of `i`.
assert log_probs[0] == -torch.inf
assert torch.allclose(
log_probs[[1, 2, 4]],
torch.log(torch.tensor(0.75)),
)
assert torch.allclose(
log_probs[[3, 5, 6]],
torch.log(torch.tensor(0.25)),
)
assert log_probs[7] == -torch.inf


if __name__ == "__main__":
unittest.main()
39 changes: 38 additions & 1 deletion tests/priors/test_spike_slab.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class TestSpikeSlabPrior(unittest.TestCase):
def test_spike_slab_prior(self):
def test_sample(self):
N = 10000
D = 4
p_spike = 0.5
Expand Down Expand Up @@ -54,6 +54,43 @@ def test_spike_slab_prior(self):
atol=1e-2,
)

def test_log_prob(self):
D = 3
p_spike = 0.5
scale = 1.

for positive_only in [True, False]:
spike_slab_prior = SpikeSlabPrior(
D,
p_spike,
scale,
positive_only,
)

samples = torch.Tensor([[-1., 0., 1.]])

if positive_only:
assert spike_slab_prior.log_prob(samples)[0] == -torch.inf

samples = torch.abs(samples)
assert torch.allclose(
spike_slab_prior.log_prob(samples)[0],
(
-1. + torch.log(torch.tensor(1. - p_spike))
+ torch.log(torch.tensor(p_spike))
- 1. + torch.log(torch.tensor(1. - p_spike))
)
)
else:
assert torch.allclose(
spike_slab_prior.log_prob(samples)[0],
(
-1. + torch.log(torch.tensor(1. - p_spike)) - torch.log(torch.tensor(2.))
+ torch.log(torch.tensor(p_spike))
- 1. + torch.log(torch.tensor(1. - p_spike)) - torch.log(torch.tensor(2.))
)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit cbc1cb1

Please sign in to comment.