Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions modules/modelSetup/BaseChromaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ def predict(
text_encoder_dropout_probability=config.text_encoder.dropout_probability,
)

if config.cep_enabled:
text_encoder_output = self._apply_conditional_embedding_perturbation(
text_encoder_output, config.cep_gamma, generator
)

latent_image = batch['latent_image']
scaled_latent_image = (latent_image - vae_shift_factor) * vae_scaling_factor

Expand Down
8 changes: 8 additions & 0 deletions modules/modelSetup/BaseFluxSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ def predict(
apply_attention_mask=config.transformer.attention_mask,
)

if config.cep_enabled:
text_encoder_output = self._apply_conditional_embedding_perturbation(
text_encoder_output, config.cep_gamma, generator
)
pooled_text_encoder_output = self._apply_conditional_embedding_perturbation(
pooled_text_encoder_output, config.cep_gamma, generator
)

latent_image = batch['latent_image']
scaled_latent_image = (latent_image - vae_shift_factor) * vae_scaling_factor

Expand Down
11 changes: 11 additions & 0 deletions modules/modelSetup/BaseHiDreamSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,17 @@ def predict(
apply_attention_mask=config.transformer.attention_mask,
))

if config.cep_enabled:
text_encoder_3_output = self._apply_conditional_embedding_perturbation(
text_encoder_3_output, config.cep_gamma, generator
)
text_encoder_4_output = self._apply_conditional_embedding_perturbation(
text_encoder_4_output, config.cep_gamma, generator
)
pooled_text_encoder_output = self._apply_conditional_embedding_perturbation(
pooled_text_encoder_output, config.cep_gamma, generator
)

latent_image = batch['latent_image']
scaled_latent_image = (latent_image - vae_shift_factor) * vae_scaling_factor

Expand Down
8 changes: 8 additions & 0 deletions modules/modelSetup/BaseHunyuanVideoSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,14 @@ def predict(
text_encoder_2_dropout_probability=config.text_encoder_2.dropout_probability,
)

if config.cep_enabled:
text_encoder_output = self._apply_conditional_embedding_perturbation(
text_encoder_output, config.cep_gamma, generator
)
pooled_text_encoder_output = self._apply_conditional_embedding_perturbation(
pooled_text_encoder_output, config.cep_gamma, generator
)

latent_image = batch['latent_image']
scaled_latent_image = latent_image * vae_scaling_factor

Expand Down
5 changes: 5 additions & 0 deletions modules/modelSetup/BasePixArtAlphaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ def predict(
text_encoder_dropout_probability=config.text_encoder.dropout_probability,
)

if config.cep_enabled:
text_encoder_output = self._apply_conditional_embedding_perturbation(
text_encoder_output, config.cep_gamma, generator
)

latent_image = batch['latent_image']
scaled_latent_image = latent_image * vae_scaling_factor

Expand Down
5 changes: 5 additions & 0 deletions modules/modelSetup/BaseQwenSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def predict(
text_encoder_dropout_probability=config.text_encoder.dropout_probability,
)

if config.cep_enabled:
text_encoder_output = self._apply_conditional_embedding_perturbation(
text_encoder_output, config.cep_gamma, generator
)

latent_image = batch['latent_image']
scaled_latent_image = model.scale_latents(latent_image)
latent_noise = self._create_noise(scaled_latent_image, config, generator)
Expand Down
5 changes: 5 additions & 0 deletions modules/modelSetup/BaseSanaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ def predict(
text_encoder_dropout_probability=config.text_encoder.dropout_probability,
)

if config.cep_enabled:
text_encoder_output = self._apply_conditional_embedding_perturbation(
text_encoder_output, config.cep_gamma, generator
)

latent_image = batch['latent_image']
scaled_latent_image = latent_image * vae_scaling_factor

Expand Down
8 changes: 8 additions & 0 deletions modules/modelSetup/BaseStableDiffusion3Setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,14 @@ def predict(
apply_attention_mask=config.transformer.attention_mask,
))

if config.cep_enabled:
text_encoder_output = self._apply_conditional_embedding_perturbation(
text_encoder_output, config.cep_gamma, generator
)
pooled_text_encoder_output = self._apply_conditional_embedding_perturbation(
pooled_text_encoder_output, config.cep_gamma, generator
)

latent_image = batch['latent_image']
scaled_latent_image = (latent_image - vae_shift_factor) * vae_scaling_factor

Expand Down
5 changes: 5 additions & 0 deletions modules/modelSetup/BaseStableDiffusionSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ def predict(
text_encoder_dropout_probability=config.text_encoder.dropout_probability,
)

if config.cep_enabled:
text_encoder_output = self._apply_conditional_embedding_perturbation(
text_encoder_output, config.cep_gamma, generator
)

