Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training Performance Optimization for flux_controlnet #12097

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions nemo/collections/diffusion/models/dit/dit_layer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn as nn
from einops import rearrange
from megatron.core.jit import jit_fuser
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
from megatron.core.transformer.attention import (
CrossAttention,
CrossAttentionSubmodules,
Expand Down Expand Up @@ -97,7 +98,15 @@ def __init__(
self.ln = norm(config.hidden_size, elementwise_affine=False, eps=self.config.layernorm_epsilon)
self.n_adaln_chunks = n_adaln_chunks
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(config.hidden_size, self.n_adaln_chunks * config.hidden_size, bias=modulation_bias)
nn.SiLU(),
ColumnParallelLinear(
config.hidden_size,
self.n_adaln_chunks * config.hidden_size,
config=config,
init_method=nn.init.normal_,
bias=modulation_bias,
gather_output=True,
),
)
self.use_second_norm = use_second_norm
if self.use_second_norm:
Expand All @@ -106,18 +115,21 @@ def __init__(

setattr(self.adaLN_modulation[-1].weight, "sequence_parallel", config.sequence_parallel)

@jit_fuser
def forward(self, timestep_emb):
return self.adaLN_modulation(timestep_emb).chunk(self.n_adaln_chunks, dim=-1)
output, bias = self.adaLN_modulation(timestep_emb)
output = output + bias if bias else output
return output.chunk(self.n_adaln_chunks, dim=-1)

# @jit_fuser
@jit_fuser
def modulate(self, x, shift, scale):
return x * (1 + scale) + shift

# @jit_fuser
@jit_fuser
def scale_add(self, residual, x, gate):
return residual + gate * x

# @jit_fuser
@jit_fuser
def modulated_layernorm(self, x, shift, scale, layernorm_idx=0):
if self.use_second_norm and layernorm_idx == 1:
layernorm = self.ln2
Expand Down Expand Up @@ -542,15 +554,17 @@ def __init__(
hidden_size = config.hidden_size
super().__init__(config=config, submodules=submodules, layer_number=layer_number)

self.adaln = AdaLN(config, modulation_bias=True, n_adaln_chunks=6, use_second_norm=True)
self.adaln = AdaLN(config, norm=TENorm, modulation_bias=True, n_adaln_chunks=6, use_second_norm=True)

self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continuous" if context_pre_only else "ada_norm_zero"

if context_norm_type == "ada_norm_continuous":
self.adaln_context = AdaLNContinuous(config, hidden_size, modulation_bias=True, norm_type="layer_norm")
elif context_norm_type == "ada_norm_zero":
self.adaln_context = AdaLN(config, modulation_bias=True, n_adaln_chunks=6, use_second_norm=True)
self.adaln_context = AdaLN(
config, norm=TENorm, modulation_bias=True, n_adaln_chunks=6, use_second_norm=True
)
else:
raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, "
Expand Down Expand Up @@ -641,7 +655,11 @@ def __init__(
):
super().__init__(config=config, submodules=submodules, layer_number=layer_number)
self.adaln = AdaLN(
config=config, n_adaln_chunks=n_adaln_chunks, modulation_bias=modulation_bias, use_second_norm=False
config=config,
norm=TENorm,
n_adaln_chunks=n_adaln_chunks,
modulation_bias=modulation_bias,
use_second_norm=False,
)

def forward(
Expand Down Expand Up @@ -835,8 +853,8 @@ def get_flux_single_transformer_engine_spec() -> ModuleSpec:
submodules=SelfAttentionSubmodules(
linear_qkv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
q_layernorm=RMSNorm,
k_layernorm=RMSNorm,
q_layernorm=TENorm,
k_layernorm=TENorm,
linear_proj=TERowParallelLinear,
),
),
Expand All @@ -859,10 +877,10 @@ def get_flux_double_transformer_engine_spec() -> ModuleSpec:
module=JointSelfAttention,
params={"attn_mask_type": AttnMaskType.no_mask},
submodules=JointSelfAttentionSubmodules(
q_layernorm=RMSNorm,
k_layernorm=RMSNorm,
added_q_layernorm=RMSNorm,
added_k_layernorm=RMSNorm,
q_layernorm=TENorm,
k_layernorm=TENorm,
added_q_layernorm=TENorm,
added_k_layernorm=TENorm,
linear_qkv=TEColumnParallelLinear,
added_linear_qkv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/diffusion/models/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
"""
assert dim % 2 == 0, "The dimension must be even."

# RoPE should be batch size independent (lifuz)
seq = torch.arange(pos.size()[1], device=pos.device, dtype=torch.float64)

scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)

out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.einsum("...n,d->...nd", seq, omega).unsqueeze(1)

return out.float()

Expand All @@ -52,7 +55,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-1,
)
emb = emb.unsqueeze(1).permute(2, 0, 1, 3)
emb = emb.unsqueeze(1) # .permute(2, 0, 1, 3)
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1)


Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/diffusion/models/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class FluxConfig(TransformerConfig, io.IOMixin):
patch_size: int = 1
guidance_embed: bool = False
vec_in_dim: int = 768
rotary_interleaved: bool = True
rotary_interleaved: bool = False
layernorm_epsilon: float = 1e-06
hidden_dropout: float = 0
attention_dropout: float = 0
Expand Down Expand Up @@ -729,7 +729,7 @@ def config(self) -> FluxConfig:
patch_size=source_config.patch_size,
guidance_embed=source_config.guidance_embeds,
vec_in_dim=source_config.pooled_projection_dim,
rotary_interleaved=True,
rotary_interleaved=False,
layernorm_epsilon=1e-06,
hidden_dropout=0,
attention_dropout=0,
Expand Down
33 changes: 28 additions & 5 deletions nemo/collections/diffusion/models/flux_controlnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.nn as nn
from megatron.core.models.common.vision_module.vision_module import VisionModule
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
from megatron.core.transformer.transformer_config import TransformerConfig
from torch.nn import functional as F

Expand Down Expand Up @@ -88,7 +89,7 @@ class FluxControlNetConfig(TransformerConfig, io.IOMixin):
num_mode: int = None
model_channels: int = 256
conditioning_embedding_channels: int = None
rotary_interleaved: bool = True
rotary_interleaved: bool = False
layernorm_epsilon: float = 1e-06
hidden_dropout: float = 0
attention_dropout: float = 0
Expand Down Expand Up @@ -161,11 +162,31 @@ def __init__(self, config: FluxControlNetConfig):
# ContolNet Blocks
self.controlnet_double_blocks = nn.ModuleList()
for _ in range(config.num_joint_layers):
self.controlnet_double_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
self.controlnet_double_blocks.append(
zero_module(
ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
config=config,
init_method=nn.init.normal_,
gather_output=True,
)
)
)

self.controlnet_single_blocks = nn.ModuleList()
for _ in range(config.num_single_layers):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
self.controlnet_single_blocks.append(
zero_module(
ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
config=config,
init_method=nn.init.normal_,
gather_output=True,
)
)
)

if config.conditioning_embedding_channels is not None:
self.input_hint_block = ControlNetConditioningEmbedding(
Expand Down Expand Up @@ -267,12 +288,14 @@ def forward(

controlnet_double_block_samples = ()
for double_block_sample, control_block in zip(double_block_samples, self.controlnet_double_blocks):
double_block_sample = control_block(double_block_sample)
double_block_sample, bias = control_block(double_block_sample)
double_block_sample = double_block_sample + bias if bias else double_block_sample
controlnet_double_block_samples += (double_block_sample,)

controlnet_single_block_samples = ()
for single_block_sample, control_block in zip(single_block_samples, self.controlnet_single_blocks):
single_block_sample = control_block(single_block_sample)
single_block_sample, bias = control_block(single_block_sample)
single_block_sample = single_block_sample + bias if bias else single_block_sample
controlnet_single_block_samples += (single_block_sample,)

controlnet_double_block_samples = [sample * conditioning_scale for sample in controlnet_double_block_samples]
Expand Down
6 changes: 3 additions & 3 deletions scripts/flux/flux_controlnet_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,12 @@ def full_model_tp2_dp4_mock() -> run.Partial:
recipe.model.flux_params.clip_params = None
recipe.model.flux_params.vae_config = None
recipe.model.flux_params.device = 'cuda'
recipe.trainer.strategy.tensor_model_parallel_size = 2
recipe.trainer.strategy.tensor_model_parallel_size = 1
recipe.trainer.devices = 8
recipe.data.global_batch_size = 8
recipe.trainer.callbacks.append(run.Config(NsysCallback, start_step=10, end_step=11, gen_shape=True))
recipe.model.flux_controlnet_config.num_single_layers = 10
recipe.model.flux_controlnet_config.num_joint_layers = 4
recipe.model.flux_controlnet_config.num_single_layers = 38
recipe.model.flux_controlnet_config.num_joint_layers = 19
return recipe


Expand Down