Skip to content
168 changes: 129 additions & 39 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
prepare_clip_model_sets,
evaluate_pipe,
UNET_EXTENDED_TARGET_REPLACE,
parse_safeloras_embeds,
apply_learned_embed_in_clip,
)

def preview_training_batch(train_dataloader, mode, n_imgs = 40):
Expand All @@ -67,6 +69,52 @@ def preview_training_batch(train_dataloader, mode, n_imgs = 40):
print(f"\nSaved {imgs_saved} preview training imgs to {outdir}")
return

def sim_matrix(a, b, eps=1e-8):
"""
added eps for numerical stability
"""
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
return sim_mt


def compute_pairwise_distances(x,y):
# compute the L2 distance of each row in x to each row in y (both are torch tensors)
# x is a torch tensor of shape (m, d)
# y is a torch tensor of shape (n, d)
# returns a torch tensor of shape (m, n)

n = y.shape[0]
m = x.shape[0]
d = x.shape[1]

x = x.unsqueeze(1).expand(m, n, d)
y = y.unsqueeze(0).expand(m, n, d)

return torch.pow(x - y, 2).sum(2)


def print_most_similar_tokens(tokenizer, optimized_token, text_encoder, n=10):
with torch.no_grad():
# get all the token embeddings:
token_embeds = text_encoder.get_input_embeddings().weight.data

# Compute the cosine-similarity between the optimized tokens and all the other tokens
similarity = sim_matrix(optimized_token.unsqueeze(0), token_embeds).squeeze()
similarity = similarity.detach().cpu().numpy()

distances = compute_pairwise_distances(optimized_token.unsqueeze(0), token_embeds).squeeze()
distances = distances.detach().cpu().numpy()

# print similarity for the most similar tokens:
most_similar_tokens = np.argsort(similarity)[::-1]

print(f"{tokenizer.decode(most_similar_tokens[0])} --> mean: {optimized_token.mean().item():.3f}, std: {optimized_token.std().item():.3f}, norm: {optimized_token.norm():.4f}")
for token_id in most_similar_tokens[1:n+1]:
print(f"sim of {similarity[token_id]:.3f} & L2 of {distances[token_id]:.3f} with \"{tokenizer.decode(token_id)}\"")


def get_models(
pretrained_model_name_or_path,
Expand Down Expand Up @@ -139,19 +187,21 @@ def get_models(
pretrained_vae_name_or_path or pretrained_model_name_or_path,
subfolder=None if pretrained_vae_name_or_path else "vae",
revision=None if pretrained_vae_name_or_path else revision,
local_files_only = True,
)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path,
subfolder="unet",
revision=revision,
local_files_only = True,
)

return (
text_encoder.to(device),
vae.to(device),
unet.to(device),
tokenizer,
placeholder_token_ids,
placeholder_token_ids
)