latent_image = batch['latent_image']
scaled_latent_image = latent_image * vae_scaling_factor

Expand Down
8 changes: 8 additions & 0 deletions modules/modelSetup/BaseStableDiffusionXLSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ def predict(
text_encoder_2_dropout_probability=config.text_encoder_2.dropout_probability,
))

if config.cep_enabled:
text_encoder_output = self._apply_conditional_embedding_perturbation(
text_encoder_output, config.cep_gamma, generator
)
pooled_text_encoder_2_output = self._apply_conditional_embedding_perturbation(
pooled_text_encoder_2_output, config.cep_gamma, generator
)

latent_image = batch['latent_image']
scaled_latent_image = latent_image * vae_scaling_factor

Expand Down
8 changes: 8 additions & 0 deletions modules/modelSetup/BaseWuerstchenSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ def predict(
text_encoder_dropout_probability=config.text_encoder.dropout_probability,
)

if config.cep_enabled:
text_embedding = self._apply_conditional_embedding_perturbation(
text_embedding, config.cep_gamma, generator
)
pooled_text_text_embedding = self._apply_conditional_embedding_perturbation(
pooled_text_text_embedding, config.cep_gamma, generator
)

latent_input = scaled_noisy_latent_image

if model.model_type.is_wuerstchen_v2():
Expand Down
6 changes: 6 additions & 0 deletions modules/modelSetup/BaseZImageSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def predict(
text_encoder_output=batch.get('text_encoder_hidden_state'),
text_encoder_dropout_probability=config.text_encoder.dropout_probability,
)

if config.cep_enabled:
text_encoder_output = self._apply_conditional_embedding_perturbation(
text_encoder_output, config.cep_gamma, generator
)

scaled_latent_image = model.scale_latents(batch['latent_image'])

latent_noise = self._create_noise(scaled_latent_image, config, generator)
Expand Down
30 changes: 30 additions & 0 deletions modules/modelSetup/mixin/ModelSetupNoiseMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,36 @@ def _create_noise(

return noise

def _apply_conditional_embedding_perturbation(
self,
embedding: Tensor,
gamma: float,
generator: Generator
) -> Tensor:
"""
Applies Conditional Embedding Perturbation (CEP) as per Equation (8).
Paper: "Slight Corruption in Pre-training Data Makes Better Diffusion Models"

delta ~ U(-sqrt(gamma/d), sqrt(gamma/d)) or N(0, sqrt(gamma/d))
"""
# d denotes the dimension of c_theta(y)
d = embedding.shape[-1]

# gamma controls perturbation magnitude (Paper uses gamma=1.0 as default baseline)
# Calculate scaling factor: sqrt(gamma / d)
scale = math.sqrt(gamma / d)

# CEP-U (Uniform) scheme
noise = torch.rand(
embedding.shape,
generator=generator,
device=embedding.device,
dtype=embedding.dtype
)
perturbation = (noise * 2.0 - 1.0) * scale

return embedding + perturbation

def _get_timestep_discrete(
self,
num_train_timesteps: int,
Expand Down
10 changes: 9 additions & 1 deletion modules/ui/TrainingTab.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,15 @@ def __create_noise_frame(self, master, row, supports_generalized_offset_noise: b
tooltip="Dynamically shift the timestep distribution based on resolution.")
components.switch(frame, 9, 1, self.ui_state, "dynamic_timestep_shifting")


# Conditional Embedding Perturbation (CEP)
cep_label = components.label(frame, 10, 0, "Conditional Embedding Perturbation (CEP)",
tooltip="Inject a slight noise into the TEs outputs to enhance the quality, diversity, and fidelity of the generated images.")
cep_label.configure(wraplength=130, justify="left")
components.switch(frame, 10, 1, self.ui_state, "cep_enabled")

components.label(frame, 11, 0, "CEP Gamma",
tooltip="Gamma controls perturbation noise magnitude, paper's default is 1. Only has an effect if CEP is enabled")
components.entry(frame, 11, 1, self.ui_state, "cep_gamma")

def __create_masked_frame(self, master, row):
frame = ctk.CTkFrame(master=master, corner_radius=5)
Expand Down
4 changes: 4 additions & 0 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ class TrainConfig(BaseConfig):
timestep_distribution: TimestepDistribution
min_noising_strength: float
max_noising_strength: float
cep_enabled: bool
cep_gamma: float

noising_weight: float
noising_bias: float
Expand Down Expand Up @@ -1032,6 +1034,8 @@ def default_values() -> 'TrainConfig':
data.append(("noising_bias", 0.0, float, False))
data.append(("timestep_shift", 1.0, float, False))
data.append(("dynamic_timestep_shifting", False, bool, False))
data.append(("cep_enabled", False, bool, False))
data.append(("cep_gamma", 1.0, float, False))


# unet
Expand Down