From 4fc10e0cc8ba976152c7936a1af9717209f03e18 Mon Sep 17 00:00:00 2001 From: Simo Ryu Date: Mon, 1 Jul 2024 13:26:04 -0700 Subject: [PATCH] lfv1 training code and model and stuff --- advanced/main.py | 533 ------------ advanced/main_t2i.py | 219 +++-- advanced/main_t2i_highres.py | 819 ------------------- advanced/mmdit.py | 227 ++--- advanced/run.sh | 20 - advanced/{run_multi_node.sh => run_multi.sh} | 23 +- advanced/run_multi_node_resize.sh | 39 - advanced/run_paral.sh | 30 - advanced/run_single.sh | 12 - advanced/run_single_small.sh | 26 - advanced/run_single_t2i.sh | 14 - advanced/test.py | 178 ---- advanced/upload_stuff_hf.py | 84 -- 13 files changed, 278 insertions(+), 1946 deletions(-) delete mode 100644 advanced/main.py delete mode 100644 advanced/main_t2i_highres.py delete mode 100644 advanced/run.sh rename advanced/{run_multi_node.sh => run_multi.sh} (52%) delete mode 100644 advanced/run_multi_node_resize.sh delete mode 100644 advanced/run_paral.sh delete mode 100644 advanced/run_single.sh delete mode 100644 advanced/run_single_small.sh delete mode 100644 advanced/run_single_t2i.sh delete mode 100644 advanced/test.py delete mode 100644 advanced/upload_stuff_hf.py diff --git a/advanced/main.py b/advanced/main.py deleted file mode 100644 index 24db682..0000000 --- a/advanced/main.py +++ /dev/null @@ -1,533 +0,0 @@ -# Mostly Copy-paste from https://github.com/cloneofsimo/min-max-in-dit. - -import math -import os -import random -from typing import Any - -import click -import deepspeed -import numpy as np -import streaming.base.util as util -import torch - -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True - -from deepspeed import get_accelerator -from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam -from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus -from streaming import StreamingDataset -from streaming.base.format.mds.encodings import Encoding, _encodings -from torch.utils.data import DataLoader -from tqdm import tqdm -from transformers import get_scheduler - -import wandb -from mmdit import MMDiT_for_IN1K - - -class RF(torch.nn.Module): - def __init__(self, model, ln=True): - super().__init__() - self.model = model - self.ln = ln - self.stratified = False - - def forward(self, x, cond): - b = x.size(0) - if self.ln: - if self.stratified: - # stratified sampling of normals - # first stratified sample from uniform - quantiles = torch.linspace(0, 1, b + 1).to(x.device) - z = quantiles[:-1] + torch.rand((b,)).to(x.device) / b - # now transform to normal - z = torch.erfinv(2 * z - 1) * math.sqrt(2) - t = torch.sigmoid(z) - else: - nt = torch.randn((b,)).to(x.device) - t = torch.sigmoid(nt) - else: - t = torch.rand((b,)).to(x.device) - texp = t.view([b, *([1] * len(x.shape[1:]))]) - z1 = torch.randn_like(x) - zt = (1 - texp) * x + texp * z1 - - # make t, zt into same dtype as x - zt, t = zt.to(x.dtype), t.to(x.dtype) - vtheta = self.model(zt, t, cond) - batchwise_mse = ((z1 - x - vtheta) ** 2).mean(dim=list(range(1, len(x.shape)))) - tlist = batchwise_mse.detach().cpu().reshape(-1).tolist() - ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)] - return batchwise_mse.mean(), {"batchwise_loss": ttloss} - - @torch.no_grad() - def sample(self, z, cond, null_cond=None, sample_steps=50, cfg=2.0): - b = z.size(0) - dt = 1.0 / sample_steps - dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))]) - images = [z] - for i in range(sample_steps, 0, -1): - t = i / sample_steps - t = torch.tensor([t] * b).to(z.device) - - vc = self.model(z, t, cond) - if null_cond is not None: - vu = self.model(z, t, null_cond) - vc = vu + cfg * (vc - vu) - - z = z - dt * vc - images.append(z) - return images - - @torch.no_grad() - def sample_with_xps(self, z, cond, null_cond=None, sample_steps=50, cfg=2.0): - b = z.size(0) - dt = 1.0 / sample_steps - dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))]) - images = [z] - for i in range(sample_steps, 0, -1): - t = i / sample_steps - t = torch.tensor([t] * b).to(z.device) - - vc = self.model(z, t, cond) - if null_cond is not None: - vu = self.model(z, t, null_cond) - vc = vu + cfg * (vc - vu) - x = z - i * dt * vc - z = z - dt * vc - images.append(x) - return images - - @torch.no_grad() - def sample_with_xps_tff( - self, z, cond, null_cond=None, sample_steps=50, cfg=2.0, tff=lambda x: x - ): - b = z.size(0) - dt = 1.0 / sample_steps - dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))]) - images = [z] - for i in range(sample_steps, 0, -1): - t = i / sample_steps - t = tff(t) - t = torch.tensor([t] * b).to(z.device) - - vc = self.model(z, t, cond) - if null_cond is not None: - vu = self.model(z, t, null_cond) - vc = vu + cfg * (vc - vu) - x = z - i * dt * vc - z = z - dt * vc - images.append(x) - return images - - -def _z3_params_to_fetch(param_list): - return [ - p - for p in param_list - if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE - ] - - -def save_zero_three_model(model, global_rank, save_dir, zero_stage=0): - zero_stage_3 = zero_stage == 3 - os.makedirs(save_dir, exist_ok=True) - WEIGHTS_NAME = "pytorch_model.bin" - output_model_file = os.path.join(save_dir, WEIGHTS_NAME) - - model_to_save = model.module if hasattr(model, "module") else model - if not zero_stage_3: - if global_rank == 0: - torch.save(model_to_save.state_dict(), output_model_file) - else: - output_state_dict = {} - for k, v in model_to_save.named_parameters(): - if hasattr(v, "ds_id"): - with deepspeed.zero.GatheredParameters( - _z3_params_to_fetch([v]), enabled=zero_stage_3 - ): - v_p = v.data.cpu() - else: - v_p = v.cpu() - if global_rank == 0 and "lora" not in k: - output_state_dict[k] = v_p - if global_rank == 0: - torch.save(output_state_dict, output_model_file) - del output_state_dict - - -@torch.no_grad() -def extract_model_state_dict_deepspeed(model, global_rank): - output_state_dict = {} - for k, v in model.named_parameters(): - if hasattr(v, "ds_id"): - with deepspeed.zero.GatheredParameters( - _z3_params_to_fetch([v]), enabled=True - ): - v_p = v.data.cpu() - else: - v_p = v.cpu() - - if global_rank == 0: - output_state_dict[k] = v_p.detach() - - return output_state_dict - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -class uint8(Encoding): - def encode(self, obj: Any) -> bytes: - return obj.tobytes() - - def decode(self, data: bytes) -> Any: - x = np.frombuffer(data, np.uint8).astype(np.float32) - return (x / 255.0 - 0.5) * 24.0 - - -_encodings["uint8"] = uint8 - - -@click.command() -@click.option("--local_rank", default=-1, help="Local rank") -@click.option("--num_train_epochs", default=2, help="Number of training epochs") -@click.option("--learning_rate", default=3e-3, help="Learning rate") -@click.option("--offload", default=False, help="Offload") -@click.option("--train_batch_size", default=256, help="Total Train batch size") -@click.option( - "--per_device_train_batch_size", default=128, help="Per device train batch size" -) -@click.option("--zero_stage", default=2, help="Zero stage, from 0 to 3") -@click.option("--seed", default=42, help="Seed for rng") -@click.option("--run_name", default=None, help="Run name that will be used for wandb") -@click.option("--train_dir", default="./vae_mds", help="Train dir that MDS can read") -@click.option( - "--skipped_ema_step", - default=1024, - help="Skipped EMA step. Karras EMA will save model every skipped_ema_step", -) -@click.option("--weight_decay", default=0.1, help="Weight decay") -@click.option( - "--hidden_dim", - default=256, - help="Hidden dim, this will mainly determine the model size", -) -@click.option( - "--n_layers", - default=12, - help="Number of layers, this will mainly determine the model size", -) -@click.option("--save_dir", default="./ckpt", help="Save dir for model") -def main( - local_rank, - train_batch_size, - per_device_train_batch_size, - num_train_epochs, - learning_rate, - offload=False, - zero_stage=2, - seed=42, - run_name=None, - train_dir="./vae_mds", - skipped_ema_step=16, - weight_decay=0.1, - hidden_dim=256, - n_layers=12, - save_dir="./ckpt", -): - - # first, set the seed - set_seed(seed) - - if run_name is None: - run_name = ( - f"LR:{learning_rate}__num_train_epochs:{num_train_epochs}_offload:{offload}" - ) - - if local_rank == -1: - device = torch.device(get_accelerator().device_name()) - else: - get_accelerator().set_device(local_rank) - device = torch.device(get_accelerator().device_name(), local_rank) - # Initializes the distributed backend which will take care of sychronizing nodes/GPUs - deepspeed.init_distributed() - - # set LOCAL_WORLD_SIZE to 8 - - os.environ["LOCAL_WORLD_SIZE"] = str(os.environ.get("WORLD_SIZE")) - - offload_device = "cpu" if offload else "none" - - ds_config = { - "train_micro_batch_size_per_gpu": per_device_train_batch_size, - "train_batch_size": train_batch_size, - "zero_optimization": { - "stage": zero_stage, - "offload_param": {"device": offload_device}, - "offload_optimizer": {"device": offload_device}, - "stage3_param_persistence_threshold": 1e4, - "stage3_max_live_parameters": 3e7, - "stage3_prefetch_bucket_size": 3e7, - "memory_efficient_linear": False, - }, - "bfloat16": {"enabled": True}, - "gradient_clipping": 0.3, - } - - torch.distributed.barrier() - - global_rank = torch.distributed.get_rank() - - ##### DEFINE model, dataset, sampler, dataloader, optim, schedular - - with deepspeed.zero.Init(enabled=(zero_stage == 3)): - - rf = RF( - MMDiT_for_IN1K( - in_channels=4, - out_channels=4, - dim=hidden_dim, - global_conddim=hidden_dim, - n_layers=n_layers, - n_heads=8, - ), - True, - ).cuda() - # rf.load_state_dict(torch.load("/home/host/simo/ckpts/5b_2/model_57345/ema1.pt", map_location="cpu")) - - ema_state_dict1 = extract_model_state_dict_deepspeed(rf, global_rank) - ema_state_dict2 = extract_model_state_dict_deepspeed(rf, global_rank) - - total_params = sum(p.numel() for p in rf.parameters()) - size_in_bytes = total_params * 4 - size_in_gb = size_in_bytes / (1024**3) - print( - f"Model Size: {size_in_bytes}, {size_in_gb} GB, Total Param Count: {total_params / 1e6} M" - ) - - util.clean_stale_shared_memory() - # barrier - torch.distributed.barrier() - - train_dataset = StreamingDataset( - local=train_dir, - remote=None, - split=None, - shuffle=True, - shuffle_algo="naive", - num_canonical_nodes=1, - batch_size=per_device_train_batch_size, - ) - - print(f"\n\n######-----Dataset loaded: {len(train_dataset)}") - print( - f"Rank: {os.environ.get('RANK')}, Local Rank: {os.environ.get('LOCAL_WORLD_SIZE')}, Global Rank: {global_rank}" - ) - - dataloader = DataLoader( - train_dataset, - batch_size=per_device_train_batch_size, - num_workers=8, - ) - - torch.distributed.barrier() - ## Config muP-learning rate. - no_decay_name_list = [ - "bias", - "norm", - "c_vec_embedder", - "cond_seq_linear", - "init_x_linear", - ] - - optimizer_grouped_parameters = [] - final_optimizer_settings = {} - - for n, p in rf.named_parameters(): - group_parameters = {} - if p.requires_grad: - if any(ndnl in n for ndnl in no_decay_name_list): - group_parameters["weight_decay"] = 0.0 - else: - group_parameters["weight_decay"] = weight_decay - - # Define learning rate for specific types of params - - if "embed" in n or any(ndnl in n for ndnl in no_decay_name_list): - group_parameters["lr"] = learning_rate * 0.1 - else: - group_parameters["lr"] = learning_rate * (32 / hidden_dim) - - group_parameters["params"] = [p] - final_optimizer_settings[n] = { - "lr": group_parameters["lr"], - "wd": group_parameters["weight_decay"], - } - optimizer_grouped_parameters.append(group_parameters) - - AdamOptimizer = torch.optim.AdamW - - optimizer = AdamOptimizer( - optimizer_grouped_parameters, lr=learning_rate, betas=(0.9, 0.999) - ) - - lr_scheduler = get_scheduler( - name="linear", optimizer=optimizer, num_warmup_steps=300, num_training_steps=1e6 - ) - - rf.train() - - model_engine, optimizer, _, lr_scheduler = deepspeed.initialize( - model=rf, config=ds_config, lr_scheduler=lr_scheduler, optimizer=optimizer - ) - - global_step = 0 - - if global_rank == 0: - lossbin = {i: 0 for i in range(10)} - losscnt = {i: 1e-6 for i in range(10)} - ##### actual training loop - - if global_rank == 0: - wandb.init( - project="rf_in1k_mup", - name=run_name, - config={ - "num_train_epochs": num_train_epochs, - "learning_rate": learning_rate, - "offload": offload, - "train_batch_size": train_batch_size, - "per_device_train_batch_size": per_device_train_batch_size, - "zero_stage": zero_stage, - "seed": seed, - "train_dir": train_dir, - "skipped_ema_step": skipped_ema_step, - "weight_decay": weight_decay, - "hidden_dim": hidden_dim, - }, - ) - - for i in range(num_train_epochs): - pbar = tqdm(dataloader) - - for batch in pbar: - - x = ( - batch["vae_output"].reshape(-1, 4, 32, 32).to(device).to(torch.bfloat16) - * 0.13025 - ) - - y = torch.tensor(list(map(int, batch["label"]))).long().to(x.device) - # randomly make y into index 1000, with prob 0.1 - y = torch.where( - torch.rand_like(y.float()) < 0.1, - (torch.ones_like(y) * 1000).long(), - y, - ) - - loss, info = model_engine(x, y) - model_engine.backward(loss) - model_engine.step() - - norm = model_engine.get_global_grad_norm() - - global_step += 1 - if global_rank == 0: - batchwise_loss = info["batchwise_loss"] - # check t-wise loss - for t, l in batchwise_loss: - lossbin[int(t * 10)] += l - losscnt[int(t * 10)] += 1 - - if global_step % 16 == 0: - wandb.log( - { - "train_loss": loss.item(), - "train_grad_norm": norm, - "value": value, - "ema1_of_value": ema1_of_value, - "ema2_of_value": ema2_of_value, - **{ - f"lossbin_{i}": lossbin[i] / losscnt[i] - for i in range(10) - }, - } - ) - # reset - lossbin = {i: 0 for i in range(10)} - losscnt = {i: 1e-6 for i in range(10)} - - if global_step % skipped_ema_step == 1: - - current_state_dict = extract_model_state_dict_deepspeed(rf, global_rank) - - if global_rank == 0: - - # update ema. - BETA1 = (1 - 1 / global_step) ** (1 + 16) - BETA2 = (1 - 1 / global_step) ** (1 + 9) - - # adjust beta for skipped-ema - BETA1 = 1 - (1 - BETA1) * skipped_ema_step - BETA1 = max(0, BETA1) - BETA2 = 1 - (1 - BETA2) * skipped_ema_step - BETA2 = max(0, BETA2) - - value = None - ema1_of_value = None - ema2_of_value = None - - for k, v in sorted(current_state_dict.items()): - ema_state_dict1[k] = ( - BETA1 * ema_state_dict1[k] + (1 - BETA1) * v - ) - ema_state_dict2[k] = ( - BETA2 * ema_state_dict2[k] + (1 - BETA2) * v - ) - # log 1st value for sanity check - if value is None: - value = v.half().flatten()[0].item() - ema1_of_value = ( - ema_state_dict1[k].half().flatten()[0].item() - ) - ema2_of_value = ( - ema_state_dict2[k].half().flatten()[0].item() - ) - - pbar.set_description( - f"norm: {norm}, loss: {loss.item()}, global_step: {global_step}" - ) - - if global_step % 4096 == 1: - # make save_dir - os.makedirs(save_dir, exist_ok=True) - - os.makedirs(f"{save_dir}/model_{global_step}", exist_ok=True) - save_zero_three_model( - model_engine, - global_rank, - f"{save_dir}/model_{global_step}/", - zero_stage=zero_stage, - ) - - # save ema weights - if global_rank == 0: - torch.save( - ema_state_dict1, f"{save_dir}/model_{global_step}/ema1.pt" - ) - torch.save( - ema_state_dict2, f"{save_dir}/model_{global_step}/ema2.pt" - ) - - print(f"Model saved at {global_step}") - - -if __name__ == "__main__": - main() diff --git a/advanced/main_t2i.py b/advanced/main_t2i.py index f3aa6d9..44812ce 100644 --- a/advanced/main_t2i.py +++ b/advanced/main_t2i.py @@ -3,6 +3,7 @@ import math import os import random +import time from typing import Any import click @@ -22,7 +23,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm from transformers import get_scheduler - +from collections import defaultdict import wandb from mmdit import MMDiT import pandas as pd @@ -51,7 +52,7 @@ def log_dif(model_cur_sd, model_prev_sd): param_count = param.numel() # Determine color based on the criteria using regex - layer_match = re.match(r".*\.layers\.(\d+)(?:\..*)?$", name) + layer_match = re.match(r".*\.single_layers\.(\d+)(?:\..*)?$", name) if layer_match: layer_num = int(layer_match.group(1)) @@ -242,12 +243,12 @@ def extract_model_state_dict_deepspeed(model, global_rank): with deepspeed.zero.GatheredParameters( _z3_params_to_fetch([v]), enabled=True ): - v_p = v.data.cpu() + v_p = v.data.cpu().half() else: - v_p = v.cpu() + v_p = v.cpu().half() if global_rank == 0: - output_state_dict[k] = v_p.detach() + output_state_dict[k] = v_p.detach().half() return output_state_dict @@ -308,6 +309,11 @@ def decode(self, data: bytes) -> Any: default=256, help="Hidden dim, this will mainly determine the model size", ) +@click.option( + "--n_heads", + default=16, + help="Number of heads, this will mainly determine the model size", +) @click.option( "--n_layers", default=12, @@ -328,6 +334,17 @@ def decode(self, data: bytes) -> Any: @click.option( "--cond_seq_dim", default=2048, help="Conditional sequence dimension, like T5" ) +@click.option( + "--init_ckpt_path", default=None, help="Path to initial checkpoint" +) +@click.option( + "--t_cutoff_tokens", default=64, help="T cutoff tokens for T5 embeddings" +) +@click.option( + "--modify_resolution_at_initialization", + default=False, + help="Modify resolution at initialization", +) def main( local_rank, train_batch_size, @@ -342,6 +359,7 @@ def main( skipped_ema_step=16, weight_decay=0.1, hidden_dim=256, + n_heads=16, n_layers=12, save_dir="./ckpt", lr_frozen_factor=1.0, @@ -350,6 +368,9 @@ def main( vae_col="vae_256x256_latents", t5_col="t5_xl_embeddings", cond_seq_dim=2048, + init_ckpt_path=None, + t_cutoff_tokens=64, + modify_resolution_at_initialization=True, ): # first, set the seed @@ -403,29 +424,35 @@ def main( dim=hidden_dim, global_conddim=hidden_dim, n_layers=n_layers, - n_heads=8, + n_heads=n_heads, cond_seq_dim=cond_seq_dim, + max_seq = (vaeres//2 )**2, ), True, ).cuda() - statedict = torch.load( - "/home/ubuntu/ckpts_36L_5/model_69633/pytorch_model.bin", - map_location="cpu", - ) - # remove model.layers.23.modC.1.weight - # statedict.pop("model.layers.31.modC.1.weight") + + if init_ckpt_path is not None: + statedict = torch.load( + init_ckpt_path, + map_location="cpu", + ) + # # remove model.layers.23.modC.1.weight + # # statedict.pop("model.layers.31.modC.1.weight") - rf.load_state_dict( - statedict, - strict=False, - ) + rf.load_state_dict( + statedict, + strict=False, + ) + + if modify_resolution_at_initialization: + rf.model.extend_pe((16, 16), (vaeres//2, vaeres//2)) ema_state_dict1 = extract_model_state_dict_deepspeed(rf, global_rank) ema_state_dict2 = { - k: v.detach().cpu().float().clone() for k, v in ema_state_dict1.items() + k: v.clone() for k, v in ema_state_dict1.items() } prv_state_dict = { - k: v.detach().cpu().float().clone() for k, v in ema_state_dict1.items() + k: v.clone() for k, v in ema_state_dict1.items() } total_params = sum(p.numel() for p in rf.parameters()) @@ -457,59 +484,67 @@ def main( assert os.environ.get(varname) is not None, f"{varname} is not set" print(f"{varname}: {os.environ.get(varname)}") - locdir = f"/tmp/mdstemp_0" + locdir = f"/scratch/simo" + # # cleanup if rank0 + # if local_rank == 0: + # try: + # os.system(f"rm -rf {locdir}") + # #os.system(f"rm -rf /tmp/mdstemp_0") + + # # make + # os.makedirs(locdir, exist_ok=True) + # except: + # pass + + locdir = f"/scratch/simo" + #locdir = f"/tmp/mdstemp_0" + util.clean_stale_shared_memory() # cleanup if rank0 if local_rank == 0: - os.system(f"rm -rf {locdir}") - # make - os.makedirs(locdir, exist_ok=True) + try: + os.system(f"rm -rf {locdir}") + # make + os.makedirs(locdir, exist_ok=True) + except: + print("Failed to cleanup") torch.distributed.barrier() - if True: - train_dataset = StreamingDataset( - local=locdir, - remote=train_dir, - split=None, - shuffle=True, - shuffle_algo="py1e", - num_canonical_nodes=(int(os.environ["WORLD_SIZE"]) // 8), - batch_size=per_device_train_batch_size, - shuffle_seed=seed, - shuffle_block_size=10 * 4000, - cache_limit="100gb", - ) - else: - streams = [] - for idx in range(8): - locdir = f"/tmp/mdstemp_{idx}" - # cleanup if rank0 - if local_rank == 0: - os.system(f"rm -rf {locdir}") - # make - os.makedirs(locdir, exist_ok=True) - - stream = Stream( - remote=f"/jfs/datacomp-1b-0-10k/{idx}", - local=locdir, - proportion=1 / 8, - ) - - streams.append(stream) - - train_dataset = StreamingDataset( - streams=streams, - shuffle=True, - shuffle_algo="py1e", - num_canonical_nodes=(int(os.environ["WORLD_SIZE"]) // 8), - batch_size=per_device_train_batch_size, - shuffle_seed=seed, - shuffle_block_size=20 * 4000, - cache_limit="100gb", - predownload=2048, - ) + + # train_dataset = StreamingDataset( + # local=locdir, + # remote=train_dir, + # split=None, + # shuffle=False, + # shuffle_algo="py1s", + # num_canonical_nodes=(int(os.environ["WORLD_SIZE"]) // 8), + # batch_size=per_device_train_batch_size, + # shuffle_block_size=5000, + # cache_limit="50gb", + # predownload=512 * per_device_train_batch_size, + # download_retry=4, + # download_timeout=300, + # ) + + train_dataset = StreamingDataset( + local=train_dir, + remote=None, + # local=locdir, + # remote=train_dir, + split=None, + shuffle=True, + shuffle_algo="py1s", + shuffle_seed=seed, + num_canonical_nodes=(int(os.environ["WORLD_SIZE"]) // 8), + batch_size=per_device_train_batch_size, + cache_limit="300gb", + download_retry=2, + download_timeout=300, + shuffle_block_size=3000, + ) + right_pad = lambda x: torch.cat( - [x, torch.zeros((77 * 3) - x.size(0), cond_seq_dim).to(x.device)], dim=0 + [x, torch.zeros((256) - x.size(0), cond_seq_dim).to(x.device)], dim=0 ) def dequantize_t5(tensor): @@ -519,7 +554,7 @@ def dequantize_t5(tensor): dataloader = DataLoader( train_dataset, batch_size=per_device_train_batch_size, - num_workers=8, + num_workers=32, collate_fn=lambda x: { vae_col: torch.stack([torch.tensor(xx[vae_col]) for xx in x]), t5_col: torch.stack( @@ -533,6 +568,7 @@ def dequantize_t5(tensor): ] ), }, + drop_last = True ) torch.distributed.barrier() @@ -549,38 +585,40 @@ def dequantize_t5(tensor): optimizer_grouped_parameters = [] final_optimizer_settings = {} + param_groups = defaultdict(lambda: {"params": [], "weight_decay": None, "lr": None}) + for n, p in rf.named_parameters(): - group_parameters = {} if p.requires_grad: if any(ndnl in n for ndnl in no_decay_name_list): - group_parameters["weight_decay"] = 0.0 + weight_decay_value = 0.0 else: - group_parameters["weight_decay"] = weight_decay + weight_decay_value = weight_decay # Define learning rate for specific types of params - if any(ndnl in n for ndnl in no_decay_name_list): - group_parameters["lr"] = learning_rate * 0.033 + lr_value = learning_rate * 0.033 elif any(ipt in n for ipt in custom_lr_set.keys()): input_dim = p.shape[-1] ipt = [ipt for ipt in custom_lr_set.keys() if ipt in n][0] - group_parameters["lr"] = learning_rate * ( - custom_lr_set[ipt] / input_dim - ) - + lr_value = learning_rate * (custom_lr_set[ipt] / input_dim) else: - group_parameters["lr"] = learning_rate * (32 / hidden_dim) + lr_value = learning_rate * (32 / hidden_dim) if any(ndnl in n for ndnl in small_train_name_list): - group_parameters["lr"] = group_parameters["lr"] * lr_frozen_factor + lr_value = lr_value * lr_frozen_factor + + group_key = (lr_value, weight_decay_value) + param_groups[group_key]["params"].append(p) + param_groups[group_key]["weight_decay"] = weight_decay_value + param_groups[group_key]["lr"] = lr_value - group_parameters["params"] = [p] final_optimizer_settings[n] = { - "lr": group_parameters["lr"], - "wd": group_parameters["weight_decay"], + "lr": lr_value, + "wd": weight_decay_value, "shape": str(list(p.shape)), } - optimizer_grouped_parameters.append(group_parameters) + + optimizer_grouped_parameters = [v for v in param_groups.values()] if global_rank == 0: # mkdir and dump optimizer settings @@ -594,11 +632,11 @@ def dequantize_t5(tensor): AdamOptimizer = torch.optim.AdamW optimizer = AdamOptimizer( - rf.parameters(), lr=learning_rate * (32 / hidden_dim), betas=(0.9, 0.95) + optimizer_grouped_parameters, betas=(0.9, 0.95) ) lr_scheduler = get_scheduler( - name="linear", optimizer=optimizer, num_warmup_steps=300, num_training_steps=1e6 + name="linear", optimizer=optimizer, num_warmup_steps=100, num_training_steps=1e6 ) rf.train() @@ -648,7 +686,7 @@ def dequantize_t5(tensor): ) cond = ( - (batch[t5_col].reshape(-1, 77 * 3, cond_seq_dim)) + (batch[t5_col].reshape(-1, 256, cond_seq_dim))[:, :t_cutoff_tokens, :] .to(device) .to(torch.bfloat16) ) @@ -696,7 +734,7 @@ def dequantize_t5(tensor): log_dif(current_state_dict, prv_state_dict) prv_state_dict = { - k: v.detach().cpu().float().clone() + k: v.detach().cpu().half().clone() for k, v in current_state_dict.items() } # update ema. @@ -715,10 +753,10 @@ def dequantize_t5(tensor): for k, v in current_state_dict.items(): ema_state_dict1[k] = ( - BETA1 * ema_state_dict1[k] + (1 - BETA1) * v + BETA1 * ema_state_dict1[k] + (1 - BETA1) * v.half() ) ema_state_dict2[k] = ( - BETA2 * ema_state_dict2[k] + (1 - BETA2) * v + BETA2 * ema_state_dict2[k] + (1 - BETA2) * v.half() ) # log 1st value for sanity check if value is None: @@ -734,8 +772,9 @@ def dequantize_t5(tensor): f"norm: {norm}, loss: {loss.item()}, global_step: {global_step}" ) - if global_step % 4096 == 1: - + if global_step % 8192 == 10: + print(f"Starting EMA save at {global_step}") + t1 = time.time() os.makedirs(f"{save_dir}/model_{global_step}", exist_ok=True) save_zero_three_model( model_engine, @@ -753,7 +792,7 @@ def dequantize_t5(tensor): ema_state_dict2, f"{save_dir}/model_{global_step}/ema2.pt" ) - print(f"Model saved at {global_step}, Global Rank {global_rank}") + print(f"Model saved at {global_step}, Global Rank {global_rank}, Time: {time.time() - t1}") # sync torch.distributed.barrier() diff --git a/advanced/main_t2i_highres.py b/advanced/main_t2i_highres.py deleted file mode 100644 index c6bad8c..0000000 --- a/advanced/main_t2i_highres.py +++ /dev/null @@ -1,819 +0,0 @@ -# Mostly Copy-paste from https://github.com/cloneofsimo/min-max-in-dit. - -import math -import os -import random -from typing import Any - -import click -import deepspeed -import numpy as np -import streaming.base.util as util -import torch - -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True - -from deepspeed import get_accelerator -from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam -from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus -from streaming import StreamingDataset, Stream -from streaming.base.format.mds.encodings import Encoding, _encodings -from torch.utils.data import DataLoader -from tqdm import tqdm -from transformers import get_scheduler - -import wandb -from mmdit import MMDiT -import pandas as pd -import plotly.express as px -import re - - -@torch.no_grad() -def log_dif(model_cur_sd, model_prev_sd): - # Initialize a new run - - # Create lists to store data for the plot - layer_names = [] - std_devs = [] - l1_norms = [] - param_counts = [] - colors = [] - markers = [] - - # Iterate over the parameters and compute necessary metrics - for name, param in model_cur_sd.items(): - if name in model_prev_sd: - prev_param = model_prev_sd[name] - std_dev = param.std().item() - l1_norm = torch.abs(param - prev_param).mean().item() - param_count = param.numel() - - # Determine color based on the criteria using regex - layer_match = re.match(r".*\.layers\.(\d+)(?:\..*)?$", name) - - if layer_match: - layer_num = int(layer_match.group(1)) - colors.append(layer_num) - else: - colors.append(-1) - - # Determine marker type - if param.ndim == 1: - markers.append("x") - else: - markers.append("circle") - - layer_names.append(name) - std_devs.append(std_dev) - l1_norms.append(np.log1p(l1_norm)) # log(1 + x) transformation - param_counts.append(np.log(param_count)) - - # Create a DataFrame for the plot - df = pd.DataFrame( - { - "Layer Name": layer_names, - "Standard Deviation": std_devs, - "L1 Norm of Changes (log scale)": l1_norms, - "Parameter Count (log)": param_counts, - "Color": colors, - "Marker": markers, - } - ) - - # Determine the number of layers - max_layer_num = df[df["Color"] != -1]["Color"].max() - - # Create a color scale for the layers (yellow to red) - color_scale = px.colors.sequential.YlOrRd - color_discrete_map = { - i: color_scale[int(i * (len(color_scale) - 1) / max_layer_num)] - for i in range(int(max_layer_num) + 1) - } - color_discrete_map[-1] = "blue" # Blue for non-layer parameters - - # Create Plotly figure - fig = px.scatter( - df, - x="Standard Deviation", - y="L1 Norm of Changes (log scale)", - size="Parameter Count (log)", - color="Color", - hover_name="Layer Name", - title="Model Weight Distribution and Changes", - symbol="Marker", - color_discrete_map=color_discrete_map, - opacity=0.7, - ) - - # - - table = wandb.Table(columns=["plotly_figure"]) - - # Create path for Plotly figure - path_to_plotly_html = "./plotly_figure.html" - - # Write Plotly figure to HTML - fig.write_html(path_to_plotly_html, auto_play=False) - - # Add Plotly figure as HTML file into Table - table.add_data(wandb.Html(path_to_plotly_html)) - - # Log Table - wandb.log({"weight_distribution_changes": table}) - - -class RF(torch.nn.Module): - def __init__(self, model, ln=True): - super().__init__() - self.model = model - self.ln = ln - self.stratified = False - self.t_transform = lambda t : (math.sqrt(3) * t / (1 + (math.sqrt(3) - 1) *t)) - - def forward(self, x, cond, randomly_augment_x_latent=False): - - if randomly_augment_x_latent: - # this will take B, C, H, W latent and crop it so they are ~ 33% of the original size. - b, c, h, w = x.size() - if random.random() < -1: - new_w = random.randint(int(w * 0.3333), w) - new_h = random.randint(int(h * 0.3333), h) - # We dont want very small spatiality. We priotize uniform distibution on w, but h should be large enough. - # if new_w = 0.333w, you should get max h. - new_h = max(new_h, int(0.3333 * w * h / new_w)) - # this way, we are making sure we don't get very small spatiality. - new_w, new_h = min(new_w, w), min(new_h, h) - # make it even - new_w, new_h = new_w - new_w % 4, new_h - new_h % 4 - - # now make b maximal. Following will gurantee that new_b, new_h, new_w are similar accross all devices, - # and also, they are correct. This comes with the tradeoff that in worst case we drop 66% of the batches. - - new_b = int(0.18 * b * h * w / (new_h * new_w)) * 2 - new_b = max(new_b, 1) - x = x[ - :new_b, - :, - h // 2 - new_h // 2 : h // 2 + new_h // 2, - w // 2 - new_w // 2 : w // 2 + new_w // 2, - ] - else: - new_b = max(int(b * 0.18) * 2, 1) - x = x[:new_b] - - b = x.size(0) - #print(x.size()) - if self.ln: - if self.stratified: - # stratified sampling of normals - # first stratified sample from uniform - quantiles = torch.linspace(0, 1, b + 1).to(x.device) - z = quantiles[:-1] + torch.rand((b,)).to(x.device) / b - # now transform to normal - z = torch.erfinv(2 * z - 1) * math.sqrt(2) - t = torch.sigmoid(z) - else: - nt = torch.randn((b,)).to(x.device) - t = torch.sigmoid(nt) - else: - t = torch.rand((b,)).to(x.device) - texp = t.view([b, *([1] * len(x.shape[1:]))]) - texp = self.t_transform(texp) - z1 = torch.randn_like(x) - zt = (1 - texp) * x + texp * z1 - - # make t, zt into same dtype as x - zt, t = zt.to(x.dtype), t.to(x.dtype) - vtheta = self.model(zt, t, cond) - batchwise_mse = ((z1 - x - vtheta) ** 2).mean(dim=list(range(1, len(x.shape)))) - tlist = batchwise_mse.detach().cpu().reshape(-1).tolist() - ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)] - return batchwise_mse.mean(), {"batchwise_loss": ttloss} - - @torch.no_grad() - def sample(self, z, cond, null_cond=None, sample_steps=50, cfg=2.0): - b = z.size(0) - dt = 1.0 / sample_steps - dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))]) - images = [z] - for i in range(sample_steps, 0, -1): - t = i / sample_steps - t = torch.tensor([t] * b).to(z.device) - - vc = self.model(z, t, cond) - if null_cond is not None: - vu = self.model(z, t, null_cond) - vc = vu + cfg * (vc - vu) - - z = z - dt * vc - images.append(z) - return images - - @torch.no_grad() - def sample_with_xps(self, z, cond, null_cond=None, sample_steps=50, cfg=2.0): - b = z.size(0) - dt = 1.0 / sample_steps - dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))]) - images = [z] - for i in range(sample_steps, 0, -1): - t = i / sample_steps - t = torch.tensor([t] * b).to(z.device) - - vc = self.model(z, t, cond) - if null_cond is not None: - vu = self.model(z, t, null_cond) - vc = vu + cfg * (vc - vu) - x = z - i * dt * vc - z = z - dt * vc - images.append(x) - return images - - -def _z3_params_to_fetch(param_list): - return [ - p - for p in param_list - if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE - ] - - -def save_zero_three_model(model, global_rank, save_dir, zero_stage=0): - zero_stage_3 = zero_stage == 3 - os.makedirs(save_dir, exist_ok=True) - WEIGHTS_NAME = "pytorch_model.bin" - output_model_file = os.path.join(save_dir, WEIGHTS_NAME) - - model_to_save = model.module if hasattr(model, "module") else model - if not zero_stage_3: - if global_rank == 0: - torch.save(model_to_save.state_dict(), output_model_file) - else: - output_state_dict = {} - for k, v in model_to_save.named_parameters(): - if hasattr(v, "ds_id"): - with deepspeed.zero.GatheredParameters( - _z3_params_to_fetch([v]), enabled=zero_stage_3 - ): - v_p = v.data.cpu() - else: - v_p = v.cpu() - if global_rank == 0 and "lora" not in k: - output_state_dict[k] = v_p - if global_rank == 0: - torch.save(output_state_dict, output_model_file) - del output_state_dict - - -@torch.no_grad() -def extract_model_state_dict_deepspeed(model, global_rank): - output_state_dict = {} - for k, v in model.named_parameters(): - if hasattr(v, "ds_id"): - with deepspeed.zero.GatheredParameters( - _z3_params_to_fetch([v]), enabled=True - ): - v_p = v.data.cpu() - else: - v_p = v.cpu() - - if global_rank == 0: - output_state_dict[k] = v_p.detach() - - return output_state_dict - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -class uint8(Encoding): - def encode(self, obj: Any) -> bytes: - return obj.tobytes() - - def decode(self, data: bytes) -> Any: - x = np.frombuffer(data, np.uint8).astype(np.float32) - return x - - -class np16(Encoding): - def encode(self, obj: Any) -> bytes: - return obj.tobytes() - - def decode(self, data: bytes) -> Any: - return np.frombuffer(data, np.float16) - - -_encodings["np16"] = np16 -_encodings["uint8"] = uint8 - - -@click.command() -@click.option("--local_rank", default=-1, help="Local rank") -@click.option("--num_train_epochs", default=2, help="Number of training epochs") -@click.option("--learning_rate", default=3e-3, help="Learning rate") -@click.option("--offload", default=False, help="Offload") -@click.option("--train_batch_size", default=256, help="Total Train batch size") -@click.option( - "--per_device_train_batch_size", default=128, help="Per device train batch size" -) -@click.option("--zero_stage", default=2, help="Zero stage, from 0 to 3") -@click.option("--seed", default=42, help="Seed for rng") -@click.option("--run_name", default=None, help="Run name that will be used for wandb") -@click.option( - "--train_dir", - default="/home/host/simo/capfusion_mds", - help="Train dir that MDS can read", -) -@click.option( - "--skipped_ema_step", - default=1024, - help="Skipped EMA step. Karras EMA will save model every skipped_ema_step", -) -@click.option("--weight_decay", default=0.1, help="Weight decay") -@click.option( - "--hidden_dim", - default=256, - help="Hidden dim, this will mainly determine the model size", -) -@click.option( - "--n_layers", - default=12, - help="Number of layers, this will mainly determine the model size", -) -@click.option("--save_dir", default="./ckpt", help="Save dir for model") -@click.option( - "--lr_frozen_factor", - default=1.0, - help="Learning rate for (nearly) frozen layers. You would want this less than 1.", -) -@click.option("--note", default="hi", help="Note for wandb") -@click.option("--vaeres", default=32, help="VAE resolution. 32 x 32 by default") -@click.option( - "--vae_col", default="vae_1024x1024_latents", help="Column name for VAE data" -) -@click.option("--t5_col", default="t5_xl_embeddings", help="Column name for T5 data") -@click.option( - "--cond_seq_dim", default=2048, help="Conditional sequence dimension, like T5" -) -@click.option( - "--resize_pe_at_initialization", - default=False, - help="Resize positional encoding at initialization", -) -def main( - local_rank, - train_batch_size, - per_device_train_batch_size, - num_train_epochs, - learning_rate, - offload=False, - zero_stage=2, - seed=42, - run_name=None, - train_dir="./vae_mds", - skipped_ema_step=16, - weight_decay=0.1, - hidden_dim=256, - n_layers=12, - save_dir="./ckpt", - lr_frozen_factor=1.0, - note="hi", - vaeres=32, - vae_col="vae_256x256_latents", - t5_col="t5_xl_embeddings", - cond_seq_dim=2048, - resize_pe_at_initialization=False, -): - - # first, set the seed - set_seed(seed) - - if run_name is None: - run_name = ( - f"LR:{learning_rate}__num_train_epochs:{num_train_epochs}_offload:{offload}" - ) - - if local_rank == -1: - device = torch.device(get_accelerator().device_name()) - else: - get_accelerator().set_device(local_rank) - device = torch.device(get_accelerator().device_name(), local_rank) - # Initializes the distributed backend which will take care of sychronizing nodes/GPUs - deepspeed.init_distributed() - - # set LOCAL_WORLD_SIZE to 8 - - offload_device = "cpu" if offload else "none" - - ds_config = { - "train_micro_batch_size_per_gpu": per_device_train_batch_size, - "train_batch_size": train_batch_size, - "zero_optimization": { - "stage": zero_stage, - "offload_param": {"device": offload_device}, - "offload_optimizer": {"device": offload_device}, - "stage3_param_persistence_threshold": 1e4, - "stage3_max_live_parameters": 3e7, - "stage3_prefetch_bucket_size": 3e7, - "memory_efficient_linear": False, - }, - "bfloat16": {"enabled": True}, - "gradient_clipping": 1.0, - } - - torch.distributed.barrier() - - global_rank = torch.distributed.get_rank() - - ##### DEFINE model, dataset, sampler, dataloader, optim, schedular - - with deepspeed.zero.Init(enabled=(zero_stage == 3)): - - rf = RF( - MMDiT( - in_channels=4, - out_channels=4, - dim=hidden_dim, - global_conddim=hidden_dim, - n_layers=n_layers, - n_heads=8, - cond_seq_dim=cond_seq_dim, - max_seq= 96 * 96 - ), - True, - ).cuda() - if True: - statedict = torch.load( - "/home/ubuntu/ckpts_36L_2/model_102401/ema1.pt", - #"/home/ubuntu/ckpts_36L_2_highres_freezemost/model_12288/ema1.pt", - map_location="cpu", - ) - # remove model.layers.23.modC.1.weight - # statedict.pop("model.layers.31.modC.1.weight") - - rf.load_state_dict( - statedict, - strict=False, - ) - if resize_pe_at_initialization: - rf.model.extend_pe((16, 16), (vaeres // 2, vaeres // 2)) - - ema_state_dict1 = extract_model_state_dict_deepspeed(rf, global_rank) - ema_state_dict2 = { - k: v.detach().cpu().float().clone() for k, v in ema_state_dict1.items() - } - prv_state_dict = { - k: v.detach().cpu().float().clone() for k, v in ema_state_dict1.items() - } - - total_params = sum(p.numel() for p in rf.parameters()) - size_in_bytes = total_params * 4 - size_in_gb = size_in_bytes / (1024**3) - print( - f"Model Size: {size_in_bytes}, {size_in_gb} GB, Total Param Count: {total_params / 1e6} M" - ) - - util.clean_stale_shared_memory() - # barrier - torch.distributed.barrier() - - os.environ["LOCAL_WORLD_SIZE"] = str(min(8, int(os.environ.get("WORLD_SIZE")))) - # WORLD_SIZE: Total number of processes to launch across all nodes. - # LOCAL_WORLD_SIZE: Total number of processes to launch for each node. - # RANK: Rank of the current process, which is the range between 0 to WORLD_SIZE - 1. - # MASTER_ADDR: The hostname for the rank-zero process. - # MASTER_PORT: The port for the rank-zero process. - - for varname in [ - "RANK", - "LOCAL_WORLD_SIZE", - "WORLD_SIZE", - "MASTER_ADDR", - "MASTER_PORT", - ]: - assert os.environ.get(varname) is not None, f"{varname} is not set" - print(f"{varname}: {os.environ.get(varname)}") - - locdir = f"/tmp/mdstemp_0" - # cleanup if rank0 - if local_rank == 0: - os.system(f"rm -rf {locdir}") - # make - os.makedirs(locdir, exist_ok=True) - - torch.distributed.barrier() - if True: - train_dataset = StreamingDataset( - local=locdir, - remote=train_dir, - split=None, - shuffle=True, - shuffle_algo="py1e", - num_canonical_nodes=(int(os.environ["WORLD_SIZE"]) // 8), - batch_size=per_device_train_batch_size, - shuffle_seed=seed, - shuffle_block_size=10 * 4000, - cache_limit="100gb", - ) - else: - streams = [] - for idx in range(8): - locdir = f"/tmp/mdstemp_{idx}" - # cleanup if rank0 - if local_rank == 0: - os.system(f"rm -rf {locdir}") - # make - os.makedirs(locdir, exist_ok=True) - - stream = Stream( - remote=f"/jfs/datacomp-1b-0-10k/{idx}", - local=locdir, - proportion=1 / 8, - ) - - streams.append(stream) - - train_dataset = StreamingDataset( - streams=streams, - shuffle=True, - shuffle_algo="py1e", - num_canonical_nodes=(int(os.environ["WORLD_SIZE"]) // 8), - batch_size=per_device_train_batch_size, - shuffle_seed=seed, - shuffle_block_size=20 * 4000, - cache_limit="100gb", - predownload=2048, - ) - - right_pad = lambda x: torch.cat( - [x, torch.zeros((256) - x.size(0), cond_seq_dim).to(x.device)], dim=0 - ) - - def dequantize_t5(tensor): - denorm_tensor = tensor.to(torch.float16) / 255 - return (denorm_tensor * 0.5) - 0.25 - - dataloader = DataLoader( - train_dataset, - batch_size=per_device_train_batch_size, - num_workers=8, - collate_fn=lambda x: { - vae_col: torch.stack([torch.tensor(xx[vae_col]) for xx in x]), - t5_col: torch.stack( - [ - right_pad( - dequantize_t5( - torch.tensor(xx[t5_col]).reshape(-1, cond_seq_dim) - ) - ) - for xx in x - ] - ), - }, - ) - - torch.distributed.barrier() - ## Config muP-learning rate. - no_decay_name_list = ["bias", "norm", "positional_encoding", "register_tokens"] - - small_train_name_list = ["w2q", "w2k", "w2v", "w2o", "mlpX", "modX"] - - custom_lr_set = { - "init_x_linear": 4.0, - "cond_seq_linear": 32.0, - } - - optimizer_grouped_parameters = [] - final_optimizer_settings = {} - - # requires grad for first 2 and last 2 layer - # for n, p in rf.named_parameters(): - # if "layers" in n: - # if any(layername in n for layername in ["layers.0.", "layers.1.", "layers.34.", "layers.35."]): - # p.requires_grad = True - # else: - # p.requires_grad = False - # else: - # p.requires_grad = True - - for n, p in rf.named_parameters(): - group_parameters = {} - if p.requires_grad: - if any(ndnl in n for ndnl in no_decay_name_list): - group_parameters["weight_decay"] = 0.0 - else: - group_parameters["weight_decay"] = weight_decay - - # Define learning rate for specific types of params - - if any(ndnl in n for ndnl in no_decay_name_list): - group_parameters["lr"] = learning_rate * 0.033 - elif any(ipt in n for ipt in custom_lr_set.keys()): - input_dim = p.shape[-1] - ipt = [ipt for ipt in custom_lr_set.keys() if ipt in n][0] - group_parameters["lr"] = learning_rate * ( - custom_lr_set[ipt] / input_dim - ) - - else: - group_parameters["lr"] = learning_rate * (32 / hidden_dim) - - if any(ndnl in n for ndnl in small_train_name_list): - group_parameters["lr"] = group_parameters["lr"] * lr_frozen_factor - - group_parameters["params"] = [p] - final_optimizer_settings[n] = { - "lr": group_parameters["lr"], - "wd": group_parameters["weight_decay"], - "shape": str(list(p.shape)), - } - optimizer_grouped_parameters.append(group_parameters) - - if global_rank == 0: - # mkdir and dump optimizer settings - os.makedirs(save_dir, exist_ok=True) - - with open(f"{save_dir}/optimizer_settings.txt", "w") as f: - # format - for k, v in sorted(final_optimizer_settings.items()): - f.write(f"{k}: {v}\n") - - AdamOptimizer = torch.optim.AdamW - - optimizer = AdamOptimizer( - #optimizer_grouped_parameters, - rf.parameters(), lr=learning_rate * (32 / hidden_dim), - betas=(0.9, 0.95) - ) - - lr_scheduler = get_scheduler( - name="linear", optimizer=optimizer, num_warmup_steps=300, num_training_steps=1e6 - ) - - rf.train() - - model_engine, optimizer, _, lr_scheduler = deepspeed.initialize( - model=rf, config=ds_config, lr_scheduler=lr_scheduler, optimizer=optimizer - ) - - global_step = 0 - - if global_rank == 0: - lossbin = {i: 0 for i in range(10)} - losscnt = {i: 1e-6 for i in range(10)} - ##### actual training loop - - if global_rank == 0: - wandb.init( - project="6.5b_t2i_mup", - name=run_name, - config={ - "num_train_epochs": num_train_epochs, - "learning_rate": learning_rate, - "offload": offload, - "train_batch_size": train_batch_size, - "per_device_train_batch_size": per_device_train_batch_size, - "zero_stage": zero_stage, - "seed": seed, - "train_dir": train_dir, - "skipped_ema_step": skipped_ema_step, - "weight_decay": weight_decay, - "hidden_dim": hidden_dim, - }, - notes=note, - ) - - for i in range(num_train_epochs): - pbar = tqdm(dataloader) - - for batch in pbar: - - x = ( - batch[vae_col] - .reshape(-1, 4, vaeres, vaeres) - .to(device) - .to(torch.bfloat16) - * 0.13025 - ) - - cond = ( - (batch[t5_col].reshape(-1, 256, cond_seq_dim)) - .to(device) - .to(torch.bfloat16) - ) - - loss, info = model_engine( - x, {"c_seq": cond}, randomly_augment_x_latent=True - ) - model_engine.backward(loss) - model_engine.step() - - norm = model_engine.get_global_grad_norm() - - global_step += 1 - if global_rank == 0: - batchwise_loss = info["batchwise_loss"] - # check t-wise loss - for t, l in batchwise_loss: - lossbin[int(t * 10)] += l - losscnt[int(t * 10)] += 1 - - if global_step % 64 == 0: - wandb.log( - { - "train/avg_loss": sum(lossbin.values()) - / sum(losscnt.values()), - "train/grad_norm": norm, - "value/rawval": value, - "value/ema1val": ema1_of_value, - "value/ema2val": ema2_of_value, - **{ - f"loss/bin_{i}": lossbin[i] / losscnt[i] - for i in range(10) - }, - } - ) - # reset - lossbin = {i: 0 for i in range(10)} - losscnt = {i: 1e-6 for i in range(10)} - - if global_step % skipped_ema_step == 1: - - current_state_dict = extract_model_state_dict_deepspeed(rf, global_rank) - - if global_rank == 0: - - # log - log_dif(current_state_dict, prv_state_dict) - - prv_state_dict = { - k: v.detach().cpu().float().clone() - for k, v in current_state_dict.items() - } - # update ema. - BETA1 = (1 - 1 / global_step) ** (1 + 16) - BETA2 = (1 - 1 / global_step) ** (1 + 9) - - # adjust beta for skipped-ema - BETA1 = 1 - (1 - BETA1) * skipped_ema_step - BETA1 = max(0, BETA1) - BETA2 = 1 - (1 - BETA2) * skipped_ema_step - BETA2 = max(0, BETA2) - - value = None - ema1_of_value = None - ema2_of_value = None - - for k, v in current_state_dict.items(): - ema_state_dict1[k] = ( - BETA1 * ema_state_dict1[k] + (1 - BETA1) * v - ) - ema_state_dict2[k] = ( - BETA2 * ema_state_dict2[k] + (1 - BETA2) * v - ) - # log 1st value for sanity check - if value is None: - value = v.half().flatten()[0].item() - ema1_of_value = ( - ema_state_dict1[k].half().flatten()[0].item() - ) - ema2_of_value = ( - ema_state_dict2[k].half().flatten()[0].item() - ) - - pbar.set_description( - f"norm: {norm}, loss: {loss.item()}, global_step: {global_step}" - ) - - if global_step % 4096 == 0: - - os.makedirs(f"{save_dir}/model_{global_step}", exist_ok=True) - save_zero_three_model( - model_engine, - global_rank, - f"{save_dir}/model_{global_step}/", - zero_stage=zero_stage, - ) - - # save ema weights - if global_rank == 0: - torch.save( - ema_state_dict1, f"{save_dir}/model_{global_step}/ema1.pt" - ) - torch.save( - ema_state_dict2, f"{save_dir}/model_{global_step}/ema2.pt" - ) - - print(f"Model saved at {global_step}, Global Rank {global_rank}") - - # sync - torch.distributed.barrier() - - -if __name__ == "__main__": - main() diff --git a/advanced/mmdit.py b/advanced/mmdit.py index 82ca6e8..6e03c9c 100644 --- a/advanced/mmdit.py +++ b/advanced/mmdit.py @@ -71,6 +71,54 @@ def forward(self, hidden_states): hidden_states = self.weight.to(torch.float32) * hidden_states return hidden_states.to(input_dtype) +class SingleAttention(nn.Module): + def __init__(self, dim, n_heads, mh_qknorm=False): + super().__init__() + + self.n_heads = n_heads + self.head_dim = dim // n_heads + + # this is for cond + self.w1q = nn.Linear(dim, dim, bias=False) + self.w1k = nn.Linear(dim, dim, bias=False) + self.w1v = nn.Linear(dim, dim, bias=False) + self.w1o = nn.Linear(dim, dim, bias=False) + + self.q_norm1 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim)) + if mh_qknorm + else Fp32LayerNorm(self.head_dim, bias=False, elementwise_affine=False) + ) + self.k_norm1 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim)) + if mh_qknorm + else Fp32LayerNorm(self.head_dim, bias=False, elementwise_affine=False) + ) + + @torch.compile() + def forward(self, c): + + bsz, seqlen1, _ = c.shape + + q, k, v = self.w1q(c), self.w1k(c), self.w1v(c) + q = q.view(bsz, seqlen1, self.n_heads, self.head_dim) + k = k.view(bsz, seqlen1, self.n_heads, self.head_dim) + v = v.view(bsz, seqlen1, self.n_heads, self.head_dim) + q, k = self.q_norm1(q), self.k_norm1(k) + + output = F.scaled_dot_product_attention( + q.permute(0, 2, 1, 3), + k.permute(0, 2, 1, 3), + v.permute(0, 2, 1, 3), + dropout_p=0.0, + is_causal=False, + scale=1 / self.head_dim**0.5, + ).permute(0, 2, 1, 3) + output = output.flatten(-2) + c = self.w1o(output) + return c + + class DoubleAttention(nn.Module): def __init__(self, dim, n_heads, mh_qknorm=False): @@ -94,24 +142,26 @@ def __init__(self, dim, n_heads, mh_qknorm=False): self.q_norm1 = ( MultiHeadLayerNorm((self.n_heads, self.head_dim)) if mh_qknorm - else Fp32LayerNorm(self.head_dim, bias=False) + else Fp32LayerNorm(self.head_dim, bias=False, elementwise_affine=False) ) self.k_norm1 = ( MultiHeadLayerNorm((self.n_heads, self.head_dim)) if mh_qknorm - else Fp32LayerNorm(self.head_dim, bias=False) + else Fp32LayerNorm(self.head_dim, bias=False, elementwise_affine=False) ) self.q_norm2 = ( MultiHeadLayerNorm((self.n_heads, self.head_dim)) if mh_qknorm - else Fp32LayerNorm(self.head_dim, bias=False) + else Fp32LayerNorm(self.head_dim, bias=False, elementwise_affine=False) ) self.k_norm2 = ( MultiHeadLayerNorm((self.n_heads, self.head_dim)) if mh_qknorm - else Fp32LayerNorm(self.head_dim, bias=False) + else Fp32LayerNorm(self.head_dim, bias=False, elementwise_affine=False) ) + + @torch.compile() def forward(self, c, x): @@ -150,6 +200,7 @@ def forward(self, c, x): c, x = output.split([seqlen1, seqlen2], dim=1) c = self.w1o(c) x = self.w2o(x) + return c, x @@ -157,8 +208,8 @@ class MMDiTBlock(nn.Module): def __init__(self, dim, heads=8, global_conddim=1024, is_last=False): super().__init__() - self.normC1 = Fp32LayerNorm(dim, bias=False) - self.normC2 = Fp32LayerNorm(dim, bias=False) + self.normC1 = Fp32LayerNorm(dim, elementwise_affine=False, bias=False) + self.normC2 = Fp32LayerNorm(dim, elementwise_affine=False, bias=False) if not is_last: self.mlpC = MLP(dim, hidden_dim=dim * 4) self.modC = nn.Sequential( @@ -171,8 +222,8 @@ def __init__(self, dim, heads=8, global_conddim=1024, is_last=False): nn.Linear(global_conddim, 2 * dim, bias=False), ) - self.normX1 = Fp32LayerNorm(dim, bias=False) - self.normX2 = Fp32LayerNorm(dim, bias=False) + self.normX1 = Fp32LayerNorm(dim, elementwise_affine=False, bias=False) + self.normX2 = Fp32LayerNorm(dim, elementwise_affine=False, bias=False) self.mlpX = MLP(dim, hidden_dim=dim * 4) self.modX = nn.Sequential( nn.SiLU(), @@ -186,14 +237,11 @@ def __init__(self, dim, heads=8, global_conddim=1024, is_last=False): def forward(self, c, x, global_cond, **kwargs): cres, xres = c, x - # cpath - if not self.is_last: - cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = ( - self.modC(global_cond).chunk(6, dim=1) - ) - else: - cshift_msa, cscale_msa = self.modC(global_cond).chunk(2, dim=1) - + + cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = ( + self.modC(global_cond).chunk(6, dim=1) + ) + c = modulate(self.normC1(c), cshift_msa, cscale_msa) # xpath @@ -204,13 +252,12 @@ def forward(self, c, x, global_cond, **kwargs): x = modulate(self.normX1(x), xshift_msa, xscale_msa) # attention - c, x = self.attn(c, x) - if not self.is_last: - c = self.normC2(cres + cgate_msa.unsqueeze(1) * c) - c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp)) - c = cres + c + + c = self.normC2(cres + cgate_msa.unsqueeze(1) * c) + c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp)) + c = cres + c x = self.normX2(xres + xgate_msa.unsqueeze(1) * x) x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp)) @@ -218,6 +265,39 @@ def forward(self, c, x, global_cond, **kwargs): return c, x +class DiTBlock(nn.Module): + # like MMDiTBlock, but it only has X + def __init__(self, dim, heads=8, global_conddim=1024): + super().__init__() + + self.norm1 = Fp32LayerNorm(dim, elementwise_affine=False, bias=False) + self.norm2 = Fp32LayerNorm(dim, elementwise_affine=False, bias=False) + + self.modCX = nn.Sequential( + nn.SiLU(), + nn.Linear(global_conddim, 6 * dim, bias=False), + ) + + self.attn = SingleAttention(dim, heads) + self.mlp = MLP(dim, hidden_dim=dim * 4) + + @torch.compile() + def forward(self, cx, global_cond, **kwargs): + cxres = cx + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX( + global_cond + ).chunk(6, dim=1) + cx = modulate(self.norm1(cx), shift_msa, scale_msa) + cx = self.attn(cx) + cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx) + mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp)) + cx = gate_mlp.unsqueeze(1) * mlpout + + cx = cxres + cx + + return cx + + class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256): @@ -260,11 +340,11 @@ def __init__( patch_size=2, dim=2048, n_layers=8, + n_double_layers=4, n_heads=4, global_conddim=1024, cond_seq_dim=2048, - cond_vector_dim=1024, - max_seq=96 * 96, + max_seq=16 * 16, ): super().__init__() @@ -280,11 +360,20 @@ def __init__( self.positional_encoding = nn.Parameter(torch.randn(1, max_seq, dim) * 0.1) self.register_tokens = nn.Parameter(torch.randn(1, 8, dim) * 0.02) - self.layers = nn.ModuleList([]) - for idx in range(n_layers): - self.layers.append( + self.double_layers = nn.ModuleList([]) + self.single_layers = nn.ModuleList([]) + + + for idx in range(n_double_layers): + self.double_layers.append( MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1)) ) + + for idx in range(n_double_layers, n_layers): + self.single_layers.append( + DiTBlock(dim, n_heads, global_conddim) + ) + self.final_linear = nn.Linear( dim, patch_size * patch_size * out_channels, bias=False @@ -298,10 +387,13 @@ def __init__( self.out_channels = out_channels self.patch_size = patch_size + self.n_double_layers = n_double_layers + self.n_layers = n_layers for pn, p in self.named_parameters(): - if "mod" in pn: + if ".mod" in pn: nn.init.constant_(p, 0) + print("zeroed", pn) # if cond_seq_linear nn.init.constant_(self.cond_seq_linear.weight, 0) @@ -375,9 +467,20 @@ def forward(self, x, t, conds, **kwargs): c = torch.cat([self.register_tokens.repeat(c.size(0), 1, 1), c], dim=1) global_cond = self.t_embedder(t) # B, D + + if len(self.double_layers) > 0: + for layer in self.double_layers: + + c, x = layer(c, x, global_cond, **kwargs) + + if len(self.single_layers) > 0: + c_len = c.size(1) + cx = torch.cat([c, x], dim=1) + for layer in self.single_layers: + + cx = layer(cx, global_cond, **kwargs) - for layer in self.layers: - c, x = layer(c, x, global_cond, **kwargs) + x = cx[:, c_len:] fshift, fscale = self.modF(global_cond).chunk(2, dim=1) @@ -387,72 +490,8 @@ def forward(self, x, t, conds, **kwargs): return x -class MultiTokenEmbedding(nn.Module): - # make two embedding, one for each token, concat and return - def __init__(self, num_tokens, hidden_size): - super().__init__() - self.embedding1 = nn.Embedding(num_tokens, hidden_size) - self.embedding2 = nn.Embedding(num_tokens, hidden_size) - - def forward(self, x): - emb1 = self.embedding1(x).unsqueeze(1) - emb2 = self.embedding2(x).unsqueeze(1) - return torch.cat([emb1, emb2], dim=1) # B, 2, D - - -class MMDiT_for_IN1K(MMDiT): - # This will "simulate" having clip. - # it will act as having clip that encodes both clip global vector and clip sequence vector. - # in reality this is just one hot encoding. - def __init__( - self, - in_channels=4, - out_channels=4, - patch_size=2, - dim=1024, - n_layers=8, - n_heads=4, - global_conddim=1024, - cond_seq_dim=2048, - cond_vector_dim=1024, - max_seq=32 * 32, - ): - super(MMDiT_for_IN1K, self).__init__( - in_channels, - out_channels, - patch_size, - dim, - n_layers, - n_heads, - global_conddim, - cond_seq_dim, - cond_vector_dim, - max_seq, - ) - - # replace cond_seq_linear and c_vec_embedder, so they will both take discrete input. - self.c_vec_embedder = nn.Embedding(1024, global_conddim) - self.cond_seq_linear = MultiTokenEmbedding(1024, dim) - - # init embedding with very small 0.03 - nn.init.trunc_normal_( - self.c_vec_embedder.weight, mean=0.0, std=0.1, a=-0.2, b=0.2 - ) - nn.init.trunc_normal_( - self.cond_seq_linear.embedding1.weight, mean=0.0, std=0.1, a=-0.2, b=0.2 - ) - nn.init.trunc_normal_( - self.cond_seq_linear.embedding2.weight, mean=0.0, std=0.1, a=-0.2, b=0.2 - ) - - def forward(self, x, t, conds, **kwargs): - - conds_dict = {"c_seq": conds, "c_vec": conds} - return super(MMDiT_for_IN1K, self).forward(x, t, conds_dict, **kwargs) - - if __name__ == "__main__": - model = MMDiT(max_seq=32 * 32) + model = MMDiT(max_seq=32 * 32, dim = 3072, n_heads=24) model.extend_pe((32, 32), (64, 64)) x = torch.randn(1, 4, 20, 48) t = torch.randn(8) diff --git a/advanced/run.sh b/advanced/run.sh deleted file mode 100644 index f14ef78..0000000 --- a/advanced/run.sh +++ /dev/null @@ -1,20 +0,0 @@ -export WORLD_SIZE=8 #$(nvidia-smi -L | wc -l) -# deepspeed --num_gpus $WORLD_SIZE main.py --learning_rate 1e-3 --save_dir "/home/host/simo/ckpts/{}" -# lrs=(1e-4 2e-4 4e-4 8e-4) -# widths=(64 128 256) -loglr=(-8 -7 -6 -5 -4 -3) -widths=(32 64 128 256) - -for width in "${widths[@]}"; do - for loglr in "${loglr[@]}"; do - lr=$(python -c "print(2**$loglr)") - run_name="mup_lr_${lr}_width_${width}" - echo "Running $run_name" - deepspeed --num_gpus $WORLD_SIZE \ - main.py \ - --learning_rate $lr \ - --hidden_dim $width \ - --run_name $run_name \ - --save_dir "/home/host/simo/ckpts/${run_name}" - done -done \ No newline at end of file diff --git a/advanced/run_multi_node.sh b/advanced/run_multi.sh similarity index 52% rename from advanced/run_multi_node.sh rename to advanced/run_multi.sh index 3322a7e..a3274f3 100644 --- a/advanced/run_multi_node.sh +++ b/advanced/run_multi.sh @@ -18,20 +18,29 @@ for host in `cat hostfiles`; do host_ip=`echo $host | cut -d ' ' -f 1` echo "Running command on $host_ip" ssh -o StrictHostKeyChecking=no $host_ip $COMMAND + # check if pytorch is installed + ssh -o StrictHostKeyChecking=no $host_ip "python -c 'import torch; print(torch.randn(100).cuda())'" + # remove /scratch/simo + ssh -o StrictHostKeyChecking=no $host_ip "sudo rm -rf /scratch/simo && sudo mkdir -p /scratch/simo && sudo chmod 777 /scratch/simo && sudo chown -R nobody:nogroup /scratch/simo" done +bash /home/simo/common_installations.sh +# --train_dir "/jfs/mds_relinked" \ + deepspeed --hostfile=./hostfiles \ main_t2i.py \ - --learning_rate 0.0266 \ - --hidden_dim 2560 \ + --learning_rate 0.005 \ + --hidden_dim 3072 \ --n_layers 36 \ - --run_name node-2-36L-run-6 \ - --save_dir "/home/ubuntu/ckpts_36L_6" \ + --n_heads 12 \ + --run_name 6.5b-dithybrid-36-24-node4-run-stage8 \ + --save_dir "/jfs/stage1_0621_stage8" \ --num_train_epochs 200 \ --train_batch_size 1024 \ - --per_device_train_batch_size 16 \ - --train_dir "/jfs/datacomp-1b-0-10k/2/" \ + --per_device_train_batch_size 32 \ + --train_dir "/jfs/datacomp-1b-0-10k/1" \ --seed 5 \ - --note "continue training" \ No newline at end of file + --note "continue training" \ + --init_ckpt_path "/jfs/stage1_0621_stage7/model_57354/ema1.pt" \ No newline at end of file diff --git a/advanced/run_multi_node_resize.sh b/advanced/run_multi_node_resize.sh deleted file mode 100644 index 3462af8..0000000 --- a/advanced/run_multi_node_resize.sh +++ /dev/null @@ -1,39 +0,0 @@ -export NCCL_P2P_DISABLE=1 -export NCCL_P2P_LEVEL=NVL -export NCCL_SHM_DISABLE=1 - -# put the above env var to ~/.deepspeed_env -# remove .deepseed_env if it exists -rm -f ~/.deepspeed_env -echo "NCCL_P2P_DISABLE=1" > ~/.deepspeed_env -echo "NCCL_P2P_LEVEL=NVL" >> ~/.deepspeed_env -echo "NCCL_SHM_DISABLE=1" >> ~/.deepspeed_env - -# goes into ssh of each host in hostfile and run the following cmd -COMMAND="lsof /dev/nvidia* | awk '{print \$2}' | xargs -I {} kill {}" - -# run the command on all hosts in hostfile -for host in `cat hostfiles`; do - # host is in form of "xxx.xx.xxx.xx slots=8" - host_ip=`echo $host | cut -d ' ' -f 1` - echo "Running command on $host_ip" - ssh -o StrictHostKeyChecking=no $host_ip $COMMAND -done - - - -deepspeed --hostfile=./hostfiles \ - main_t2i_highres.py \ - --learning_rate 0.018 \ - --hidden_dim 2560 \ - --n_layers 36 \ - --run_name node-2-highres-0.18 \ - --save_dir "/home/ubuntu/ckpts_36L_2_highres_lr_0.006" \ - --num_train_epochs 200 \ - --train_batch_size 256 \ - --per_device_train_batch_size 4 \ - --train_dir "/home/ubuntu/laionpop" \ - --seed 3 \ - --note "Laion Pop Aesthetic fine tuning" \ - --resize_pe_at_initialization True \ - --vaeres 128 \ No newline at end of file diff --git a/advanced/run_paral.sh b/advanced/run_paral.sh deleted file mode 100644 index cf0cf5c..0000000 --- a/advanced/run_paral.sh +++ /dev/null @@ -1,30 +0,0 @@ -export WORLD_SIZE=8 #$(nvidia-smi -L | wc -l) -# deepspeed --num_gpus $WORLD_SIZE main.py --learning_rate 1e-3 --save_dir "/home/host/simo/ckpts/{}" -# lrs=(1e-4 2e-4 4e-4 8e-4) -# widths=(64 128 256) -loglr=(-10 -9 -8 -7 -6 -5 -4 -3) -widths=(128) -gpuidx=(0 1 2 3 4 5 6 7) -masterports=(11600 11601 11602 11603 11604 11605 11606 11607) -for width in "${widths[@]}"; do - for idx in "${gpuidx[@]}"; do - loglr_idx=$((idx)) - loglrv=${loglr[$loglr_idx]} - masterport=${masterports[$idx]} - lr=$(python -c "print(2**$loglrv)") - run_name="layer48_mup_lr_${lr}_width_${width}" - echo "Running $run_name" - deepspeed --master_port $masterport --include=localhost:$idx \ - main.py \ - --learning_rate $lr \ - --hidden_dim $width \ - --run_name $run_name \ - --save_dir "/home/host/simo/ckpts/${run_name}" \ - --num_train_epochs 2 \ - --n_layers 48 \ - --train_batch_size 128 \ - --per_device_train_batch_size 128 & - done - % ${#loglr[@]} - wait -done \ No newline at end of file diff --git a/advanced/run_single.sh b/advanced/run_single.sh deleted file mode 100644 index 68218ce..0000000 --- a/advanced/run_single.sh +++ /dev/null @@ -1,12 +0,0 @@ -export WORLD_SIZE=8 - -deepspeed --num_gpus $WORLD_SIZE \ - main.py \ - --learning_rate 0.01 \ - --hidden_dim 2560 \ - --n_layers 28 \ - --run_name largerun \ - --save_dir "/home/ubuntu/ckpts" \ - --num_train_epochs 200 \ - --train_batch_size 256 \ - --per_device_train_batch_size 16 \ diff --git a/advanced/run_single_small.sh b/advanced/run_single_small.sh deleted file mode 100644 index 3110d87..0000000 --- a/advanced/run_single_small.sh +++ /dev/null @@ -1,26 +0,0 @@ -loglr=(-7) -widths=(128) -gpuidx=(0) -masterports=(11201) -for width in "${widths[@]}"; do - for idx in "${gpuidx[@]}"; do - loglr_idx=$((idx)) - loglrv=${loglr[$loglr_idx]} - masterport=${masterports[$idx]} - lr=$(python -c "print(2**$loglrv)") - run_name="layer48_mup_lr_${lr}_width_${width}" - echo "Running $run_name" - deepspeed --master_port $masterport --include=localhost:$idx \ - main.py \ - --learning_rate $lr \ - --hidden_dim $width \ - --run_name $run_name \ - --save_dir "/home/host/simo/ckpts/${run_name}" \ - --num_train_epochs 2 \ - --n_layers 48 \ - --train_batch_size 128 \ - --per_device_train_batch_size 128 & - done - % ${#loglr[@]} - wait -done \ No newline at end of file diff --git a/advanced/run_single_t2i.sh b/advanced/run_single_t2i.sh deleted file mode 100644 index f9807c2..0000000 --- a/advanced/run_single_t2i.sh +++ /dev/null @@ -1,14 +0,0 @@ -export WORLD_SIZE=8 - -deepspeed --num_gpus $WORLD_SIZE \ - main_t2i.py \ - --learning_rate 0.002 \ - --hidden_dim 2560 \ - --n_layers 24 \ - --run_name largerun_freeze_pd_2 \ - --save_dir "/home/host/simo/ckpts/5b_cont_2" \ - --num_train_epochs 200 \ - --train_batch_size 128 \ - --per_device_train_batch_size 16 \ - --note "continue with smaller lr." \ - --seed 40 diff --git a/advanced/test.py b/advanced/test.py deleted file mode 100644 index 93062c8..0000000 --- a/advanced/test.py +++ /dev/null @@ -1,178 +0,0 @@ -import pandas as pd -import plotly.express as px -import re -import numpy as np -import torch -import wandb - - -def log_dif(model_cur_sd, model_prev_sd): - # Initialize a new run - - # Create lists to store data for the plot - layer_names = [] - std_devs = [] - l1_norms = [] - param_counts = [] - colors = [] - markers = [] - - # Iterate over the parameters and compute necessary metrics - for name, param in model_cur_sd.items(): - if name in model_prev_sd: - prev_param = model_prev_sd[name] - std_dev = param.std().item() - l1_norm = torch.abs(param - prev_param).mean().item() - param_count = param.numel() - - # Determine color based on the criteria using regex - layer_match = re.match(r"layers\.(\d+)(?:\..*)?$", name) - if layer_match: - layer_num = int(layer_match.group(1)) - colors.append(layer_num) - else: - colors.append(-1) - - # Determine marker type - if param.ndim == 1: - markers.append("x") - else: - markers.append("circle") - - layer_names.append(name) - std_devs.append(std_dev) - l1_norms.append(np.log1p(l1_norm)) # log(1 + x) transformation - param_counts.append(np.log(param_count)) - - # Create a DataFrame for the plot - df = pd.DataFrame( - { - "Layer Name": layer_names, - "Standard Deviation": std_devs, - "L1 Norm of Changes (log scale)": l1_norms, - "Parameter Count (log)": param_counts, - "Color": colors, - "Marker": markers, - } - ) - - # Determine the number of layers - max_layer_num = df[df["Color"] != -1]["Color"].max() - - # Create a color scale for the layers (yellow to red) - color_scale = px.colors.sequential.YlOrRd - color_discrete_map = { - i: color_scale[int(i * (len(color_scale) - 1) / max_layer_num)] - for i in range(int(max_layer_num) + 1) - } - color_discrete_map[-1] = "blue" # Blue for non-layer parameters - - # Create Plotly figure - fig = px.scatter( - df, - x="Standard Deviation", - y="L1 Norm of Changes (log scale)", - size="Parameter Count (log)", - color="Color", - hover_name="Layer Name", - title="Model Weight Distribution and Changes", - symbol="Marker", - color_discrete_map=color_discrete_map, - opacity=0.7, - ) - - # - - table = wandb.Table(columns=["plotly_figure"]) - - # Create path for Plotly figure - path_to_plotly_html = "./plotly_figure.html" - - # Write Plotly figure to HTML - fig.write_html(path_to_plotly_html, auto_play=False) - - # Add Plotly figure as HTML file into Table - table.add_data(wandb.Html(path_to_plotly_html)) - - # Log Table - wandb.log({"weight_distribution_changes": table}) - - -# Test script -import torch -import torch.nn as nn -import torch.optim as optim - - -# Define a simple neural network with many layers -class TestNet(nn.Module): - def __init__(self): - super(TestNet, self).__init__() - self.layers = nn.Sequential( - nn.Linear(10, 50), - nn.ReLU(), - nn.Linear(50, 100), - nn.ReLU(), - nn.Linear(100, 200), - nn.ReLU(), - nn.Linear(200, 100), - nn.ReLU(), - nn.Linear(100, 100), - nn.ReLU(), - nn.Linear(100, 100), - nn.ReLU(), - nn.Linear(100, 100), - nn.ReLU(), - nn.Linear(100, 100), - nn.ReLU(), - nn.Linear(100, 100), - nn.ReLU(), - nn.Linear(100, 100), - nn.ReLU(), - nn.Linear(100, 100), - nn.ReLU(), - nn.Linear(100, 100), - nn.ReLU(), - nn.Linear(100, 100), - nn.ReLU(), - nn.Linear(100, 100), - nn.ReLU(), - nn.Linear(100, 50), - nn.ReLU(), - nn.Linear(50, 10), - ) - - def forward(self, x): - return self.layers(x) - - -# Initialize the network and optimizer -model = TestNet() -optimizer = optim.SGD(model.parameters(), lr=0.01) -criterion = nn.MSELoss() - -# Dummy input and target -input = torch.randn(5, 10) -target = torch.randn(5, 10) - -# Forward pass -output = model(input) -loss = criterion(output, target) - -# Backward pass and optimization -loss.backward() -optimizer.step() - -# Save current and previous state dicts -model_prev_sd = model.state_dict() -model_prev_sd = {k: v.clone() for k, v in model_prev_sd.items()} -optimizer.zero_grad() -output = model(input) -loss = criterion(output, target) -loss.backward() -optimizer.step() -model_cur_sd = model.state_dict() - -# Log differences -wandb.init(project="test") -log_dif(model_cur_sd, model_prev_sd) diff --git a/advanced/upload_stuff_hf.py b/advanced/upload_stuff_hf.py deleted file mode 100644 index fef627f..0000000 --- a/advanced/upload_stuff_hf.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -import time -from huggingface_hub import HfApi, CommitOperationAdd, create_branch - -if False: - CKPT_DIRS = "./data" - SLEEP_INTERVAL = 10 - REPO_ID = "cloneofsimo/test_model" - REPO_TYPE = "model" -else: - CKPT_DIRS = "/home/host/simo/ckpts/5b_highres" - SLEEP_INTERVAL = 60 - REPO_ID = "cloneofsimo/lavenderflow-5.6B" - REPO_TYPE = "model" - -# Initialize the API -api = HfApi() - - -def get_folder_size(folder_path): - total_size = 0 - for dirpath, dirnames, filenames in os.walk(folder_path): - for f in filenames: - fp = os.path.join(dirpath, f) - if os.path.exists(fp): - total_size += os.path.getsize(fp) - return total_size - - -def upload_if_stable(folder_path, relpath, wait_time=300): - """Waits for the folder size to stabilize before uploading.""" - size1 = get_folder_size(folder_path) - time.sleep(wait_time) - size2 = get_folder_size(folder_path) - - bname = f"highres-{relpath}" - - if size1 == size2: - print(f"Uploading {folder_path} to Hugging Face Hub.") - try: - create_branch(REPO_ID, repo_type=REPO_TYPE, branch=bname) - except: - pass - - api.upload_folder( - folder_path=folder_path, - repo_id=REPO_ID, - repo_type=REPO_TYPE, - revision=bname, - ) - print(f"Uploaded {folder_path} successfully.") - - # delete the folder - os.system(f"rm -rf {folder_path}") - return True - - return False - - -def monitor_ckpt_dirs(): - known_folders = set() - - while True: - current_folders = set(os.listdir(CKPT_DIRS)) - new_folders = current_folders - known_folders - - new_folders = list(new_folders) - # sort based on model_xxxx - new_folders.sort(key=lambda x: int(x.split("_")[1]), reverse=True) - - for folder in new_folders: - folder_path = os.path.join(CKPT_DIRS, folder) - if os.path.isdir(folder_path): - print(f"Detected new folder: {folder}") - relpath = os.path.relpath(folder_path, CKPT_DIRS) - if upload_if_stable(folder_path, relpath, SLEEP_INTERVAL): - known_folders.add(folder) - - time.sleep(SLEEP_INTERVAL) - - -if __name__ == "__main__": - print("Starting to monitor for new model directories.") - monitor_ckpt_dirs()