Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions nemo/collections/diffusion/models/dit/dit_layer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,27 +122,26 @@ def __init__(
else:
self.ln = norm(config.hidden_size, elementwise_affine=False, eps=self.config.layernorm_epsilon)
self.n_adaln_chunks = n_adaln_chunks
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
ColumnParallelLinear(
config.hidden_size,
self.n_adaln_chunks * config.hidden_size,
config=config,
init_method=nn.init.normal_,
bias=modulation_bias,
gather_output=True,
),
self.activation = nn.SiLU()
self.linear = ColumnParallelLinear(
config.hidden_size,
self.n_adaln_chunks * config.hidden_size,
config=config,
init_method=nn.init.normal_,
bias=modulation_bias,
gather_output=True,
)
self.use_second_norm = use_second_norm
if self.use_second_norm:
self.ln2 = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6)
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.linear.weight, 0)

setattr(self.adaLN_modulation[-1].weight, "sequence_parallel", config.sequence_parallel)
setattr(self.linear.weight, "sequence_parallel", config.sequence_parallel)

@jit_fuser
def forward(self, timestep_emb):
output, bias = self.adaLN_modulation(timestep_emb)
timestep_emb = self.activation(timestep_emb)
output, bias = self.linear(timestep_emb)
output = output + bias if bias else output
return output.chunk(self.n_adaln_chunks, dim=-1)

