Skip to content

Commit

Permalink
lfv1 training code and model and stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
cloneofsimo committed Jul 1, 2024
1 parent 72feb0c commit 4fc10e0
Show file tree
Hide file tree
Showing 13 changed files with 278 additions and 1,946 deletions.
533 changes: 0 additions & 533 deletions advanced/main.py

This file was deleted.

219 changes: 129 additions & 90 deletions advanced/main_t2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import os
import random
import time
from typing import Any

import click
Expand All @@ -22,7 +23,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import get_scheduler

from collections import defaultdict
import wandb
from mmdit import MMDiT
import pandas as pd
Expand Down Expand Up @@ -51,7 +52,7 @@ def log_dif(model_cur_sd, model_prev_sd):
param_count = param.numel()

# Determine color based on the criteria using regex
layer_match = re.match(r".*\.layers\.(\d+)(?:\..*)?$", name)
layer_match = re.match(r".*\.single_layers\.(\d+)(?:\..*)?$", name)

if layer_match:
layer_num = int(layer_match.group(1))
Expand Down Expand Up @@ -242,12 +243,12 @@ def extract_model_state_dict_deepspeed(model, global_rank):
with deepspeed.zero.GatheredParameters(
_z3_params_to_fetch([v]), enabled=True
):
v_p = v.data.cpu()
v_p = v.data.cpu().half()
else:
v_p = v.cpu()
v_p = v.cpu().half()

if global_rank == 0:
output_state_dict[k] = v_p.detach()
output_state_dict[k] = v_p.detach().half()

return output_state_dict

