Skip to content

Commit

Permalink
Making VQ training ready (#137)
Browse files Browse the repository at this point in the history
* adding training compatible code for rvq

* fixing stuff

* fix

* fix

* fix

* Update moshi/moshi/quantization/core_vq.py

Co-authored-by: Ishita Mediratta <[email protected]>

* plop

---------

Co-authored-by: Ishita Mediratta <[email protected]>
  • Loading branch information
adefossez and ishitamed19 authored Oct 14, 2024
1 parent be9fe49 commit 2df3a38
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 40 deletions.
139 changes: 129 additions & 10 deletions moshi/moshi/quantization/core_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
import typing as tp

from einops import rearrange
from einops import rearrange, repeat
import torch
from torch import nn
from torch import distributed
Expand All @@ -34,12 +35,6 @@ def _ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, decay: float) -> N
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))


def _uniform_init(*shape: int) -> torch.Tensor:
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t


def _sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor:
num_samples, device = samples.shape[0], samples.device

Expand All @@ -65,6 +60,29 @@ def _is_distributed() -> bool:
return distributed.is_initialized() and distributed.get_world_size() > 1


def _run_kmeans(samples: torch.Tensor, num_clusters: int, num_iters: int = 50) -> tp.Tuple[torch.Tensor, torch.Tensor]:
# Kmeans algorithm used to initialize the codebooks.
dim = samples.shape[-1]
means = _sample_vectors(samples, num_clusters)
bins = None

for _ in range(num_iters):
dists = torch.cdist(samples[None], means[None], p=2)[0]
buckets = dists.argmin(dim=-1)
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins.clamp_(min=1)

new_means = torch.zeros_like(means)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means /= bins[..., None]
resampled = _sample_vectors(samples, num_clusters)
means = torch.where(zero_mask[..., None], resampled, new_means)

assert bins is not None
return means, bins


def zero_scalar(device) -> torch.Tensor:
"""Returns a 0. value on the given device without introducing a synchronization point."""
return torch.zeros([1], device=device)[0]
Expand Down Expand Up @@ -106,7 +124,6 @@ def __init__(
):
super().__init__()
self.decay = decay
embedding = torch.zeros(codebook_size, dim)

self.dim = dim
self.codebook_size = codebook_size
Expand All @@ -116,12 +133,13 @@ def __init__(
self.replaced_usage_ratio = replaced_usage_ratio
self.check_unused_every = check_unused_every
self._next_unused_check = check_unused_every
self._cached_initialized = False

self.register_buffer("_initialized", torch.tensor([False], dtype=torch.float))
self.register_buffer("cluster_usage", torch.ones(codebook_size))
embedding = torch.zeros(codebook_size, dim)
self.register_buffer("embedding_sum", embedding)
self.register_buffer("_embedding", None, persistent=False)
self._cached_initialized = False

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs) -> None:
# Mapping old names to new names
Expand Down Expand Up @@ -149,6 +167,42 @@ def embedding(self) -> torch.Tensor:
return embedding
return self._embedding

@property
def initialized(self) -> bool:
"""Cached version of self._initialized,
This assumes that once the module is initialized, it will never go back to the uninitialized state."""
if not self._cached_initialized:
self._cached_initialized = self._initialized.item()
return self._cached_initialized

def _init_embedding(self, data: torch.Tensor) -> None:
# Initialize the codebook, e.g. using kmeans.
if self.initialized:
return

rank = 0
if _is_distributed():
rank = distributed.get_rank()
# First gathering shapes in case not all GPUs have the same effective batch size.
# then gathering the actual content.
if rank == 0:
other_shapes: tp.List[torch.Size] = [None] * distributed.get_world_size() # type: ignore
distributed.gather_object(data.shape, other_shapes)
other_data: tp.List[torch.Tensor] = [
torch.empty(shape, device=data.device, dtype=data.dtype) for shape in other_shapes]
distributed.gather(data, other_data)
data = torch.cat(other_data, dim=0)
else:
distributed.gather_object(data.shape)
distributed.gather(data)
if rank == 0:
embedding, cluster_usage = _run_kmeans(data, self.codebook_size)
self.embedding_sum.data.copy_(embedding * cluster_usage[:, None])
self.cluster_usage.data.copy_(cluster_usage)
self._initialized.data.fill_(1)
# Make sure all buffers across workers are in sync after initialization
self._broadcast_buffers()

def _broadcast_buffers(self) -> None:
if _is_distributed():
for buffer in self.buffers():
Expand All @@ -168,6 +222,25 @@ def _replace_expired_codes(self, samples: torch.Tensor, mask: torch.Tensor) -> N
mask, replace_cluster_usage, self.cluster_usage
)