Expand Down
12 changes: 6 additions & 6 deletions nemo/collections/diffusion/models/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,10 +816,10 @@ def config(self) -> FluxConfig:
def convert_state(self, source, target):
# pylint: disable=C0301
mapping = {
'transformer_blocks.*.norm1.linear.weight': 'double_blocks.*.adaln.adaLN_modulation.1.weight',
'transformer_blocks.*.norm1.linear.bias': 'double_blocks.*.adaln.adaLN_modulation.1.bias',
'transformer_blocks.*.norm1_context.linear.weight': 'double_blocks.*.adaln_context.adaLN_modulation.1.weight',
'transformer_blocks.*.norm1_context.linear.bias': 'double_blocks.*.adaln_context.adaLN_modulation.1.bias',
'transformer_blocks.*.norm1.linear.weight': 'double_blocks.*.adaln.linear.weight',
'transformer_blocks.*.norm1.linear.bias': 'double_blocks.*.adaln.linear.bias',
'transformer_blocks.*.norm1_context.linear.weight': 'double_blocks.*.adaln_context.linear.weight',
'transformer_blocks.*.norm1_context.linear.bias': 'double_blocks.*.adaln_context.linear.bias',
'transformer_blocks.*.attn.norm_q.weight': 'double_blocks.*.self_attention.q_layernorm.weight',
'transformer_blocks.*.attn.norm_k.weight': 'double_blocks.*.self_attention.k_layernorm.weight',
'transformer_blocks.*.attn.norm_added_q.weight': 'double_blocks.*.self_attention.added_q_layernorm.weight',
Expand All @@ -836,8 +836,8 @@ def convert_state(self, source, target):
'transformer_blocks.*.ff_context.net.0.proj.bias': 'double_blocks.*.context_mlp.linear_fc1.bias',
'transformer_blocks.*.ff_context.net.2.weight': 'double_blocks.*.context_mlp.linear_fc2.weight',
'transformer_blocks.*.ff_context.net.2.bias': 'double_blocks.*.context_mlp.linear_fc2.bias',
'single_transformer_blocks.*.norm.linear.weight': 'single_blocks.*.adaln.adaLN_modulation.1.weight',
'single_transformer_blocks.*.norm.linear.bias': 'single_blocks.*.adaln.adaLN_modulation.1.bias',
'single_transformer_blocks.*.norm.linear.weight': 'single_blocks.*.adaln.linear.weight',
'single_transformer_blocks.*.norm.linear.bias': 'single_blocks.*.adaln.linear.bias',
'single_transformer_blocks.*.proj_mlp.weight': 'single_blocks.*.mlp.linear_fc1.weight',
'single_transformer_blocks.*.proj_mlp.bias': 'single_blocks.*.mlp.linear_fc1.bias',
'single_transformer_blocks.*.attn.norm_q.weight': 'single_blocks.*.self_attention.q_layernorm.weight',
Expand Down
41 changes: 37 additions & 4 deletions nemo/collections/diffusion/models/flux/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def __init__(
self.scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=scheduler_steps)
self.params = params

def load_from_pretrained(self, ckpt_path, do_convert_from_hf=True, save_converted_model_to=None):
def load_from_pretrained(
self, ckpt_path, do_convert_from_hf=True, save_converted_model_to=None, load_dist_ckpt=False
):
"""
Loads the model's weights from a checkpoint. If HF ckpt is provided, it will be converted to NeMo
format and save it to local folder.
Expand All @@ -175,11 +177,23 @@ def load_from_pretrained(self, ckpt_path, do_convert_from_hf=True, save_converte
Whether to convert the checkpoint from Hugging Face format before loading. Default is True.
save_converted_model_to (str, optional):
Path to save the converted checkpoint if `do_convert_from_hf` is True. Default is None.
load_dist_ckpt (bool, optional):
Whether to load the checkpoint from dist.ckpt format (NeMo2 checkpoint). Default is False.

Logs:
The function logs information about missing or unexpected keys during checkpoint loading.
"""
if do_convert_from_hf:
assert not (do_convert_from_hf and load_dist_ckpt), 'do_convert_from_hf and load_dist_ckpt cannot both be true'

if load_dist_ckpt:
from megatron.core import dist_checkpointing

sharded_state_dict = dict(state_dict=self.transformer.sharded_state_dict(prefix="module."))
loaded_state_dict = dist_checkpointing.load(
sharded_state_dict=sharded_state_dict, checkpoint_dir=ckpt_path
)
ckpt = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()}
elif do_convert_from_hf:
ckpt = flux_transformer_converter(ckpt_path, self.transformer.config)
if save_converted_model_to is not None:
save_path = os.path.join(save_converted_model_to, 'nemo_flux_transformer.safetensors')
Expand All @@ -196,6 +210,11 @@ def load_from_pretrained(self, ckpt_path, do_convert_from_hf=True, save_converte
f"please check the ckpt provided or the image quality may be compromised.\n {missing}"
)
logging.info(f"Found unexepected keys: \n {unexpected}")
if len(unexpected) > 0:
logging.info(
f"The following keys are unexpected during checkpoint loading, "
f"please check the ckpt provided or the image quality may be compromised.\n {unexpected}"
)

