Skip to content

Commit

Permalink
updates to whitening, notebook edits
Browse files Browse the repository at this point in the history
  • Loading branch information
HarrisonSantiago committed Nov 27, 2024
1 parent 36a40b7 commit c4996a3
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 117 deletions.
125 changes: 55 additions & 70 deletions examples/Data_Whitening.ipynb

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions sparsecoding/transforms/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def whiten_images(images: torch.Tensor,
must be one of ['frequency', 'pca', 'zca', 'cholesky]")


def compute_image_whitening_stats(images: torch.Tensor,
n_components=None) -> Dict:
def compute_image_whitening_stats(images: torch.Tensor) -> Dict:
"""
Wrapper for computing whitening stats of an image dataset
Expand All @@ -80,7 +79,7 @@ def compute_image_whitening_stats(images: torch.Tensor,
"""
check_images(images)
flattened_images = images.flatten(start_dim=1)
return compute_whitening_stats(flattened_images, n_components)
return compute_whitening_stats(flattened_images)


def create_frequency_filter(image_size: int, f0_factor: float = 0.4) -> torch.Tensor:
Expand Down
80 changes: 36 additions & 44 deletions sparsecoding/transforms/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,30 @@
from typing import Dict


def compute_whitening_stats(X: torch.Tensor,
n_components=None):
def compute_whitening_stats(X: torch.Tensor):

"""
Given a tensor of data, compute statistics for whitening transform.
Parameters
----------
X: Input data of size [N, D]
n_components: Number of principal components to keep. If None, keep all components.
If int, keep that many components. If float between 0 and 1,
keep components that explain that fraction of variance.
Returns
----------
Dictionary containing whitening statistics (eigenvalues, eigenvectors, mean)
"""

# Step 1: Center Data
mean = torch.mean(X, dim=0)
X_centered = X - mean
Sigma = torch.cov(X_centered.T)

# Step 2: Compute eigenvalues/eigenvectors
eigenvalues, eigenvectors = torch.linalg.eigh(Sigma)

# Since eigh returns values in ascending order, reverse them to get descending order
eigenvalues = torch.flip(eigenvalues, dims=[0])
eigenvectors = torch.flip(eigenvectors, dims=[1])

# Step 3: We provide the option of returning a certain
# num of principal components. 0 <= n_components < 1 indicates you want to keep
# a certain percentage of explained variance. n_components > 1 indicates a
# you wish to keep that many. n_components = None means you want to keep all
if n_components is not None:
if isinstance(n_components, float):
if not 0 < n_components <= 1:
raise ValueError("If n_components is float, it must be between 0 and 1")
explained_variance_ratio = eigenvalues / torch.sum(eigenvalues)
cumulative_variance_ratio = torch.cumsum(explained_variance_ratio, dim=0)
n_components = torch.sum(cumulative_variance_ratio <= n_components) + 1
elif isinstance(n_components, int):
if not 0 < n_components <= len(eigenvalues):
raise ValueError(f"n_components must be between 1 and {len(eigenvalues)}")
else:
raise ValueError("n_components must be int or float")

# Instead of truncating, zero out unwanted components
mask = torch.zeros_like(eigenvalues)
mask[:n_components] = 1.0
eigenvalues = eigenvalues * mask
# For eigenvectors, we zero out the columns corresponding to zeroed eigenvalues
eigenvectors = eigenvectors * mask.unsqueeze(0)

return {
'mean': mean,
'eigenvalues': eigenvalues,
Expand All @@ -79,7 +49,9 @@ def whiten(X: torch.Tensor,
X: Input data of shape [N, D] where N are unique data elements of dimensionality D
algorithm: Whitening transform we want to apply, one of ['zca', 'pca', or 'cholesky']
stats: Dict containing precomputed whitening statistics (mean, eigenvectors, eigenvalues)
n_components: number of components to retain if computing stats
n_components: Number of principal components to keep. If None, keep all components.
If int, keep that many components. If float between 0 and 1,
keep components that explain that fraction of variance.
epsilon: Optional small constant to prevent division by zero
Returns
Expand All @@ -97,24 +69,44 @@ def whiten(X: torch.Tensor,
"""

if stats is None:
stats = compute_whitening_stats(X, n_components)
stats = compute_whitening_stats(X)

x_centered = X - stats.get('mean')

if algorithm == 'pca':
# For PCA: project onto eigenvectors and scale
scaling = torch.diag(1. / torch.sqrt(stats.get('eigenvalues') + epsilon))
W = scaling @ stats.get('eigenvectors').T
elif algorithm == 'zca':
# For ZCA: project, scale, and rotate back
scaling = torch.diag(1. / torch.sqrt(stats.get('eigenvalues') + epsilon))
W = (stats.get('eigenvectors') @
scaling @
stats.get('eigenvectors').T)
if algorithm == 'pca' or algorithm == 'zca':

scaling = 1. / torch.sqrt(stats.get('eigenvalues') + epsilon)

if n_components is not None:
if isinstance(n_components, float):
if not 0 < n_components <= 1:
raise ValueError("If n_components is float, it must be between 0 and 1")
explained_variance_ratio = stats.get('eigenvalues') / torch.sum(stats.get('eigenvalues'))
cumulative_variance_ratio = torch.cumsum(explained_variance_ratio, dim=0)
n_components = torch.sum(cumulative_variance_ratio <= n_components) + 1
elif isinstance(n_components, int):
if not 0 < n_components <= len(stats.get('eigenvalues')):
raise ValueError(f"n_components must be between 1 and {len(stats.get('eigenvalues'))}")
else:
raise ValueError("n_components must be int or float")

mask = torch.zeros_like(scaling)
mask[:n_components] = 1.0
scaling = scaling * mask

scaling = torch.diag(scaling)

if algorithm == 'pca':
# For PCA: project onto eigenvectors and scale
W = scaling @ stats.get('eigenvectors').T
else:
# For ZCA: project, scale, and rotate back
W = (stats.get('eigenvectors') @
scaling @
stats.get('eigenvectors').T)
elif algorithm == 'cholesky':
# Based on Cholesky decomp, also related to QR decomp
L = torch.linalg.cholesky(stats.get('covariance'))
# Whitening matrix is inverse of L
W = torch.linalg.inv(L)
else:
raise ValueError(f"Unknown whitening algorithm: {algorithm}, must be one of ['pca', 'zca', 'cholesky]")
Expand Down

0 comments on commit c4996a3

Please sign in to comment.