def _check_expired_codes(self, batch_samples: torch.Tensor) -> torch.Tensor:
# Checks whether some centroids are under utilized, and replace them if necessary.
if not self.initialized:
return zero_scalar(batch_samples.device)

self._next_unused_check -= 1
if self._next_unused_check > 0:
return zero_scalar(batch_samples.device)
# we don't check every iteration to avoid having too many sync points.
self._next_unused_check = self.check_unused_every
threshold_cluster_usage = self.threshold_usage_ratio * self.cluster_usage.sum() / self.codebook_size
expired_codes = self.cluster_usage < threshold_cluster_usage

assert batch_samples.dim() == 2
self._replace_expired_codes(batch_samples, mask=expired_codes)
self._broadcast_buffers()

return expired_codes.float().mean()

def _reshape_input(self, x: torch.Tensor) -> torch.Tensor:
# Flattens all the dimensions but the last one, e.g. return a vector of shape `[N, D]`.
x = rearrange(x, "... d -> (...) d")
Expand Down Expand Up @@ -211,11 +284,38 @@ def forward(
shape = x.shape
x = self._reshape_input(x)

if self.training and initialize:
# If initialize is False, we are not allowed to initialize this layer
# and the rest of the code will operate on a 0 filled codebook.
# This is due to previous layers having used the batch to run kmeans init
# and thus, the residuals are mostly 0s.
self._init_embedding(x.detach())

flat_codes = self._quantize(x)
codes = self._reshape_codes(flat_codes, shape)
quantized = self.decode(codes)
metrics: tp.Dict[str, torch.Tensor] = {}

if self.training:
# We do the expiry of the unused codes at this point as buffers are in sync
# and all the workers will take the same decision.
expired = self._check_expired_codes(x)
metrics['rvq_expired'] = expired
cluster_usage = torch.zeros_like(self.cluster_usage)
cluster_usage.scatter_add_(
0, flat_codes, torch.ones_like(flat_codes, dtype=cluster_usage.dtype))
_ema_inplace(self.cluster_usage, cluster_usage, self.decay)

if self.initialized:
# We report the entropy normalized by that of the uniform distribution,
# This means the codebooks are optimally used when entropy=1.
metrics['rvq_entropy'] = _compute_entropy(self.cluster_usage) / math.log(self.codebook_size)

embedding_sum = torch.zeros_like(self.embedding_sum)
embedding_sum.scatter_add_(0, repeat(flat_codes, "n -> n d", d=self.dim), x)
_ema_inplace(self.embedding_sum, embedding_sum, self.decay)
self.register_buffer('_embedding', None)

return _CodebookForwardResult(quantized, codes, metrics)


Expand Down Expand Up @@ -274,6 +374,10 @@ def __init__(
def embedding(self):
return self._codebook.embedding

@property
def initialized(self):
return self._codebook.initialized

def _rearrange_input(self, x):
x = rearrange(x, "b d n -> b n d")
return x
Expand All @@ -300,7 +404,11 @@ def forward(self, x: torch.Tensor, initialize: bool = True) -> _VQForwardResult:
x = self._rearrange_input(x)
quantized, codes, metrics = self._codebook(x, initialize=initialize)

loss = zero_scalar(x.device)
if self.training:
quantized = x + (quantized - x).detach()
loss = F.mse_loss(x, quantized.detach())
else:
loss = zero_scalar(x.device)

quantized = self.project_out(quantized)
quantized = self._rearrange_output(quantized)
Expand Down Expand Up @@ -341,9 +449,16 @@ def forward(
previous_layer_is_initialized = True

for i, layer in enumerate(self.layers[:n_q]): # type: ignore
if self.training:
this_layer_is_initialized = layer.initialized
# We only allow the kmeans initialization if the previous layer is already initialized from the previous
# iterations, this is to avoid learning the subsequent kmeans on the same batch, which would eventually
# lead to its exhaustion and running kmeans on 0 values.
quantized, codes, loss, metrics = layer(
residual, initialize=previous_layer_is_initialized
)
if self.training:
previous_layer_is_initialized = this_layer_is_initialized # type: ignore

quantized = quantized.detach()
residual = residual - quantized
Expand All @@ -359,6 +474,10 @@ def forward(
all_metrics[key] = value / n_q
all_metrics[key + f"_{i + self.codebook_offset}"] = value

if self.training:
# Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
quantized_out = x + (quantized_out - x).detach()

out_losses, out_codes = map(torch.stack, (all_losses, all_codes))
return _VQForwardResult(quantized_out, out_codes, out_losses, all_metrics)

Expand Down
38 changes: 8 additions & 30 deletions moshi/moshi/quantization/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# LICENSE file in the root directory of this source tree.

import math
import random
import typing as tp

import torch
Expand Down Expand Up @@ -48,35 +49,27 @@ def __init__(
output_dimension: tp.Optional[int] = None,
n_q: int = 8,
q_dropout: bool = False,
q_first_only_proba: float = 0.0,
no_quantization_rate: float = 0.0,
bins: int = 1024,
decay: float = 0.99,
threshold_usage_ratio: float = 0.1,
replaced_usage_ratio: float = 1.0,
codebook_offset: int = 0,
force_projection: bool = False,
generator_seed: tp.Optional[int] = None,
):
super().__init__()
self.max_n_q = n_q
self.n_q = n_q
self.q_dropout = q_dropout
self.no_quantization_rate = no_quantization_rate
self.q_first_only_proba = q_first_only_proba
self.dimension = dimension
self.input_dimension = input_dimension or dimension
self.output_dimension = output_dimension or dimension
self.bins = bins
self.decay = decay
self.rng_dropout = random.Random(1234)
self.input_proj: torch.nn.Module
self.output_proj: torch.nn.Module
self.generator = None
if generator_seed is not None:
self.generator = torch.Generator(
device="cuda" if torch.cuda.is_available() else "cpu"
)
self.generator.manual_seed(generator_seed)
if self.input_dimension == self.dimension and not force_projection:
self.input_proj = torch.nn.Identity()
else:
Expand Down Expand Up @@ -116,17 +109,19 @@ def forward(self, x: torch.Tensor, frame_rate: int):
"""
n_q = self.n_q
x = self.input_proj(x)

if self.training and self.q_dropout:
n_q = self.rng_dropout.randint(1, self.n_q)
bw_per_q = math.log2(self.bins) * frame_rate / 1000
quantized, codes, commit_loss, metrics = self.vq(x, n_q=n_q)
B, _, _ = quantized.shape
if self.training and self.no_quantization_rate > 0:
mask = (torch.rand(B, 1, 1, device=x.device) <= self.no_quantization_rate).float()
quantized = x * mask + (1 - mask) * quantized
quantized = self.output_proj(quantized)
codes = codes.transpose(0, 1)
# codes is [B, K, T], with T frames, K nb of codebooks.
bw = torch.tensor(n_q * bw_per_q).to(x)
return QuantizedResult(
quantized, codes, bw, penalty=torch.mean(commit_loss), metrics=metrics
)
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss), metrics=metrics)

def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
Expand Down Expand Up @@ -174,18 +169,13 @@ class SplitResidualVectorQuantizer(BaseQuantizer):
Args:
n_q (int): Number of residual vector quantizers used.
n_semantic_q (int): Number of residual vector quantizers used for the semantic quantizer.
no_quantization_mode (str): if 'true_skip', when doing no quantization, the input will not go
through the sub quantizers. If `independent`, independent decisions are taken by
the semantic and acoustic quantizers. If `same` (the default), the same decision is taken by both.
**kwargs: Arguments to the constructor of `ResidualVectorQuantizer` that are shared between both.
"""

def __init__(
self,
*,
n_q: int = 8,
no_quantization_rate: float = 0.0,
no_quantization_mode: str = "same",
n_q_semantic: int = 1,
**kwargs,
):
Expand All @@ -197,15 +187,6 @@ def __init__(
self.max_n_q = n_q
self.n_q_semantic = n_q_semantic
self.n_q_acoustic = n_q - n_q_semantic
if no_quantization_mode == "true_skip":
self.no_quantization_rate = no_quantization_rate
# Setting to zero for the underlying RVQ.
no_quantization_rate = 0.0
else:
self.no_quantization_rate = 0.0
if no_quantization_mode == "same":
kwargs["generator_seed"] = 1234
kwargs["no_quantization_rate"] = no_quantization_rate
q_dropout = kwargs.pop("q_dropout", False)
self.rvq_first = ResidualVectorQuantizer(
n_q=n_q_semantic, force_projection=True, q_dropout=False, **kwargs
Expand All @@ -217,9 +198,6 @@ def __init__(
q_dropout=q_dropout,
**kwargs,
)
if no_quantization_mode == "true_skip":
assert self.rvq_first.input_dimension == self.rvq_first.output_dimension
assert self.rvq_rest.input_dimension == self.rvq_rest.output_dimension

def _renorm_and_add(
self,
Expand Down

0 comments on commit 2df3a38

Please sign in to comment.