Expand Down Expand Up @@ -308,6 +309,11 @@ def decode(self, data: bytes) -> Any:
default=256,
help="Hidden dim, this will mainly determine the model size",
)
@click.option(
"--n_heads",
default=16,
help="Number of heads, this will mainly determine the model size",
)
@click.option(
"--n_layers",
default=12,
Expand All @@ -328,6 +334,17 @@ def decode(self, data: bytes) -> Any:
@click.option(
"--cond_seq_dim", default=2048, help="Conditional sequence dimension, like T5"
)
@click.option(
"--init_ckpt_path", default=None, help="Path to initial checkpoint"
)
@click.option(
"--t_cutoff_tokens", default=64, help="T cutoff tokens for T5 embeddings"
)
@click.option(
"--modify_resolution_at_initialization",
default=False,
help="Modify resolution at initialization",
)
def main(
local_rank,
train_batch_size,
Expand All @@ -342,6 +359,7 @@ def main(
skipped_ema_step=16,
weight_decay=0.1,
hidden_dim=256,
n_heads=16,
n_layers=12,
save_dir="./ckpt",
lr_frozen_factor=1.0,
Expand All @@ -350,6 +368,9 @@ def main(
vae_col="vae_256x256_latents",
t5_col="t5_xl_embeddings",
cond_seq_dim=2048,
init_ckpt_path=None,
t_cutoff_tokens=64,
modify_resolution_at_initialization=True,
):

# first, set the seed
Expand Down Expand Up @@ -403,29 +424,35 @@ def main(
dim=hidden_dim,
global_conddim=hidden_dim,
n_layers=n_layers,
n_heads=8,
n_heads=n_heads,
cond_seq_dim=cond_seq_dim,
max_seq = (vaeres//2 )**2,
),
True,
).cuda()
statedict = torch.load(
"/home/ubuntu/ckpts_36L_5/model_69633/pytorch_model.bin",
map_location="cpu",
)
# remove model.layers.23.modC.1.weight
# statedict.pop("model.layers.31.modC.1.weight")

if init_ckpt_path is not None:
statedict = torch.load(
init_ckpt_path,
map_location="cpu",
)
# # remove model.layers.23.modC.1.weight
# # statedict.pop("model.layers.31.modC.1.weight")

rf.load_state_dict(
statedict,
strict=False,
)
rf.load_state_dict(
statedict,
strict=False,
)

if modify_resolution_at_initialization:
rf.model.extend_pe((16, 16), (vaeres//2, vaeres//2))

ema_state_dict1 = extract_model_state_dict_deepspeed(rf, global_rank)
ema_state_dict2 = {
k: v.detach().cpu().float().clone() for k, v in ema_state_dict1.items()
k: v.clone() for k, v in ema_state_dict1.items()
}
prv_state_dict = {
k: v.detach().cpu().float().clone() for k, v in ema_state_dict1.items()
k: v.clone() for k, v in ema_state_dict1.items()
}

total_params = sum(p.numel() for p in rf.parameters())
Expand Down Expand Up @@ -457,59 +484,67 @@ def main(
assert os.environ.get(varname) is not None, f"{varname} is not set"
print(f"{varname}: {os.environ.get(varname)}")

locdir = f"/tmp/mdstemp_0"
locdir = f"/scratch/simo"
# # cleanup if rank0
# if local_rank == 0:
# try:
# os.system(f"rm -rf {locdir}")
# #os.system(f"rm -rf /tmp/mdstemp_0")

# # make
# os.makedirs(locdir, exist_ok=True)
# except:
# pass

locdir = f"/scratch/simo"
#locdir = f"/tmp/mdstemp_0"
util.clean_stale_shared_memory()
# cleanup if rank0
if local_rank == 0:
os.system(f"rm -rf {locdir}")
# make
os.makedirs(locdir, exist_ok=True)
try:
os.system(f"rm -rf {locdir}")
# make
os.makedirs(locdir, exist_ok=True)
except:
print("Failed to cleanup")

torch.distributed.barrier()
if True:
train_dataset = StreamingDataset(
local=locdir,
remote=train_dir,
split=None,
shuffle=True,
shuffle_algo="py1e",
num_canonical_nodes=(int(os.environ["WORLD_SIZE"]) // 8),
batch_size=per_device_train_batch_size,
shuffle_seed=seed,
shuffle_block_size=10 * 4000,
cache_limit="100gb",
)
else:
streams = []
for idx in range(8):
locdir = f"/tmp/mdstemp_{idx}"
# cleanup if rank0
if local_rank == 0:
os.system(f"rm -rf {locdir}")
# make
os.makedirs(locdir, exist_ok=True)

stream = Stream(
remote=f"/jfs/datacomp-1b-0-10k/{idx}",
local=locdir,
proportion=1 / 8,
)

streams.append(stream)

train_dataset = StreamingDataset(
streams=streams,
shuffle=True,
shuffle_algo="py1e",
num_canonical_nodes=(int(os.environ["WORLD_SIZE"]) // 8),
batch_size=per_device_train_batch_size,
shuffle_seed=seed,
shuffle_block_size=20 * 4000,
cache_limit="100gb",
predownload=2048,
)

# train_dataset = StreamingDataset(
# local=locdir,
# remote=train_dir,
# split=None,
# shuffle=False,
# shuffle_algo="py1s",
# num_canonical_nodes=(int(os.environ["WORLD_SIZE"]) // 8),
# batch_size=per_device_train_batch_size,
# shuffle_block_size=5000,
# cache_limit="50gb",
# predownload=512 * per_device_train_batch_size,
# download_retry=4,
# download_timeout=300,
# )

train_dataset = StreamingDataset(
local=train_dir,
remote=None,
# local=locdir,
# remote=train_dir,
split=None,
shuffle=True,
shuffle_algo="py1s",
shuffle_seed=seed,
num_canonical_nodes=(int(os.environ["WORLD_SIZE"]) // 8),
batch_size=per_device_train_batch_size,
cache_limit="300gb",
download_retry=2,
download_timeout=300,
shuffle_block_size=3000,
)


right_pad = lambda x: torch.cat(
[x, torch.zeros((77 * 3) - x.size(0), cond_seq_dim).to(x.device)], dim=0
[x, torch.zeros((256) - x.size(0), cond_seq_dim).to(x.device)], dim=0
)

def dequantize_t5(tensor):
Expand All @@ -519,7 +554,7 @@ def dequantize_t5(tensor):
dataloader = DataLoader(
train_dataset,
batch_size=per_device_train_batch_size,
num_workers=8,
num_workers=32,
collate_fn=lambda x: {
vae_col: torch.stack([torch.tensor(xx[vae_col]) for xx in x]),
t5_col: torch.stack(
Expand All @@ -533,6 +568,7 @@ def dequantize_t5(tensor):
]
),
},
drop_last = True
)

torch.distributed.barrier()
Expand All @@ -549,38 +585,40 @@ def dequantize_t5(tensor):
optimizer_grouped_parameters = []
final_optimizer_settings = {}

param_groups = defaultdict(lambda: {"params": [], "weight_decay": None, "lr": None})

for n, p in rf.named_parameters():
group_parameters = {}
if p.requires_grad:
if any(ndnl in n for ndnl in no_decay_name_list):
group_parameters["weight_decay"] = 0.0
weight_decay_value = 0.0
else:
group_parameters["weight_decay"] = weight_decay
weight_decay_value = weight_decay

# Define learning rate for specific types of params

if any(ndnl in n for ndnl in no_decay_name_list):
group_parameters["lr"] = learning_rate * 0.033
lr_value = learning_rate * 0.033
elif any(ipt in n for ipt in custom_lr_set.keys()):
input_dim = p.shape[-1]
ipt = [ipt for ipt in custom_lr_set.keys() if ipt in n][0]
group_parameters["lr"] = learning_rate * (
custom_lr_set[ipt] / input_dim
)

lr_value = learning_rate * (custom_lr_set[ipt] / input_dim)
else:
group_parameters["lr"] = learning_rate * (32 / hidden_dim)
lr_value = learning_rate * (32 / hidden_dim)

if any(ndnl in n for ndnl in small_train_name_list):
group_parameters["lr"] = group_parameters["lr"] * lr_frozen_factor
lr_value = lr_value * lr_frozen_factor

group_key = (lr_value, weight_decay_value)
param_groups[group_key]["params"].append(p)
param_groups[group_key]["weight_decay"] = weight_decay_value
param_groups[group_key]["lr"] = lr_value

group_parameters["params"] = [p]
final_optimizer_settings[n] = {
"lr": group_parameters["lr"],
"wd": group_parameters["weight_decay"],
"lr": lr_value,
"wd": weight_decay_value,
"shape": str(list(p.shape)),
}
optimizer_grouped_parameters.append(group_parameters)

optimizer_grouped_parameters = [v for v in param_groups.values()]

if global_rank == 0:
# mkdir and dump optimizer settings
Expand All @@ -594,11 +632,11 @@ def dequantize_t5(tensor):
AdamOptimizer = torch.optim.AdamW

optimizer = AdamOptimizer(
rf.parameters(), lr=learning_rate * (32 / hidden_dim), betas=(0.9, 0.95)
optimizer_grouped_parameters, betas=(0.9, 0.95)
)

lr_scheduler = get_scheduler(
name="linear", optimizer=optimizer, num_warmup_steps=300, num_training_steps=1e6
name="linear", optimizer=optimizer, num_warmup_steps=100, num_training_steps=1e6
)

rf.train()
Expand Down Expand Up @@ -648,7 +686,7 @@ def dequantize_t5(tensor):
)

cond = (
(batch[t5_col].reshape(-1, 77 * 3, cond_seq_dim))
(batch[t5_col].reshape(-1, 256, cond_seq_dim))[:, :t_cutoff_tokens, :]
.to(device)
.to(torch.bfloat16)
)
Expand Down Expand Up @@ -696,7 +734,7 @@ def dequantize_t5(tensor):
log_dif(current_state_dict, prv_state_dict)

prv_state_dict = {
k: v.detach().cpu().float().clone()
k: v.detach().cpu().half().clone()
for k, v in current_state_dict.items()
}
# update ema.
Expand All @@ -715,10 +753,10 @@ def dequantize_t5(tensor):

for k, v in current_state_dict.items():
ema_state_dict1[k] = (
BETA1 * ema_state_dict1[k] + (1 - BETA1) * v
BETA1 * ema_state_dict1[k] + (1 - BETA1) * v.half()
)
ema_state_dict2[k] = (
BETA2 * ema_state_dict2[k] + (1 - BETA2) * v
BETA2 * ema_state_dict2[k] + (1 - BETA2) * v.half()
)
# log 1st value for sanity check
if value is None:
Expand All @@ -734,8 +772,9 @@ def dequantize_t5(tensor):
f"norm: {norm}, loss: {loss.item()}, global_step: {global_step}"
)

if global_step % 4096 == 1:

if global_step % 8192 == 10:
print(f"Starting EMA save at {global_step}")
t1 = time.time()
os.makedirs(f"{save_dir}/model_{global_step}", exist_ok=True)
save_zero_three_model(
model_engine,
Expand All @@ -753,7 +792,7 @@ def dequantize_t5(tensor):
ema_state_dict2, f"{save_dir}/model_{global_step}/ema2.pt"
)

print(f"Model saved at {global_step}, Global Rank {global_rank}")
print(f"Model saved at {global_step}, Global Rank {global_rank}, Time: {time.time() - t1}")

# sync
torch.distributed.barrier()
Expand Down
Loading

0 comments on commit 4fc10e0

Please sign in to comment.