def encoder_prompt(
self,
Expand Down Expand Up @@ -685,12 +704,26 @@ def __init__(
self.flux_controlnet = FluxControlNet(contorlnet_config) if flux_controlnet is None else flux_controlnet

def load_from_pretrained(
self, flux_ckpt_path, controlnet_ckpt_path, do_convert_from_hf=True, save_converted_model_to=None
self,
flux_ckpt_path,
controlnet_ckpt_path,
do_convert_from_hf=True,
save_converted_model_to=None,
load_dist_ckpt=False,
):
'''
Converts both flux base model and flux controlnet ckpt into NeMo format.
'''
if do_convert_from_hf:
assert not (do_convert_from_hf and load_dist_ckpt), 'do_convert_from_hf and load_dist_ckpt cannot both be true'
if load_dist_ckpt:
from megatron.core import dist_checkpointing

sharded_state_dict = dict(state_dict=self.transformer.sharded_state_dict(prefix="module."))
loaded_state_dict = dist_checkpointing.load(
sharded_state_dict=sharded_state_dict, checkpoint_dir=flux_ckpt_path
)
flux_ckpt = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()}
elif do_convert_from_hf:
flux_ckpt = flux_transformer_converter(flux_ckpt_path, self.transformer.config)
flux_controlnet_ckpt = flux_transformer_converter(controlnet_ckpt_path, self.flux_controlnet.config)

Expand Down
15 changes: 6 additions & 9 deletions nemo/collections/diffusion/utils/flux_ckpt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def _import_qkv(transformer_config, q, k, v):

flux_key_mapping = {
'double_blocks': {
'norm1.linear.weight': 'adaln.adaLN_modulation.1.weight',
'norm1.linear.bias': 'adaln.adaLN_modulation.1.bias',
'norm1_context.linear.weight': 'adaln_context.adaLN_modulation.1.weight',
'norm1_context.linear.bias': 'adaln_context.adaLN_modulation.1.bias',
'norm1.linear.weight': 'adaln.linear.weight',
'norm1.linear.bias': 'adaln.linear.bias',
'norm1_context.linear.weight': 'adaln_context.linear.weight',
'norm1_context.linear.bias': 'adaln_context.linear.bias',
'attn.norm_q.weight': 'self_attention.q_layernorm.weight',
'attn.norm_k.weight': 'self_attention.k_layernorm.weight',
'attn.norm_added_q.weight': 'self_attention.added_q_layernorm.weight',
Expand All @@ -102,8 +102,8 @@ def _import_qkv(transformer_config, q, k, v):
'ff_context.net.2.bias': 'context_mlp.linear_fc2.bias',
},
'single_blocks': {
'norm.linear.weight': 'adaln.adaLN_modulation.1.weight',
'norm.linear.bias': 'adaln.adaLN_modulation.1.bias',
'norm.linear.weight': 'adaln.linear.weight',
'norm.linear.bias': 'adaln.linear.bias',
'proj_mlp.weight': 'mlp.linear_fc1.weight',
'proj_mlp.bias': 'mlp.linear_fc1.bias',
# 'proj_out.weight': 'proj_out.weight',
Expand Down Expand Up @@ -219,8 +219,5 @@ def flux_transformer_converter(ckpt_path=None, transformer_config=None):
new_state_dict[f'single_blocks.{str(i)}.mlp.linear_fc2.bias'] = (
diffuser_state_dict[f'single_transformer_blocks.{str(i)}.proj_out.bias'].detach().clone()
)
new_state_dict[f'single_blocks.{str(i)}.self_attention.linear_proj.bias'] = (
diffuser_state_dict[f'single_transformer_blocks.{str(i)}.proj_out.bias'].detach().clone()
)

return new_state_dict
7 changes: 7 additions & 0 deletions scripts/flux/flux_controlnet_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def parse_args():
default=False,
help="Must be true if provided checkpoint is not already converted to NeMo version",
)
parser.add_argument(
"--load_dist_ckpt",
action='store_true',
default=False,
help="Load distributed checkpoint for Flux",
)
parser.add_argument(
"--save_converted_model_to",
type=str,
Expand Down Expand Up @@ -143,6 +149,7 @@ def parse_args():
args.controlnet_ckpt,
do_convert_from_hf=args.do_convert_from_hf,
save_converted_model_to=args.save_converted_model_to,
load_dist_ckpt=args.load_dist_ckpt,
)
dtype = torch.float32
text = args.prompts.split(',')
Expand Down
7 changes: 7 additions & 0 deletions scripts/flux/flux_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def parse_args():
default=False,
help="Must be true if provided checkpoint is not already converted to NeMo version",
)
parser.add_argument(
"--load_dist_ckpt",
action='store_true',
default=False,
help="Load distributed checkpoint for Flux",
)
parser.add_argument(
"--save_converted_model_to",
type=str,
Expand Down Expand Up @@ -118,6 +124,7 @@ def parse_args():
args.flux_ckpt,
do_convert_from_hf=args.do_convert_from_hf,
save_converted_model_to=args.save_converted_model_to,
load_dist_ckpt=args.load_dist_ckpt,
)
dtype = torch.float32
text = args.prompts.split(',')
Expand Down
Loading