Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
dpaiton committed Dec 31, 2024
1 parent 7547e7e commit 6678a5e
Show file tree
Hide file tree
Showing 26 changed files with 245 additions and 137 deletions.
25 changes: 20 additions & 5 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
from sparsecoding.test_utils import (bars_datas_fixture, bars_datasets_fixture,
bars_dictionary_fixture,
dataset_size_fixture, patch_size_fixture,
priors_fixture)
import torch

from sparsecoding.test_utils import (
bars_datas_fixture,
bars_datasets_fixture,
bars_dictionary_fixture,
dataset_size_fixture,
patch_size_fixture,
priors_fixture,
)

torch.manual_seed(1997)

# We import and define all fixtures in this file.
# This allows users to avoid any dependency fixtures.
# NOTE: This means pytest should only be run from this directory.
__all__ = ['dataset_size_fixture', 'patch_size_fixture', 'bars_datas_fixture', 'bars_datasets_fixture', 'bars_dictionary_fixture', 'priors_fixture']
__all__ = [
"dataset_size_fixture",
"patch_size_fixture",
"bars_datas_fixture",
"bars_datasets_fixture",
"bars_dictionary_fixture",
"priors_fixture",
]
20 changes: 10 additions & 10 deletions sparsecoding/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from .vanilla import Vanilla

