Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion modules/modelSetup/mixin/ModelSetupNoiseMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from modules.util.enum.TimestepDistribution import TimestepDistribution

import torch
import torch.distributions
from torch import Generator, Tensor


Expand Down Expand Up @@ -145,7 +146,8 @@ def _get_timestep_discrete(
if config.timestep_distribution in [
TimestepDistribution.UNIFORM,
TimestepDistribution.LOGIT_NORMAL,
TimestepDistribution.HEAVY_TAIL
TimestepDistribution.HEAVY_TAIL,
TimestepDistribution.BETA
]:
# continuous implementations
if config.timestep_distribution == TimestepDistribution.UNIFORM:
Expand All @@ -168,6 +170,45 @@ def _get_timestep_discrete(
)
u = 1.0 - u - scale * (torch.cos(math.pi / 2.0 * u) ** 2.0 - 1.0 + u)
timestep = u * num_timestep + min_timestep
elif config.timestep_distribution == TimestepDistribution.BETA:
# B-TTDM Configuration
# Noising Weight -> Alpha
# Noising Bias -> Beta
alpha = max(1e-4, config.noising_weight)
beta = max(1e-4, config.noising_bias)

# B-TTDM Paper optimization (Section 3.3):
# They strictly recommend Beta=1 and Alpha < 1.
# When Beta=1, we can use Inverse Transform Sampling (CDF inversion)
# which allows us to use torch.rand with the generator
# CDF^(-1)(u) = u^(1/alpha)
if abs(beta - 1.0) < 1e-5:
u = torch.rand(batch_size, generator=generator, device=generator.device)
u = u.pow(1.0 / alpha)
timestep = u * num_timestep + min_timestep

# Inverse case: Alpha=1, Beta != 1
# x = 1 - u^(1/beta)
elif abs(alpha - 1.0) < 1e-5:
u = torch.rand(
batch_size,
generator=generator,
device=generator.device)
u = 1.0 - u.pow(1.0 / beta)
timestep = u * num_timestep + min_timestep

else:
# Fallback for arbitrary Beta values (Beta != 1 and Alpha != 1).
# PyTorch's Beta distribution does not accept a generator directly.
# This path is mathematically correct for distribution shape, but
# technically bypasses the generator seed (uses global device seed).
# Since B-TTDM requires Beta=1, this path is rarely taken for this specific paper.
m = torch.distributions.Beta(
torch.tensor(alpha, device=generator.device),
torch.tensor(beta, device=generator.device)
)
u = m.sample((batch_size,))
timestep = u * num_timestep + min_timestep

timestep = num_train_timesteps * shift * timestep / ((shift - 1) * timestep + num_train_timesteps)
else:
Expand Down
1 change: 1 addition & 0 deletions modules/util/enum/TimestepDistribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class TimestepDistribution(Enum):
HEAVY_TAIL = 'HEAVY_TAIL'
COS_MAP = 'COS_MAP'
INVERTED_PARABOLA = 'INVERTED_PARABOLA'
BETA = 'BETA'

def __str__(self):
return self.value