Expand Down Expand Up @@ -477,12 +527,13 @@ def train_inversion(

if global_step % accum_iter == 0:
# print gradient of text encoder embedding
print(
text_encoder.get_input_embeddings()
.weight.grad[index_updates, :]
.norm(dim=-1)
.mean()
)
if 0:
print(
text_encoder.get_input_embeddings()
.weight.grad[index_updates, :]
.norm(dim=-1)
.mean()
)
optimizer.step()
optimizer.zero_grad()

Expand Down Expand Up @@ -517,8 +568,10 @@ def train_inversion(
index_no_updates
] = orig_embeds_params[index_no_updates]

for i, t in enumerate(optimizing_embeds):
print(f"token {i} --> mean: {t.mean().item():.3f}, std: {t.std().item():.3f}, norm: {t.norm():.4f}")
if global_step % 50 == 0:
print("------------------------------")
for i, t in enumerate(optimizing_embeds):
print_most_similar_tokens(tokenizer, t, text_encoder)

global_step += 1
progress_bar.update(1)
Expand All @@ -537,7 +590,7 @@ def train_inversion(
placeholder_token_ids=placeholder_token_ids,
placeholder_tokens=placeholder_tokens,
save_path=os.path.join(
save_path, f"step_inv_{global_step}.safetensors"
save_path, f"step_inv_{global_step:04d}.safetensors"
),
save_lora=False,
)
Expand Down Expand Up @@ -583,7 +636,7 @@ def train_inversion(
return

import matplotlib.pyplot as plt
def plot_loss_curve(losses, name, moving_avg=20):
def plot_loss_curve(losses, name, moving_avg=5):
losses = np.array(losses)
losses = np.convolve(losses, np.ones(moving_avg)/moving_avg, mode='valid')
plt.plot(losses)
Expand Down Expand Up @@ -654,7 +707,7 @@ def perform_tuning(
vae,
text_encoder,
scheduler,
optimized_embeddings = text_encoder.get_input_embeddings().weight[:, :],
optimized_embeddings = text_encoder.get_input_embeddings().weight[~index_no_updates, :],
train_inpainting=train_inpainting,
t_mutliplier=0.8,
mixed_precision=True,
Expand Down Expand Up @@ -683,6 +736,12 @@ def perform_tuning(
index_no_updates
] = orig_embeds_params[index_no_updates]

if global_step % 100 == 0:
optimizing_embeds = text_encoder.get_input_embeddings().weight[~index_no_updates]
print("------------------------------")
for i, t in enumerate(optimizing_embeds):
print_most_similar_tokens(tokenizer, t, text_encoder)


global_step += 1

Expand All @@ -696,7 +755,7 @@ def perform_tuning(
placeholder_token_ids=placeholder_token_ids,
placeholder_tokens=placeholder_tokens,
save_path=os.path.join(
save_path, f"step_{global_step}.safetensors"
save_path, f"step_{global_step:04d}.safetensors"
),
target_replace_module_text=lora_clip_target_modules,
target_replace_module_unet=lora_unet_target_modules,
Expand All @@ -706,16 +765,15 @@ def perform_tuning(
.mean()
.item()
)

print("LORA Unet Moved", moved)

moved = (
torch.tensor(
list(itertools.chain(*inspect_lora(text_encoder).values()))
)
.mean()
.item()
)

print("LORA CLIP Moved", moved)

if log_wandb:
Expand Down Expand Up @@ -778,6 +836,7 @@ def train(
placeholder_tokens: str = "",
placeholder_token_at_data: Optional[str] = None,
initializer_tokens: Optional[str] = None,
load_pretrained_inversion_embeddings_path: Optional[str] = None,
seed: int = 42,
resolution: int = 512,
color_jitter: bool = True,
Expand All @@ -788,7 +847,8 @@ def train(
save_steps: int = 100,
gradient_accumulation_steps: int = 4,
gradient_checkpointing: bool = False,
lora_rank: int = 4,
lora_rank_unet: int = 4,
lora_rank_text_encoder: int = 4,
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
lora_clip_target_modules={"CLIPAttention"},
lora_dropout_p: float = 0.0,
Expand Down Expand Up @@ -825,6 +885,10 @@ def train(
script_start_time = time.time()
torch.manual_seed(seed)

if use_template == "person" and not use_face_segmentation_condition:
print("### WARNING ### : Using person template without face segmentation condition")
print("When training people, it is highly recommended to use face segmentation condition!!")

# Get a dict with all the arguments:
args_dict = locals()

Expand All @@ -841,7 +905,7 @@ def train(

if output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
# print(placeholder_tokens, initializer_tokens)

if len(placeholder_tokens) == 0:
placeholder_tokens = []
print("PTI : Placeholder Tokens not given, using null token")
Expand Down Expand Up @@ -874,6 +938,7 @@ def train(

print("PTI : Placeholder Tokens", placeholder_tokens)
print("PTI : Initializer Tokens", initializer_tokens)
print("PTI : Token Map: ", token_map)

# get the models
text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models(
Expand All @@ -886,7 +951,8 @@ def train(
)

noise_scheduler = DDPMScheduler.from_config(
pretrained_model_name_or_path, subfolder="scheduler"
pretrained_model_name_or_path, subfolder="scheduler",
local_files_only = True,
)

if gradient_checkpointing:
Expand Down Expand Up @@ -925,8 +991,6 @@ def train(
train_inpainting=train_inpainting,
)

train_dataset.blur_amount = 200

if train_inpainting:
assert not cached_latents, "Cached latents not supported for inpainting"

Expand Down Expand Up @@ -963,7 +1027,7 @@ def train(
vae = None

# STEP 1 : Perform Inversion
if perform_inversion and not cached_latents:
if perform_inversion and not cached_latents and (load_pretrained_inversion_embeddings_path is None):
preview_training_batch(train_dataloader, "inversion")

print("PTI : Performing Inversion")
Expand Down Expand Up @@ -1014,34 +1078,44 @@ def train(
del ti_optimizer
print("############### Inversion Done ###############")

elif load_pretrained_inversion_embeddings_path is not None:

print("PTI : Loading pretrained inversion embeddings..")
from safetensors.torch import safe_open
# Load the pretrained embeddings from the lora file:
safeloras = safe_open(load_pretrained_inversion_embeddings_path, framework="pt", device="cpu")
#monkeypatch_or_replace_safeloras(pipe, safeloras)
tok_dict = parse_safeloras_embeds(safeloras)
apply_learned_embed_in_clip(
tok_dict,
text_encoder,
tokenizer,
idempotent=True,
)

# Next perform Tuning with LoRA:
if not use_extended_lora:
unet_lora_params, _ = inject_trainable_lora(
unet,
r=lora_rank,
r=lora_rank_unet,
target_replace_module=lora_unet_target_modules,
dropout_p=lora_dropout_p,
scale=lora_scale,
)
print("PTI : not use_extended_lora...")
print("PTI : Will replace modules: ", lora_unet_target_modules)
else:
print("PTI : USING EXTENDED UNET!!!")
lora_unet_target_modules = (
lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE
)
print("PTI : Will replace modules: ", lora_unet_target_modules)
unet_lora_params, _ = inject_trainable_lora_extended(
unet, r=lora_rank, target_replace_module=lora_unet_target_modules
unet, r=lora_rank_unet, target_replace_module=lora_unet_target_modules
)

n_optimizable_unet_params = sum(
[el.numel() for el in itertools.chain(*unet_lora_params)]
)
print("PTI : n_optimizable_unet_params: ", n_optimizable_unet_params)

print(f"PTI : has {len(unet_lora_params)} lora")
print("PTI : Before training:")
inspect_lora(unet)
#n_optimizable_unet_params = sum([el.numel() for el in itertools.chain(*unet_lora_params)])
#print("PTI : Number of optimizable UNET parameters: ", n_optimizable_unet_params)

params_to_optimize = [
{"params": itertools.chain(*unet_lora_params), "lr": unet_lr},
Expand Down Expand Up @@ -1073,15 +1147,15 @@ def train(
text_encoder_lora_params, _ = inject_trainable_lora(
text_encoder,
target_replace_module=lora_clip_target_modules,
r=lora_rank,
r=lora_rank_text_encoder,
)
params_to_optimize += [
{
"params": itertools.chain(*text_encoder_lora_params),
"lr": text_encoder_lr,
}
{"params": itertools.chain(*text_encoder_lora_params),
"lr": text_encoder_lr}
]
inspect_lora(text_encoder)

#n_optimizable_text_Encoder_params = sum( [el.numel() for el in itertools.chain(*text_encoder_lora_params)])
#print("PTI : Number of optimizable text-encoder parameters: ", n_optimizable_text_Encoder_params)

lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora)

Expand All @@ -1090,8 +1164,6 @@ def train(
print("Training text encoder!")
text_encoder.train()

train_dataset.blur_amount = 70

lr_scheduler_lora = get_scheduler(
lr_scheduler_lora,
optimizer=lora_optimizers,
Expand All @@ -1101,6 +1173,22 @@ def train(
if not cached_latents:
preview_training_batch(train_dataloader, "tuning")

#print("PTI : n_optimizable_unet_params: ", n_optimizable_unet_params)
print(f"PTI : has {len(unet_lora_params)} lora")
print("PTI : Before training:")

moved = (
torch.tensor(list(itertools.chain(*inspect_lora(unet).values())))
.mean().item())
print(f"LORA Unet Moved {moved:.6f}")


moved = (
torch.tensor(
list(itertools.chain(*inspect_lora(text_encoder).values()))
).mean().item())
print(f"LORA CLIP Moved {moved:.6f}")

perform_tuning(
unet,
vae,
Expand Down Expand Up @@ -1132,6 +1220,8 @@ def train(
training_time = time.time() - script_start_time
print(f"Training time: {training_time/60:.1f} minutes")
args_dict["training_time_s"] = int(training_time)
args_dict["n_epochs"] = math.ceil(max_train_steps_tuning / len(train_dataloader.dataset))
args_dict["n_training_imgs"] = len(train_dataloader.dataset)

# Save the args_dict to the output directory as a json file:
with open(os.path.join(output_dir, "lora_training_args.json"), "w") as f:
Expand Down
Loading