__all__ = [
'IHT',
'InferenceMethod',
'ISTA',
'LCA',
'LSM',
'MP',
'OMP',
'PyTorchOptimizer',
'Vanilla'
]
"IHT",
"InferenceMethod",
"ISTA",
"LCA",
"LSM",
"MP",
"OMP",
"PyTorchOptimizer",
"Vanilla",
]
14 changes: 6 additions & 8 deletions sparsecoding/inference/iht.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class IHT(InferenceMethod):
"""

def __init__(self, sparsity, n_iter=10, solver=None, return_all_coefficients=False):
'''
"""
Parameters
----------
Expand All @@ -27,7 +27,7 @@ def __init__(self, sparsity, n_iter=10, solver=None, return_all_coefficients=Fal
can result in large memory usage/potential exhaustion. This function typically used for
debugging
solver : default=None
'''
"""
super().__init__(solver)
self.n_iter = n_iter
self.sparsity = sparsity
Expand All @@ -54,11 +54,10 @@ def infer(self, data, dictionary):
device = dictionary.device

# Define signal sparsity
K = np.ceil(self.sparsity*n_basis).astype(int)
K = np.ceil(self.sparsity * n_basis).astype(int)

# Initialize coefficients for the whole batch
coefficients = torch.zeros(
batch_size, n_basis, requires_grad=False, device=device)
coefficients = torch.zeros(batch_size, n_basis, requires_grad=False, device=device)

for _ in range(self.n_iter):
# Compute the prediction given the current coefficients
Expand All @@ -75,9 +74,8 @@ def infer(self, data, dictionary):
topK_values, indices = torch.topk(torch.abs(coefficients), K, dim=1)

# Reconstruct coefficients using the output of torch.topk
coefficients = (
torch.sign(coefficients)
* torch.zeros(batch_size, n_basis, device=device).scatter_(1, indices, topK_values)
coefficients = torch.sign(coefficients) * torch.zeros(batch_size, n_basis, device=device).scatter_(
1, indices, topK_values
)

return coefficients.detach()
2 changes: 1 addition & 1 deletion sparsecoding/inference/inference_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ def checknan(data=torch.tensor(0), name="data"):
If the nan found in data
"""
if torch.isnan(data).any():
raise ValueError("InferenceMethod error: nan in %s." % (name))
raise ValueError("InferenceMethod error: nan in %s." % (name))
29 changes: 18 additions & 11 deletions sparsecoding/inference/ista.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@


class ISTA(InferenceMethod):
def __init__(self, n_iter=100, sparsity_penalty=1e-2, stop_early=False,
epsilon=1e-2, solver=None, return_all_coefficients=False):
def __init__(
self,
n_iter=100,
sparsity_penalty=1e-2,
stop_early=False,
epsilon=1e-2,
solver=None,
return_all_coefficients=False,
):
"""Iterative shrinkage-thresholding algorithm for solving LASSO problems.
Parameters
Expand Down Expand Up @@ -52,8 +59,8 @@ def threshold_nonlinearity(self, u):
a : array-like, shape [batch_size, n_basis]
activations
"""
a = (torch.abs(u) - self.threshold).clamp(min=0.)
a = torch.sign(u)*a
a = (torch.abs(u) - self.threshold).clamp(min=0.0)
a = torch.sign(u) * a
return a

def infer(self, data, dictionary, coeff_0=None, use_checknan=False):
Expand Down Expand Up @@ -85,9 +92,8 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False):

# Calculate stepsize based on largest eigenvalue of
# dictionary.T @ dictionary.
lipschitz_constant = torch.linalg.eigvalsh(
torch.mm(dictionary.T, dictionary))[-1]
stepsize = 1. / lipschitz_constant
lipschitz_constant = torch.linalg.eigvalsh(torch.mm(dictionary.T, dictionary))[-1]
stepsize = 1.0 / lipschitz_constant
self.threshold = stepsize * self.sparsity_penalty

# Initialize coefficients.
Expand All @@ -104,16 +110,17 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False):
old_u = u.clone().detach()

if self.return_all_coefficients:
coefficients = torch.concat([coefficients,
self.threshold_nonlinearity(u).clone().unsqueeze(1)], dim=1)
coefficients = torch.concat(
[coefficients, self.threshold_nonlinearity(u).clone().unsqueeze(1)],
dim=1,
)

u -= stepsize * torch.mm(residual, dictionary)
self.coefficients = self.threshold_nonlinearity(u)

if self.stop_early:
# Stopping condition is function of change of the coefficients.
a_change = torch.mean(
torch.abs(old_u - u) / stepsize)
a_change = torch.mean(torch.abs(old_u - u) / stepsize)
if a_change < self.epsilon:
break

Expand Down
15 changes: 13 additions & 2 deletions sparsecoding/inference/ista_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from sparsecoding.test_utils import assert_allclose, assert_shape_equal


def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]):
def test_shape(
patch_size_fixture: int,
dataset_size_fixture: int,
bars_dictionary_fixture: torch.Tensor,
bars_datas_fixture: list[torch.Tensor],
bars_datasets_fixture: list[BarsDataset],
):
"""Test that ISTA inference returns expected shapes."""
N_ITER = 10

Expand All @@ -18,7 +24,12 @@ def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictiona
a = inference_method.infer(data, bars_dictionary_fixture)
assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture)

def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]):

def test_inference(
bars_dictionary_fixture: torch.Tensor,
bars_datas_fixture: list[torch.Tensor],
bars_datasets_fixture: list[BarsDataset],
):
"""Test that ISTA inference recovers the correct weights."""
N_ITER = 5000
for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture):
Expand Down
45 changes: 29 additions & 16 deletions sparsecoding/inference/lca.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@


class LCA(InferenceMethod):
def __init__(self, n_iter=100, coeff_lr=1e-3, threshold=0.1,
stop_early=False, epsilon=1e-2, solver=None,
return_all_coefficients="none", nonnegative=False):
def __init__(
self,
n_iter=100,
coeff_lr=1e-3,
threshold=0.1,
stop_early=False,
epsilon=1e-2,
solver=None,
return_all_coefficients="none",
nonnegative=False,
):
"""Method implemented according locally competative algorithm (LCA)
with the ideal soft thresholding function.
Expand Down Expand Up @@ -48,8 +56,9 @@ def __init__(self, n_iter=100, coeff_lr=1e-3, threshold=0.1,
self.n_iter = n_iter
self.nonnegative = nonnegative
if return_all_coefficients not in ["none", "membrane", "active"]:
raise ValueError("Invalid input for return_all_coefficients. Valid"
"inputs are: \"none\", \"membrane\", \"active\".")
raise ValueError(
"Invalid input for return_all_coefficients. Valid" 'inputs are: "none", "membrane", "active".'
)
self.return_all_coefficients = return_all_coefficients

def threshold_nonlinearity(self, u):
Expand All @@ -66,10 +75,10 @@ def threshold_nonlinearity(self, u):
Activations
"""
if self.nonnegative:
a = (u - self.threshold).clamp(min=0.)
a = (u - self.threshold).clamp(min=0.0)
else:
a = (torch.abs(u) - self.threshold).clamp(min=0.)
a = torch.sign(u)*a
a = (torch.abs(u) - self.threshold).clamp(min=0.0)
a = torch.sign(u) * a
return a

def grad(self, b, G, u, a):
Expand All @@ -89,7 +98,7 @@ def grad(self, b, G, u, a):
du : array-like, shape [batch_size, n_coefficients]
Gradient of membrane potentials
"""
du = b-u-(G@a.t()).t()
du = b - u - (G @ a.t()).t()
return du

def infer(self, data, dictionary, coeff_0=None, use_checknan=False):
Expand Down Expand Up @@ -127,8 +136,8 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False):

coefficients = torch.zeros((batch_size, 0, n_basis)).to(device)

b = (dictionary.t()@data.t()).t()
G = dictionary.t()@dictionary-torch.eye(n_basis).to(device)
b = (dictionary.t() @ data.t()).t()
G = dictionary.t() @ dictionary - torch.eye(n_basis).to(device)
for i in range(self.n_iter):
# store old membrane potentials to evalute stop early condition
if self.stop_early:
Expand All @@ -138,19 +147,23 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False):
if self.return_all_coefficients != "none":
if self.return_all_coefficients == "active":
coefficients = torch.concat(
[coefficients, self.threshold_nonlinearity(u).clone().unsqueeze(1)], dim=1)
[
coefficients,
self.threshold_nonlinearity(u).clone().unsqueeze(1),
],
dim=1,
)
else:
coefficients = torch.concat(
[coefficients, u.clone().unsqueeze(1)], dim=1)
coefficients = torch.concat([coefficients, u.clone().unsqueeze(1)], dim=1)

# compute new
a = self.threshold_nonlinearity(u)
du = self.grad(b, G, u, a)
u = u + self.coeff_lr*du
u = u + self.coeff_lr * du

# check stopping condition
if self.stop_early:
relative_change_in_coeff = torch.linalg.norm(old_u - u)/torch.linalg.norm(old_u)
relative_change_in_coeff = torch.linalg.norm(old_u - u) / torch.linalg.norm(old_u)
if relative_change_in_coeff < self.epsilon:
break

Expand Down
17 changes: 14 additions & 3 deletions sparsecoding/inference/lca_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from sparsecoding.test_utils import assert_allclose, assert_shape_equal


def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]):
def test_shape(
patch_size_fixture: int,
dataset_size_fixture: int,
bars_dictionary_fixture: torch.Tensor,
bars_datas_fixture: list[torch.Tensor],
bars_datasets_fixture: list[BarsDataset],
):
"""
Test that LCA inference returns expected shapes.
"""
Expand All @@ -21,7 +27,12 @@ def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictiona
a = inference_method.infer(data, bars_dictionary_fixture)
assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture)

def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]):

def test_inference(
bars_dictionary_fixture: torch.Tensor,
bars_datas_fixture: list[torch.Tensor],
bars_datasets_fixture: list[BarsDataset],
):
"""
Test that LCA inference recovers the correct weights.
"""
Expand All @@ -38,4 +49,4 @@ def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: li

a = inference_method.infer(data, bars_dictionary_fixture)

assert_allclose(a, dataset.weights, atol=5e-2)
assert_allclose(a, dataset.weights, atol=5e-2)
29 changes: 15 additions & 14 deletions sparsecoding/inference/lsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@


class LSM(InferenceMethod):
def __init__(self, n_iter=100, n_iter_LSM=6, beta=0.01, alpha=80.0,
sigma=0.005, sparse_threshold=10**-2, solver=None,
return_all_coefficients=False):
def __init__(
self,
n_iter=100,
n_iter_LSM=6,
beta=0.01,
alpha=80.0,
sigma=0.005,
sparse_threshold=10**-2,
solver=None,
return_all_coefficients=False,
):
"""Infer latent coefficients generating data given dictionary.
Method implemented according to "Group Sparse Coding with a Laplacian
Scale Mixture Prior" (P. J. Garrigues & B. A. Olshausen, 2010)
Expand Down Expand Up @@ -73,7 +81,7 @@ def lsm_Loss(self, data, dictionary, coefficients, lambdas, sigma):

# Compute loss
preds = torch.mm(dictionary, coefficients.t()).t()
mse_loss = (1/(2*(sigma**2))) * torch.sum(torch.square(data - preds), dim=1, keepdim=True)
mse_loss = (1 / (2 * (sigma**2))) * torch.sum(torch.square(data - preds), dim=1, keepdim=True)
sparse_loss = torch.sum(lambdas * torch.abs(coefficients), dim=1, keepdim=True)
loss = mse_loss + sparse_loss
return loss
Expand Down Expand Up @@ -107,10 +115,7 @@ def infer(self, data, dictionary):
# Outer loop, set sparsity penalties (lambdas).
for i in range(self.n_iter_LSM):
# Compute the initial values of lambdas
lambdas = (
(self.alpha + 1)
/ (self.beta + torch.abs(coefficients.detach()))
)
lambdas = (self.alpha + 1) / (self.beta + torch.abs(coefficients.detach()))

# Inner loop, optimize coefficients w/ current sparsity penalties.
# Exits early if converged before `n_iter`s.
Expand All @@ -132,16 +137,12 @@ def infer(self, data, dictionary):
optimizer.step()

# Break if coefficients have converged.
if (
last_loss is not None
and loss > 1.05 * last_loss
):
if last_loss is not None and loss > 1.05 * last_loss:
break

last_loss = loss

# Sparsify the final solution by discarding the small coefficients
coefficients.data[torch.abs(coefficients.data)
< self.sparse_threshold] = 0
coefficients.data[torch.abs(coefficients.data) < self.sparse_threshold] = 0

return coefficients.detach()
Loading

0 comments on commit 6678a5e

Please sign in to comment.