Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion modules/modelSetup/BaseChromaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ def predict(
config,
)

# E-TSDM: Map the timestep for U-Net conditioning only.
# The physical noise addition and loss target remain based on the original 'timestep'.
model_timestep = self._apply_etsdm_timestep_mapping(timestep, config)

scaled_noisy_latent_image, sigma = self._add_noise_discrete(
scaled_latent_image,
latent_noise,
Expand Down Expand Up @@ -231,7 +235,7 @@ def predict(

packed_predicted_flow = model.transformer(
hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()),
timestep=timestep / 1000,
timestep=model_timestep / 1000,
encoder_hidden_states=text_encoder_output.to(dtype=model.train_dtype.torch_dtype()),
txt_ids=text_ids.to(dtype=model.train_dtype.torch_dtype()),
img_ids=image_ids.to(dtype=model.train_dtype.torch_dtype()),
Expand Down
6 changes: 5 additions & 1 deletion modules/modelSetup/BaseFluxSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,14 @@ def predict(
model.train_dtype.torch_dtype()
)

# E-TSDM: Map the timestep for U-Net conditioning only.
# The physical noise addition and loss target remain based on the original 'timestep'.
model_timestep = self._apply_etsdm_timestep_mapping(timestep, config)

packed_latent_input = model.pack_latents(latent_input)
packed_predicted_flow = model.transformer(
hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()),
timestep=timestep / 1000,
timestep=model_timestep / 1000,
guidance=guidance.to(dtype=model.train_dtype.torch_dtype()),
pooled_projections=pooled_text_encoder_output.to(dtype=model.train_dtype.torch_dtype()),
encoder_hidden_states=text_encoder_output.to(dtype=model.train_dtype.torch_dtype()),
Expand Down
8 changes: 6 additions & 2 deletions modules/modelSetup/BaseStableDiffusionSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def predict(
config,
)

# E-TSDM: Map the timestep for U-Net conditioning only.
# The physical noise addition and loss target remain based on the original 'timestep'.
model_timestep = self._apply_etsdm_timestep_mapping(timestep, config)

latent_noise = self._create_noise(
scaled_latent_image,
config,
Expand All @@ -206,14 +210,14 @@ def predict(
if config.model_type.has_depth_input():
predicted_latent_noise = model.unet(
latent_input.to(dtype=model.train_dtype.torch_dtype()),
timestep,
model_timestep,
text_encoder_output.to(dtype=model.train_dtype.torch_dtype()),
batch['latent_depth'].to(dtype=model.train_dtype.torch_dtype()),
).sample
else:
predicted_latent_noise = model.unet(
latent_input.to(dtype=model.train_dtype.torch_dtype()),
timestep,
model_timestep,
text_encoder_output.to(dtype=model.train_dtype.torch_dtype()),
).sample

Expand Down
6 changes: 5 additions & 1 deletion modules/modelSetup/BaseStableDiffusionXLSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def predict(
config,
)

# E-TSDM: Map the timestep for U-Net conditioning only.
# The physical noise addition and loss target remain based on the original 'timestep'.
model_timestep = self._apply_etsdm_timestep_mapping(timestep, config)

latent_noise = self._create_noise(
scaled_latent_image,
config,
Expand Down Expand Up @@ -279,7 +283,7 @@ def predict(
added_cond_kwargs = {"text_embeds": pooled_text_encoder_2_output, "time_ids": add_time_ids}
predicted_latent_noise = model.unet(
sample=latent_input.to(dtype=model.train_dtype.torch_dtype()),
timestep=timestep,
timestep=model_timestep,
encoder_hidden_states=text_encoder_output.to(dtype=model.train_dtype.torch_dtype()),
added_cond_kwargs=added_cond_kwargs,
).sample
Expand Down
69 changes: 68 additions & 1 deletion modules/modelSetup/mixin/ModelSetupNoiseMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from modules.util.enum.TimestepDistribution import TimestepDistribution

import torch
import torch.distributions
from torch import Generator, Tensor


Expand Down Expand Up @@ -118,6 +119,32 @@ def _create_noise(

return noise

def _apply_etsdm_timestep_mapping(self, timestep: Tensor, config: TrainConfig) -> Tensor:
"""
Applies Early Timestep-shared Diffusion Model (E-TSDM) logic.
Paper: Lipschitz Singularities in Diffusion Models (ICLR 2024)

Maps timesteps < t_tilde to shared values within n sub-intervals.
"""
if not getattr(config, 'etsdm_enabled', True):
return timestep

# Paper defaults: t_tilde=100, n=5 (for T=1000)
t_tilde = getattr(config, 'etsdm_t_tilde', 100)
n = getattr(config, 'etsdm_n', 5)

interval_size = t_tilde // n

# Create mask for steps within the "early" region [0, t_tilde)
mask = timestep < t_tilde

# Calculate shared timestep: floor(t / interval) * interval
# This maps [0, 19] -> 0, [20, 39] -> 20, etc.
mapped_timestep = (timestep // interval_size) * interval_size

# Apply mapping only where mask is True
return torch.where(mask, mapped_timestep, timestep)

def _get_timestep_discrete(
self,
num_train_timesteps: int,
Expand Down Expand Up @@ -145,7 +172,8 @@ def _get_timestep_discrete(
if config.timestep_distribution in [
TimestepDistribution.UNIFORM,
TimestepDistribution.LOGIT_NORMAL,
TimestepDistribution.HEAVY_TAIL
TimestepDistribution.HEAVY_TAIL,
TimestepDistribution.BETA
]:
# continuous implementations
if config.timestep_distribution == TimestepDistribution.UNIFORM:
Expand All @@ -168,6 +196,45 @@ def _get_timestep_discrete(
)
u = 1.0 - u - scale * (torch.cos(math.pi / 2.0 * u) ** 2.0 - 1.0 + u)
timestep = u * num_timestep + min_timestep
elif config.timestep_distribution == TimestepDistribution.BETA:
# B-TTDM Configuration
# Noising Weight -> Alpha
# Noising Bias -> Beta
alpha = max(1e-4, config.noising_weight)
beta = max(1e-4, config.noising_bias)

# B-TTDM Paper optimization (Section 3.3):
# They strictly recommend Beta=1 and Alpha < 1.
# When Beta=1, we can use Inverse Transform Sampling (CDF inversion)
# which allows us to use torch.rand with the generator
# CDF^(-1)(u) = u^(1/alpha)
if abs(beta - 1.0) < 1e-5:
u = torch.rand(batch_size, generator=generator, device=generator.device)
u = u.pow(1.0 / alpha)
timestep = u * num_timestep + min_timestep

# Inverse case: Alpha=1, Beta != 1
# x = 1 - u^(1/beta)
elif abs(alpha - 1.0) < 1e-5:
u = torch.rand(
batch_size,
generator=generator,
device=generator.device)
u = 1.0 - u.pow(1.0 / beta)
timestep = u * num_timestep + min_timestep

else:
# Fallback for arbitrary Beta values (Beta != 1 and Alpha != 1).
# PyTorch's Beta distribution does not accept a generator directly.
# This path is mathematically correct for distribution shape, but
# technically bypasses the generator seed (uses global device seed).
# Since B-TTDM requires Beta=1, this path is rarely taken for this specific paper.
m = torch.distributions.Beta(
torch.tensor(alpha, device=generator.device),
torch.tensor(beta, device=generator.device)
)
u = m.sample((batch_size,))
timestep = u * num_timestep + min_timestep

timestep = num_train_timesteps * shift * timestep / ((shift - 1) * timestep + num_train_timesteps)
else:
Expand Down
1 change: 1 addition & 0 deletions modules/ui/OptimizerParamsWindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def create_dynamic_ui(
'approx_mars': {'title': 'Approx MARS-M', 'tooltip': 'Enables Approximated MARS-M, a variance reduction technique. It uses the previous step\'s gradient to correct the current update, leading to lower losses and improved convergence stability. This requires additional state to store the previous gradient.', 'type': 'bool'},
'kappa_p': {'title': 'Lion-K P-value', 'tooltip': 'Controls the Lp-norm geometry for the Lion update. 1.0 = Standard Lion (Sign update, coordinate-wise), best for Transformers. 2.0 = Spherical Lion (Normalized update, rotational invariant), best for Conv2d layers (in unet models). Values between 1.0 and 2.0 interpolate behavior between the two.', 'type': 'float'},
'auto_kappa_p': {'title': 'Auto Lion-K', 'tooltip': 'Automatically determines the optimal P-value based on layer dimensions. Uses p=2.0 (Spherical) for 4D (Conv) tensors for stability and rotational invariance, and p=1.0 (Sign) for 2D (Linear) tensors for sparsity. Overrides the manual P-value. Recommend for unet models.', 'type': 'bool'},
'compiled_optimizer': {'title': 'Compiled Optimizer', 'tooltip': 'Enables PyTorch compilation for the optimizer internal step logic. This is intended to improve performance by allowing PyTorch to fuse operations and optimize the computational graph.', 'type': 'bool'},
}
# @formatter:on

Expand Down
2 changes: 2 additions & 0 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class TrainOptimizerConfig(BaseConfig):
approx_mars: False
kappa_p: float
auto_kappa_p: False
compiled_optimizer: False

def __init__(self, data: list[(str, Any, type, bool)]):
super().__init__(data)
Expand Down Expand Up @@ -261,6 +262,7 @@ def default_values():
data.append(("approx_mars", False, bool, False))
data.append(("kappa_p", None, float, True))
data.append(("auto_kappa_p", False, bool, False))
data.append(("compiled_optimizer", False, bool, False))

return TrainOptimizerConfig(data)

Expand Down
8 changes: 8 additions & 0 deletions modules/util/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,7 @@ def create_optimizer(
alpha=optimizer_config.alpha if optimizer_config.alpha is not None else 5,
kourkoutas_beta=optimizer_config.kourkoutas_beta if optimizer_config.kourkoutas_beta is not None else False,
k_warmup_steps=optimizer_config.k_warmup_steps if optimizer_config.k_warmup_steps is not None else 0,
compiled_optimizer=optimizer_config.compiled_optimizer if optimizer_config.compiled_optimizer is not None else False,
)

# ADOPT_ADV Optimizer
Expand All @@ -1106,6 +1107,7 @@ def create_optimizer(
alpha_grad=optimizer_config.alpha_grad if optimizer_config.alpha_grad is not None else 100,
kourkoutas_beta=optimizer_config.kourkoutas_beta if optimizer_config.kourkoutas_beta is not None else False,
k_warmup_steps=optimizer_config.k_warmup_steps if optimizer_config.k_warmup_steps is not None else 0,
compiled_optimizer=optimizer_config.compiled_optimizer if optimizer_config.compiled_optimizer is not None else False,
)

# PRODIGY_ADV Optimizer
Expand Down Expand Up @@ -1139,6 +1141,7 @@ def create_optimizer(
alpha_grad=optimizer_config.alpha_grad if optimizer_config.alpha_grad is not None else 100,
kourkoutas_beta=optimizer_config.kourkoutas_beta if optimizer_config.kourkoutas_beta is not None else False,
k_warmup_steps=optimizer_config.k_warmup_steps if optimizer_config.k_warmup_steps is not None else 0,
compiled_optimizer=optimizer_config.compiled_optimizer if optimizer_config.compiled_optimizer is not None else False,
)

# SIMPLIFIED_AdEMAMix Optimizer
Expand All @@ -1161,6 +1164,7 @@ def create_optimizer(
orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False,
kourkoutas_beta=optimizer_config.kourkoutas_beta if optimizer_config.kourkoutas_beta is not None else False,
k_warmup_steps=optimizer_config.k_warmup_steps if optimizer_config.k_warmup_steps is not None else 0,
compiled_optimizer=optimizer_config.compiled_optimizer if optimizer_config.compiled_optimizer is not None else False,
)

# LION_ADV Optimizer
Expand All @@ -1180,6 +1184,7 @@ def create_optimizer(
orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False,
kappa_p=optimizer_config.kappa_p if optimizer_config.kappa_p is not None else 1.0,
auto_kappa_p=optimizer_config.auto_kappa_p if optimizer_config.auto_kappa_p is not None else False,
compiled_optimizer=optimizer_config.compiled_optimizer if optimizer_config.compiled_optimizer is not None else False,
)

# LION_PRODIGY_ADV Optimizer
Expand All @@ -1206,6 +1211,7 @@ def create_optimizer(
orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False,
kappa_p=optimizer_config.kappa_p if optimizer_config.kappa_p is not None else 1.0,
auto_kappa_p=optimizer_config.auto_kappa_p if optimizer_config.auto_kappa_p is not None else False,
compiled_optimizer=optimizer_config.compiled_optimizer if optimizer_config.compiled_optimizer is not None else False,
)

# MUON_ADV Optimizer
Expand Down Expand Up @@ -1254,6 +1260,7 @@ def create_optimizer(
accelerated_ns=optimizer_config.accelerated_ns if optimizer_config.accelerated_ns is not None else False,
orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False,
approx_mars=optimizer_config.approx_mars if optimizer_config.approx_mars is not None else False,
compiled_optimizer=optimizer_config.compiled_optimizer if optimizer_config.compiled_optimizer is not None else False,
**adam_kwargs
)

Expand Down Expand Up @@ -1307,6 +1314,7 @@ def create_optimizer(
accelerated_ns=optimizer_config.accelerated_ns if optimizer_config.accelerated_ns is not None else False,
orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False,
approx_mars=optimizer_config.approx_mars if optimizer_config.approx_mars is not None else False,
compiled_optimizer=optimizer_config.compiled_optimizer if optimizer_config.compiled_optimizer is not None else False,
**adam_kwargs
)

Expand Down
1 change: 1 addition & 0 deletions modules/util/enum/TimestepDistribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class TimestepDistribution(Enum):
HEAVY_TAIL = 'HEAVY_TAIL'
COS_MAP = 'COS_MAP'
INVERTED_PARABOLA = 'INVERTED_PARABOLA'
BETA = 'BETA'

def __str__(self):
return self.value
8 changes: 8 additions & 0 deletions modules/util/optimizer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def init_model_parameters(
"use_bias_correction": True,
"nnmf_factor": False,
"stochastic_rounding": True,
"compiled_optimizer": False,
"fused_back_pass": False,
"use_atan2": False,
"cautious_mask": False,
Expand All @@ -476,6 +477,7 @@ def init_model_parameters(
"weight_decay": 0.0,
"nnmf_factor": False,
"stochastic_rounding": True,
"compiled_optimizer": False,
"fused_back_pass": False,
"use_atan2": False,
"cautious_mask": False,
Expand All @@ -498,6 +500,7 @@ def init_model_parameters(
"weight_decay": 0.0,
"nnmf_factor": False,
"stochastic_rounding": True,
"compiled_optimizer": False,
"fused_back_pass": False,
"d0": 1e-6,
"d_coef": 1.0,
Expand Down Expand Up @@ -529,6 +532,7 @@ def init_model_parameters(
"use_bias_correction": True,
"nnmf_factor": False,
"stochastic_rounding": True,
"compiled_optimizer": False,
"fused_back_pass": False,
"orthogonal_gradient": False,
"kourkoutas_beta": False,
Expand All @@ -542,6 +546,7 @@ def init_model_parameters(
"clip_threshold": None,
"nnmf_factor": False,
"stochastic_rounding": True,
"compiled_optimizer": False,
"fused_back_pass": False,
"cautious_mask": False,
"orthogonal_gradient": False,
Expand All @@ -557,6 +562,7 @@ def init_model_parameters(
"clip_threshold": None,
"nnmf_factor": False,
"stochastic_rounding": True,
"compiled_optimizer": False,
"fused_back_pass": False,
"d0": 1e-6,
"d_coef": 1.0,
Expand All @@ -580,6 +586,7 @@ def init_model_parameters(
"rms_rescaling": True,
"nnmf_factor": False,
"stochastic_rounding": True,
"compiled_optimizer": False,
"fused_back_pass": False,
"MuonWithAuxAdam": True,
"muon_hidden_layers": None,
Expand Down Expand Up @@ -610,6 +617,7 @@ def init_model_parameters(
"rms_rescaling": True,
"nnmf_factor": False,
"stochastic_rounding": True,
"compiled_optimizer": False,
"fused_back_pass": False,
"MuonWithAuxAdam": True,
"muon_hidden_layers": None,
Expand Down
2 changes: 1 addition & 1 deletion requirements-global.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ prodigyopt==1.1.2 # prodigy optimizer
schedulefree==1.4.1 # schedule-free optimizers
pytorch_optimizer==3.6.0 # pytorch optimizers
prodigy-plus-schedule-free==2.0.1 # Prodigy plus optimizer
adv_optm==1.4.1 # advanced optimizers
adv_optm==2.1.dev2 # advanced optimizers
-e git+https://github.com/KellerJordan/Muon.git@f90a42b#egg=muon-optimizer

# Profiling
Expand Down