From 6678a5e2131673b377e6af108499004a8625f0e9 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:27:54 -0800 Subject: [PATCH] formatting --- conftest.py | 25 ++++++++--- sparsecoding/inference/__init__.py | 20 ++++----- sparsecoding/inference/iht.py | 14 +++--- sparsecoding/inference/inference_method.py | 2 +- sparsecoding/inference/ista.py | 29 +++++++----- sparsecoding/inference/ista_test.py | 15 ++++++- sparsecoding/inference/lca.py | 45 ++++++++++++------- sparsecoding/inference/lca_test.py | 17 +++++-- sparsecoding/inference/lsm.py | 29 ++++++------ sparsecoding/inference/lsm_test.py | 15 ++++++- sparsecoding/inference/mp.py | 9 ++-- sparsecoding/inference/omp.py | 9 ++-- .../inference/pytorch_optimizer_test.py | 26 ++++++++--- sparsecoding/inference/vanilla.py | 26 ++++++----- sparsecoding/inference/vanilla_test.py | 15 ++++++- sparsecoding/priors/__init__.py | 8 ++-- sparsecoding/priors/l0_prior.py | 12 ++--- sparsecoding/priors/prior.py | 1 + sparsecoding/priors/spike_slab_prior_test.py | 20 ++++----- sparsecoding/test_utils/__init__.py | 11 ++++- sparsecoding/test_utils/asserts.py | 4 +- sparsecoding/test_utils/asserts_test.py | 4 +- sparsecoding/test_utils/constant_fixtures.py | 6 ++- sparsecoding/test_utils/dataset_fixtures.py | 12 +++-- sparsecoding/test_utils/model_fixtures.py | 6 +-- sparsecoding/test_utils/prior_fixtures.py | 2 +- 26 files changed, 245 insertions(+), 137 deletions(-) diff --git a/conftest.py b/conftest.py index 18ac5fd..31b4f34 100644 --- a/conftest.py +++ b/conftest.py @@ -1,9 +1,24 @@ -from sparsecoding.test_utils import (bars_datas_fixture, bars_datasets_fixture, - bars_dictionary_fixture, - dataset_size_fixture, patch_size_fixture, - priors_fixture) +import torch + +from sparsecoding.test_utils import ( + bars_datas_fixture, + bars_datasets_fixture, + bars_dictionary_fixture, + dataset_size_fixture, + patch_size_fixture, + priors_fixture, +) + +torch.manual_seed(1997) # We import and define all fixtures in this file. # This allows users to avoid any dependency fixtures. # NOTE: This means pytest should only be run from this directory. -__all__ = ['dataset_size_fixture', 'patch_size_fixture', 'bars_datas_fixture', 'bars_datasets_fixture', 'bars_dictionary_fixture', 'priors_fixture'] +__all__ = [ + "dataset_size_fixture", + "patch_size_fixture", + "bars_datas_fixture", + "bars_datasets_fixture", + "bars_dictionary_fixture", + "priors_fixture", +] diff --git a/sparsecoding/inference/__init__.py b/sparsecoding/inference/__init__.py index 0cfa46d..8ac1dec 100644 --- a/sparsecoding/inference/__init__.py +++ b/sparsecoding/inference/__init__.py @@ -9,13 +9,13 @@ from .vanilla import Vanilla __all__ = [ - 'IHT', - 'InferenceMethod', - 'ISTA', - 'LCA', - 'LSM', - 'MP', - 'OMP', - 'PyTorchOptimizer', - 'Vanilla' -] \ No newline at end of file + "IHT", + "InferenceMethod", + "ISTA", + "LCA", + "LSM", + "MP", + "OMP", + "PyTorchOptimizer", + "Vanilla", +] diff --git a/sparsecoding/inference/iht.py b/sparsecoding/inference/iht.py index e62dacd..467cfc5 100644 --- a/sparsecoding/inference/iht.py +++ b/sparsecoding/inference/iht.py @@ -12,7 +12,7 @@ class IHT(InferenceMethod): """ def __init__(self, sparsity, n_iter=10, solver=None, return_all_coefficients=False): - ''' + """ Parameters ---------- @@ -27,7 +27,7 @@ def __init__(self, sparsity, n_iter=10, solver=None, return_all_coefficients=Fal can result in large memory usage/potential exhaustion. This function typically used for debugging solver : default=None - ''' + """ super().__init__(solver) self.n_iter = n_iter self.sparsity = sparsity @@ -54,11 +54,10 @@ def infer(self, data, dictionary): device = dictionary.device # Define signal sparsity - K = np.ceil(self.sparsity*n_basis).astype(int) + K = np.ceil(self.sparsity * n_basis).astype(int) # Initialize coefficients for the whole batch - coefficients = torch.zeros( - batch_size, n_basis, requires_grad=False, device=device) + coefficients = torch.zeros(batch_size, n_basis, requires_grad=False, device=device) for _ in range(self.n_iter): # Compute the prediction given the current coefficients @@ -75,9 +74,8 @@ def infer(self, data, dictionary): topK_values, indices = torch.topk(torch.abs(coefficients), K, dim=1) # Reconstruct coefficients using the output of torch.topk - coefficients = ( - torch.sign(coefficients) - * torch.zeros(batch_size, n_basis, device=device).scatter_(1, indices, topK_values) + coefficients = torch.sign(coefficients) * torch.zeros(batch_size, n_basis, device=device).scatter_( + 1, indices, topK_values ) return coefficients.detach() diff --git a/sparsecoding/inference/inference_method.py b/sparsecoding/inference/inference_method.py index 98ff0b8..62c97ec 100644 --- a/sparsecoding/inference/inference_method.py +++ b/sparsecoding/inference/inference_method.py @@ -70,4 +70,4 @@ def checknan(data=torch.tensor(0), name="data"): If the nan found in data """ if torch.isnan(data).any(): - raise ValueError("InferenceMethod error: nan in %s." % (name)) \ No newline at end of file + raise ValueError("InferenceMethod error: nan in %s." % (name)) diff --git a/sparsecoding/inference/ista.py b/sparsecoding/inference/ista.py index 5c1abfc..42bd21b 100644 --- a/sparsecoding/inference/ista.py +++ b/sparsecoding/inference/ista.py @@ -4,8 +4,15 @@ class ISTA(InferenceMethod): - def __init__(self, n_iter=100, sparsity_penalty=1e-2, stop_early=False, - epsilon=1e-2, solver=None, return_all_coefficients=False): + def __init__( + self, + n_iter=100, + sparsity_penalty=1e-2, + stop_early=False, + epsilon=1e-2, + solver=None, + return_all_coefficients=False, + ): """Iterative shrinkage-thresholding algorithm for solving LASSO problems. Parameters @@ -52,8 +59,8 @@ def threshold_nonlinearity(self, u): a : array-like, shape [batch_size, n_basis] activations """ - a = (torch.abs(u) - self.threshold).clamp(min=0.) - a = torch.sign(u)*a + a = (torch.abs(u) - self.threshold).clamp(min=0.0) + a = torch.sign(u) * a return a def infer(self, data, dictionary, coeff_0=None, use_checknan=False): @@ -85,9 +92,8 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False): # Calculate stepsize based on largest eigenvalue of # dictionary.T @ dictionary. - lipschitz_constant = torch.linalg.eigvalsh( - torch.mm(dictionary.T, dictionary))[-1] - stepsize = 1. / lipschitz_constant + lipschitz_constant = torch.linalg.eigvalsh(torch.mm(dictionary.T, dictionary))[-1] + stepsize = 1.0 / lipschitz_constant self.threshold = stepsize * self.sparsity_penalty # Initialize coefficients. @@ -104,16 +110,17 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False): old_u = u.clone().detach() if self.return_all_coefficients: - coefficients = torch.concat([coefficients, - self.threshold_nonlinearity(u).clone().unsqueeze(1)], dim=1) + coefficients = torch.concat( + [coefficients, self.threshold_nonlinearity(u).clone().unsqueeze(1)], + dim=1, + ) u -= stepsize * torch.mm(residual, dictionary) self.coefficients = self.threshold_nonlinearity(u) if self.stop_early: # Stopping condition is function of change of the coefficients. - a_change = torch.mean( - torch.abs(old_u - u) / stepsize) + a_change = torch.mean(torch.abs(old_u - u) / stepsize) if a_change < self.epsilon: break diff --git a/sparsecoding/inference/ista_test.py b/sparsecoding/inference/ista_test.py index 6a99f5a..397c555 100644 --- a/sparsecoding/inference/ista_test.py +++ b/sparsecoding/inference/ista_test.py @@ -5,7 +5,13 @@ from sparsecoding.test_utils import assert_allclose, assert_shape_equal -def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """Test that ISTA inference returns expected shapes.""" N_ITER = 10 @@ -18,7 +24,12 @@ def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictiona a = inference_method.infer(data, bars_dictionary_fixture) assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture) -def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """Test that ISTA inference recovers the correct weights.""" N_ITER = 5000 for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): diff --git a/sparsecoding/inference/lca.py b/sparsecoding/inference/lca.py index 49ef1a5..ad5023f 100644 --- a/sparsecoding/inference/lca.py +++ b/sparsecoding/inference/lca.py @@ -4,9 +4,17 @@ class LCA(InferenceMethod): - def __init__(self, n_iter=100, coeff_lr=1e-3, threshold=0.1, - stop_early=False, epsilon=1e-2, solver=None, - return_all_coefficients="none", nonnegative=False): + def __init__( + self, + n_iter=100, + coeff_lr=1e-3, + threshold=0.1, + stop_early=False, + epsilon=1e-2, + solver=None, + return_all_coefficients="none", + nonnegative=False, + ): """Method implemented according locally competative algorithm (LCA) with the ideal soft thresholding function. @@ -48,8 +56,9 @@ def __init__(self, n_iter=100, coeff_lr=1e-3, threshold=0.1, self.n_iter = n_iter self.nonnegative = nonnegative if return_all_coefficients not in ["none", "membrane", "active"]: - raise ValueError("Invalid input for return_all_coefficients. Valid" - "inputs are: \"none\", \"membrane\", \"active\".") + raise ValueError( + "Invalid input for return_all_coefficients. Valid" 'inputs are: "none", "membrane", "active".' + ) self.return_all_coefficients = return_all_coefficients def threshold_nonlinearity(self, u): @@ -66,10 +75,10 @@ def threshold_nonlinearity(self, u): Activations """ if self.nonnegative: - a = (u - self.threshold).clamp(min=0.) + a = (u - self.threshold).clamp(min=0.0) else: - a = (torch.abs(u) - self.threshold).clamp(min=0.) - a = torch.sign(u)*a + a = (torch.abs(u) - self.threshold).clamp(min=0.0) + a = torch.sign(u) * a return a def grad(self, b, G, u, a): @@ -89,7 +98,7 @@ def grad(self, b, G, u, a): du : array-like, shape [batch_size, n_coefficients] Gradient of membrane potentials """ - du = b-u-(G@a.t()).t() + du = b - u - (G @ a.t()).t() return du def infer(self, data, dictionary, coeff_0=None, use_checknan=False): @@ -127,8 +136,8 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False): coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) - b = (dictionary.t()@data.t()).t() - G = dictionary.t()@dictionary-torch.eye(n_basis).to(device) + b = (dictionary.t() @ data.t()).t() + G = dictionary.t() @ dictionary - torch.eye(n_basis).to(device) for i in range(self.n_iter): # store old membrane potentials to evalute stop early condition if self.stop_early: @@ -138,19 +147,23 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False): if self.return_all_coefficients != "none": if self.return_all_coefficients == "active": coefficients = torch.concat( - [coefficients, self.threshold_nonlinearity(u).clone().unsqueeze(1)], dim=1) + [ + coefficients, + self.threshold_nonlinearity(u).clone().unsqueeze(1), + ], + dim=1, + ) else: - coefficients = torch.concat( - [coefficients, u.clone().unsqueeze(1)], dim=1) + coefficients = torch.concat([coefficients, u.clone().unsqueeze(1)], dim=1) # compute new a = self.threshold_nonlinearity(u) du = self.grad(b, G, u, a) - u = u + self.coeff_lr*du + u = u + self.coeff_lr * du # check stopping condition if self.stop_early: - relative_change_in_coeff = torch.linalg.norm(old_u - u)/torch.linalg.norm(old_u) + relative_change_in_coeff = torch.linalg.norm(old_u - u) / torch.linalg.norm(old_u) if relative_change_in_coeff < self.epsilon: break diff --git a/sparsecoding/inference/lca_test.py b/sparsecoding/inference/lca_test.py index 2fe955f..0834ddd 100644 --- a/sparsecoding/inference/lca_test.py +++ b/sparsecoding/inference/lca_test.py @@ -5,7 +5,13 @@ from sparsecoding.test_utils import assert_allclose, assert_shape_equal -def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that LCA inference returns expected shapes. """ @@ -21,7 +27,12 @@ def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictiona a = inference_method.infer(data, bars_dictionary_fixture) assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture) -def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that LCA inference recovers the correct weights. """ @@ -38,4 +49,4 @@ def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: li a = inference_method.infer(data, bars_dictionary_fixture) - assert_allclose(a, dataset.weights, atol=5e-2) \ No newline at end of file + assert_allclose(a, dataset.weights, atol=5e-2) diff --git a/sparsecoding/inference/lsm.py b/sparsecoding/inference/lsm.py index c3faaf6..83e5ed5 100644 --- a/sparsecoding/inference/lsm.py +++ b/sparsecoding/inference/lsm.py @@ -4,9 +4,17 @@ class LSM(InferenceMethod): - def __init__(self, n_iter=100, n_iter_LSM=6, beta=0.01, alpha=80.0, - sigma=0.005, sparse_threshold=10**-2, solver=None, - return_all_coefficients=False): + def __init__( + self, + n_iter=100, + n_iter_LSM=6, + beta=0.01, + alpha=80.0, + sigma=0.005, + sparse_threshold=10**-2, + solver=None, + return_all_coefficients=False, + ): """Infer latent coefficients generating data given dictionary. Method implemented according to "Group Sparse Coding with a Laplacian Scale Mixture Prior" (P. J. Garrigues & B. A. Olshausen, 2010) @@ -73,7 +81,7 @@ def lsm_Loss(self, data, dictionary, coefficients, lambdas, sigma): # Compute loss preds = torch.mm(dictionary, coefficients.t()).t() - mse_loss = (1/(2*(sigma**2))) * torch.sum(torch.square(data - preds), dim=1, keepdim=True) + mse_loss = (1 / (2 * (sigma**2))) * torch.sum(torch.square(data - preds), dim=1, keepdim=True) sparse_loss = torch.sum(lambdas * torch.abs(coefficients), dim=1, keepdim=True) loss = mse_loss + sparse_loss return loss @@ -107,10 +115,7 @@ def infer(self, data, dictionary): # Outer loop, set sparsity penalties (lambdas). for i in range(self.n_iter_LSM): # Compute the initial values of lambdas - lambdas = ( - (self.alpha + 1) - / (self.beta + torch.abs(coefficients.detach())) - ) + lambdas = (self.alpha + 1) / (self.beta + torch.abs(coefficients.detach())) # Inner loop, optimize coefficients w/ current sparsity penalties. # Exits early if converged before `n_iter`s. @@ -132,16 +137,12 @@ def infer(self, data, dictionary): optimizer.step() # Break if coefficients have converged. - if ( - last_loss is not None - and loss > 1.05 * last_loss - ): + if last_loss is not None and loss > 1.05 * last_loss: break last_loss = loss # Sparsify the final solution by discarding the small coefficients - coefficients.data[torch.abs(coefficients.data) - < self.sparse_threshold] = 0 + coefficients.data[torch.abs(coefficients.data) < self.sparse_threshold] = 0 return coefficients.detach() diff --git a/sparsecoding/inference/lsm_test.py b/sparsecoding/inference/lsm_test.py index b7e4b05..6ac2155 100644 --- a/sparsecoding/inference/lsm_test.py +++ b/sparsecoding/inference/lsm_test.py @@ -5,7 +5,13 @@ from sparsecoding.test_utils import assert_allclose, assert_shape_equal -def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that LSM inference returns expected shapes. """ @@ -16,7 +22,12 @@ def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictiona a = inference_method.infer(data, bars_dictionary_fixture) assert_shape_equal(a, dataset.weights) -def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that LSM inference recovers the correct weights. """ diff --git a/sparsecoding/inference/mp.py b/sparsecoding/inference/mp.py index 2d413fc..4305976 100644 --- a/sparsecoding/inference/mp.py +++ b/sparsecoding/inference/mp.py @@ -12,7 +12,7 @@ class MP(InferenceMethod): """ def __init__(self, sparsity, solver=None, return_all_coefficients=False): - ''' + """ Parameters ---------- @@ -24,7 +24,7 @@ def __init__(self, sparsity, solver=None, return_all_coefficients=False): can result in large memory usage/potential exhaustion. This function typically used for debugging solver : default=None - ''' + """ super().__init__(solver) self.sparsity = sparsity self.return_all_coefficients = return_all_coefficients @@ -50,14 +50,13 @@ def infer(self, data, dictionary): device = dictionary.device # Define signal sparsity - K = np.ceil(self.sparsity*n_basis).astype(int) + K = np.ceil(self.sparsity * n_basis).astype(int) # Get dictionary norms in case atoms are not normalized dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True) # Initialize coefficients for the whole batch - coefficients = torch.zeros( - batch_size, n_basis, requires_grad=False, device=device) + coefficients = torch.zeros(batch_size, n_basis, requires_grad=False, device=device) residual = data.clone() # [batch_size, n_features] diff --git a/sparsecoding/inference/omp.py b/sparsecoding/inference/omp.py index 33c49d7..db99427 100644 --- a/sparsecoding/inference/omp.py +++ b/sparsecoding/inference/omp.py @@ -13,7 +13,7 @@ class OMP(InferenceMethod): """ def __init__(self, sparsity, solver=None, return_all_coefficients=False): - ''' + """ Parameters ---------- @@ -25,7 +25,7 @@ def __init__(self, sparsity, solver=None, return_all_coefficients=False): can result in large memory usage/potential exhaustion. This function typically used for debugging solver : default=None - ''' + """ super().__init__(solver) self.sparsity = sparsity self.return_all_coefficients = return_all_coefficients @@ -51,14 +51,13 @@ def infer(self, data, dictionary): device = dictionary.device # Define signal sparsity - K = np.ceil(self.sparsity*n_basis).astype(int) + K = np.ceil(self.sparsity * n_basis).astype(int) # Get dictionary norms in case atoms are not normalized dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True) # Initialize coefficients for the whole batch - coefficients = torch.zeros( - batch_size, n_basis, requires_grad=False, device=device) + coefficients = torch.zeros(batch_size, n_basis, requires_grad=False, device=device) residual = data.clone() # [batch_size, n_features] diff --git a/sparsecoding/inference/pytorch_optimizer_test.py b/sparsecoding/inference/pytorch_optimizer_test.py index f90fd66..e7322bc 100644 --- a/sparsecoding/inference/pytorch_optimizer_test.py +++ b/sparsecoding/inference/pytorch_optimizer_test.py @@ -10,22 +10,24 @@ def lasso_loss(data, dictionary, coefficients, sparsity_penalty): Generic MSE + l1-norm loss. """ batch_size = data.shape[0] - datahat = (dictionary@coefficients.t()).t() + datahat = (dictionary @ coefficients.t()).t() - mse_loss = torch.linalg.vector_norm(datahat-data, dim=1).square() + mse_loss = torch.linalg.vector_norm(datahat - data, dim=1).square() sparse_loss = torch.sum(torch.abs(coefficients), axis=1) - total_loss = (mse_loss + sparsity_penalty*sparse_loss)/batch_size + total_loss = (mse_loss + sparsity_penalty * sparse_loss) / batch_size return total_loss + def loss_fn(data, dictionary, coefficients): return lasso_loss( data, dictionary, coefficients, - sparsity_penalty=1., + sparsity_penalty=1.0, ) + def optimizer_fn(coefficients): return torch.optim.Adam( coefficients, @@ -35,7 +37,14 @@ def optimizer_fn(coefficients): weight_decay=0, ) -def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that PyTorchOptimizer inference returns expected shapes. """ @@ -48,7 +57,12 @@ def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictiona a = inference_method.infer(data, bars_dictionary_fixture) assert_shape_equal(a, dataset.weights) -def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that PyTorchOptimizer inference recovers the correct weights. """ diff --git a/sparsecoding/inference/vanilla.py b/sparsecoding/inference/vanilla.py index 473fcb8..11bbfa4 100644 --- a/sparsecoding/inference/vanilla.py +++ b/sparsecoding/inference/vanilla.py @@ -4,9 +4,16 @@ class Vanilla(InferenceMethod): - def __init__(self, n_iter=100, coeff_lr=1e-3, sparsity_penalty=0.2, - stop_early=False, epsilon=1e-2, solver=None, - return_all_coefficients=False): + def __init__( + self, + n_iter=100, + coeff_lr=1e-3, + sparsity_penalty=0.2, + stop_early=False, + epsilon=1e-2, + solver=None, + return_all_coefficients=False, + ): """Gradient descent with Euler's method on model in Olshausen & Field (1997) with laplace prior over coefficients (corresponding to l-1 norm penalty). @@ -61,8 +68,7 @@ def grad(self, residual, dictionary, a): da : array-like, shape [batch_size, n_coefficients] Gradient of membrane potentials """ - da = (dictionary.t()@residual.t()).t() - \ - self.sparsity_penalty*torch.sign(a) + da = (dictionary.t() @ residual.t()).t() - self.sparsity_penalty * torch.sign(a) return da def infer(self, data, dictionary, coeff_0=None, use_checknan=False): @@ -96,11 +102,11 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False): if coeff_0 is not None: a = coeff_0.to(device) else: - a = torch.rand((batch_size, n_basis)).to(device)-0.5 + a = torch.rand((batch_size, n_basis)).to(device) - 0.5 coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) - residual = data - (dictionary@a.t()).t() + residual = data - (dictionary @ a.t()).t() for i in range(self.n_iter): if self.return_all_coefficients: @@ -110,13 +116,13 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False): old_a = a.clone().detach() da = self.grad(residual, dictionary, a) - a = a + self.coeff_lr*da + a = a + self.coeff_lr * da if self.stop_early: - if torch.linalg.norm(old_a - a)/torch.linalg.norm(old_a) < self.epsilon: + if torch.linalg.norm(old_a - a) / torch.linalg.norm(old_a) < self.epsilon: break - residual = data - (dictionary@a.t()).t() + residual = data - (dictionary @ a.t()).t() if use_checknan: self.checknan(a, "coefficients") diff --git a/sparsecoding/inference/vanilla_test.py b/sparsecoding/inference/vanilla_test.py index 57c2244..9c556e5 100644 --- a/sparsecoding/inference/vanilla_test.py +++ b/sparsecoding/inference/vanilla_test.py @@ -5,7 +5,13 @@ from sparsecoding.test_utils import assert_allclose, assert_shape_equal -def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that Vanilla inference returns expected shapes. """ @@ -20,7 +26,12 @@ def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictiona a = inference_method.infer(data, bars_dictionary_fixture) assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture) -def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that Vanilla inference recovers the correct weights. """ diff --git a/sparsecoding/priors/__init__.py b/sparsecoding/priors/__init__.py index 94e9994..75314d7 100644 --- a/sparsecoding/priors/__init__.py +++ b/sparsecoding/priors/__init__.py @@ -3,7 +3,7 @@ from .spike_slab_prior import SpikeSlabPrior __all__ = [ - "Prior", - "L0Prior", - "SpikeSlabPrior", -] \ No newline at end of file + "Prior", + "L0Prior", + "SpikeSlabPrior", +] diff --git a/sparsecoding/priors/l0_prior.py b/sparsecoding/priors/l0_prior.py index 1e41853..590bcad 100644 --- a/sparsecoding/priors/l0_prior.py +++ b/sparsecoding/priors/l0_prior.py @@ -33,10 +33,7 @@ def __init__( def D(self): return self.prob_distr.shape[0] - def sample( - self, - num_samples: int - ): + def sample(self, num_samples: int): N = num_samples num_active_weights = 1 + torch.multinomial( @@ -46,10 +43,7 @@ def sample( ) # [N] d_idxs = torch.arange(self.D) - active_idx_mask = ( - d_idxs.reshape(1, self.D) - < num_active_weights.reshape(N, 1) - ) # [N, self.D] + active_idx_mask = d_idxs.reshape(1, self.D) < num_active_weights.reshape(N, 1) # [N, self.D] n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] # Need to shuffle here so that it's not always the first weights that are active. @@ -60,6 +54,6 @@ def sample( active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] weights = torch.zeros((N, self.D), dtype=torch.float32) - weights[active_weight_idxs] += 1. + weights[active_weight_idxs] += 1.0 return weights diff --git a/sparsecoding/priors/prior.py b/sparsecoding/priors/prior.py index 4940661..061c62f 100644 --- a/sparsecoding/priors/prior.py +++ b/sparsecoding/priors/prior.py @@ -9,6 +9,7 @@ class Prior(ABC): weights_dim : int Number of weights for each sample. """ + @abstractmethod def D(self): """ diff --git a/sparsecoding/priors/spike_slab_prior_test.py b/sparsecoding/priors/spike_slab_prior_test.py index f1023ac..ed601bc 100644 --- a/sparsecoding/priors/spike_slab_prior_test.py +++ b/sparsecoding/priors/spike_slab_prior_test.py @@ -9,11 +9,11 @@ def test_spike_slab_prior(positive_only: bool): N = 10000 D = 4 p_spike = 0.5 - scale = 1. + scale = 1.0 torch.manual_seed(1997) - p_slab = 1. - p_spike + p_slab = 1.0 - p_spike spike_slab_prior = SpikeSlabPrior( D, @@ -27,7 +27,7 @@ def test_spike_slab_prior(positive_only: bool): # Check spike probability. assert torch.allclose( - torch.sum(weights == 0.) / (N * D), + torch.sum(weights == 0.0) / (N * D), torch.tensor(p_spike), atol=1e-2, ) @@ -35,20 +35,20 @@ def test_spike_slab_prior(positive_only: bool): # Check Laplacian distribution. N_slab = p_slab * N * D if positive_only: - assert torch.sum(weights < 0.) == 0 + assert torch.sum(weights < 0.0) == 0 else: assert torch.allclose( - torch.sum(weights < 0.) / N_slab, - torch.sum(weights > 0.) / N_slab, + torch.sum(weights < 0.0) / N_slab, + torch.sum(weights > 0.0) / N_slab, atol=2e-2, ) weights = torch.abs(weights) - laplace_weights = weights[weights > 0.] - for quantile in torch.arange(5) / 5.: - cutoff = -torch.log(1. - quantile) + laplace_weights = weights[weights > 0.0] + for quantile in torch.arange(5) / 5.0: + cutoff = -torch.log(1.0 - quantile) assert torch.allclose( torch.sum(laplace_weights < cutoff) / N_slab, quantile, atol=1e-2, - ) \ No newline at end of file + ) diff --git a/sparsecoding/test_utils/__init__.py b/sparsecoding/test_utils/__init__.py index 7b8371b..8c18de1 100644 --- a/sparsecoding/test_utils/__init__.py +++ b/sparsecoding/test_utils/__init__.py @@ -4,4 +4,13 @@ from .model_fixtures import bars_dictionary_fixture from .prior_fixtures import priors_fixture -__all__ = ['assert_allclose', 'assert_shape_equal', 'dataset_size_fixture', 'patch_size_fixture', 'bars_datas_fixture', 'bars_datasets_fixture', 'bars_dictionary_fixture', 'priors_fixture'] +__all__ = [ + "assert_allclose", + "assert_shape_equal", + "dataset_size_fixture", + "patch_size_fixture", + "bars_datas_fixture", + "bars_datasets_fixture", + "bars_dictionary_fixture", + "priors_fixture", +] diff --git a/sparsecoding/test_utils/asserts.py b/sparsecoding/test_utils/asserts.py index bc7056a..64db496 100644 --- a/sparsecoding/test_utils/asserts.py +++ b/sparsecoding/test_utils/asserts.py @@ -4,8 +4,10 @@ DEFAULT_ATOL = 1e-6 DEFAULT_RTOL = 1e-5 + def assert_allclose(a: np.ndarray, b: np.ndarray, rtol: float = DEFAULT_RTOL, atol: float = DEFAULT_ATOL) -> None: return np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + def assert_shape_equal(a: np.ndarray, b: np.ndarray) -> None: - assert a.shape == b.shape \ No newline at end of file + assert a.shape == b.shape diff --git a/sparsecoding/test_utils/asserts_test.py b/sparsecoding/test_utils/asserts_test.py index 912b0ad..1af9705 100644 --- a/sparsecoding/test_utils/asserts_test.py +++ b/sparsecoding/test_utils/asserts_test.py @@ -1,4 +1,3 @@ - import numpy as np import torch @@ -10,16 +9,19 @@ def test_pytorch_all_close(): expected = torch.ones([10, 10]) assert_allclose(result, expected) + def test_np_all_close(): result = np.ones([100, 100]) + 1e-10 expected = np.ones([100, 100]) assert_allclose(result, expected) + def test_assert_pytorch_shape_equal(): a = torch.zeros([10, 10]) b = torch.ones([10, 10]) assert_shape_equal(a, b) + def test_assert_np_shape_equal(): a = np.zeros([100, 100]) b = np.ones([100, 100]) diff --git a/sparsecoding/test_utils/constant_fixtures.py b/sparsecoding/test_utils/constant_fixtures.py index eac6253..1c04698 100644 --- a/sparsecoding/test_utils/constant_fixtures.py +++ b/sparsecoding/test_utils/constant_fixtures.py @@ -3,10 +3,12 @@ PATCH_SIZE = 8 DATASET_SIZE = 1000 + @pytest.fixture() def patch_size_fixture() -> int: - return PATCH_SIZE + return PATCH_SIZE + @pytest.fixture() def dataset_size_fixture() -> int: - return DATASET_SIZE \ No newline at end of file + return DATASET_SIZE diff --git a/sparsecoding/test_utils/dataset_fixtures.py b/sparsecoding/test_utils/dataset_fixtures.py index 9e3dcc9..c3f20d4 100644 --- a/sparsecoding/test_utils/dataset_fixtures.py +++ b/sparsecoding/test_utils/dataset_fixtures.py @@ -1,4 +1,3 @@ - import pytest import torch @@ -7,7 +6,9 @@ @pytest.fixture() -def bars_datasets_fixture(patch_size_fixture: int, dataset_size_fixture: int, priors_fixture: list[Prior]) -> list[BarsDataset]: +def bars_datasets_fixture( + patch_size_fixture: int, dataset_size_fixture: int, priors_fixture: list[Prior] +) -> list[BarsDataset]: return [ BarsDataset( patch_size=patch_size_fixture, @@ -17,9 +18,12 @@ def bars_datasets_fixture(patch_size_fixture: int, dataset_size_fixture: int, pr for prior in priors_fixture ] + @pytest.fixture() -def bars_datas_fixture(patch_size_fixture: int, dataset_size_fixture: int, bars_datasets_fixture: list[BarsDataset]) -> list[torch.Tensor]: +def bars_datas_fixture( + patch_size_fixture: int, dataset_size_fixture: int, bars_datasets_fixture: list[BarsDataset] +) -> list[torch.Tensor]: return [ dataset.data.reshape((dataset_size_fixture, patch_size_fixture * patch_size_fixture)) for dataset in bars_datasets_fixture - ] \ No newline at end of file + ] diff --git a/sparsecoding/test_utils/model_fixtures.py b/sparsecoding/test_utils/model_fixtures.py index 95573f6..1a0d895 100644 --- a/sparsecoding/test_utils/model_fixtures.py +++ b/sparsecoding/test_utils/model_fixtures.py @@ -1,12 +1,10 @@ - import pytest import torch from sparsecoding.datasets import BarsDataset -torch.manual_seed(1997) @pytest.fixture() def bars_dictionary_fixture(patch_size_fixture: int, bars_datasets_fixture: list[BarsDataset]) -> torch.Tensor: - """Return a bars dataset basis reshaped to represent a dictionary.""" - return bars_datasets_fixture[0].basis.reshape((2 * patch_size_fixture, patch_size_fixture * patch_size_fixture)).T \ No newline at end of file + """Return a bars dataset basis reshaped to represent a dictionary.""" + return bars_datasets_fixture[0].basis.reshape((2 * patch_size_fixture, patch_size_fixture * patch_size_fixture)).T diff --git a/sparsecoding/test_utils/prior_fixtures.py b/sparsecoding/test_utils/prior_fixtures.py index ffa2ef1..08a3125 100644 --- a/sparsecoding/test_utils/prior_fixtures.py +++ b/sparsecoding/test_utils/prior_fixtures.py @@ -21,4 +21,4 @@ def priors_fixture(patch_size_fixture: int) -> list[Prior]: ).type(torch.float32) ), ), - ] \ No newline at end of file + ]