From a269c5af788792509c0184b0828abcc90f0038ec Mon Sep 17 00:00:00 2001 From: Xerxemi <123516574+Xerxemi@users.noreply.github.com> Date: Sat, 14 Oct 2023 05:18:22 -0700 Subject: [PATCH] safetensors, rm dup imports, homedir portability --- main.py | 37 ++++++++++++++++--------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/main.py b/main.py index 696c181..c009d4d 100644 --- a/main.py +++ b/main.py @@ -20,6 +20,7 @@ from diffusers.loaders import AttnProcsLayers from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel import datetime +import hpsv2 from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer from accelerate.logging import get_logger from accelerate import Accelerator @@ -55,7 +56,9 @@ def hps_loss_fn(inference_dtype=None, device=None): tokenizer = get_tokenizer(model_name) - checkpoint_path = "/home/mprabhud/.cache/hpsv2/HPS_v2_compressed.pt" + checkpoint_path = f"{os.path.expanduser('~')}/.cache/hpsv2/HPS_v2_compressed.pt" + # force download of model via score + hpsv2.score([], "") checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['state_dict']) @@ -121,7 +124,6 @@ def evaluate(latent,train_neg_prompt_embeds,prompts, pipeline, accelerator, infe pipeline.scheduler.alphas_cumprod = pipeline.scheduler.alphas_cumprod.to(accelerator.device) prompt_embeds = pipeline.text_encoder(prompt_ids)[0] - all_rgbs_t = [] for i, t in tqdm(enumerate(pipeline.scheduler.timesteps), total=len(pipeline.scheduler.timesteps)): t = torch.tensor([t], @@ -187,7 +189,7 @@ def main(_): accelerator.init_trackers( project_name="align-prop", config=config.to_dict(), init_kwargs={"wandb": wandb_args} ) - import wandb + accelerator.project_configuration.project_dir = os.path.join(config.logdir, wandb.run.name) accelerator.project_configuration.logging_dir = os.path.join(config.logdir, wandb.run.name) @@ -197,10 +199,11 @@ def main(_): # set seed (device_specific is very important to get different prompts on different devices) set_seed(config.seed, device_specific=True) - from diffusers import StableDiffusionPipeline, DDIMScheduler - # load scheduler, tokenizer and models. - pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision) + if config.pretrained.model.endswith(".safetensors") or config.pretrained.model.endswith(".ckpt"): + pipeline = StableDiffusionPipeline.from_single_file(config.pretrained.model) + else: + pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision) # freeze parameters of models to save more memory pipeline.vae.requires_grad_(False) @@ -330,7 +333,6 @@ def load_model_hook(models, input_dir): # Initialize the optimizer optimizer_cls = torch.optim.AdamW - optimizer = optimizer_cls( unet.parameters(), lr=config.train.learning_rate, @@ -393,7 +395,7 @@ def load_model_hook(models, input_dir): if config.only_eval: #################### EVALUATION ONLY #################### - import wandb + all_eval_images = [] all_eval_rewards = [] if config.same_evaluation: @@ -412,7 +414,7 @@ def load_model_hook(models, input_dir): eval_images = torch.cat(all_eval_images) eval_image_vis = [] if accelerator.is_main_process: - import wandb + if config.run_name != "": name_val = config.run_name else: @@ -440,7 +442,7 @@ def load_model_hook(models, input_dir): latent = torch.randn((config.train.batch_size_per_gpu_available, 4, 64, 64), device=accelerator.device, dtype=inference_dtype) if accelerator.is_main_process: - import wandb + logger.info(f"{wandb.run.name} Epoch {epoch}.{inner_iters}: training") @@ -463,6 +465,7 @@ def load_model_hook(models, input_dir): with accelerator.accumulate(unet): with autocast(): with torch.enable_grad(): # important b/c don't have on by default in module + keep_input = True for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): t = torch.tensor([t], @@ -502,7 +505,6 @@ def load_model_hook(models, input_dir): loss = loss.sum() loss = loss/config.train.batch_size_per_gpu_available loss = loss * config.train.loss_coeff - rewards_mean = rewards.mean() rewards_std = rewards.std() @@ -530,7 +532,7 @@ def load_model_hook(models, input_dir): ) % config.train.gradient_accumulation_steps == 0 # log training and evaluation if config.visualize_eval and (global_step % config.vis_freq ==0): - import wandb + all_eval_images = [] all_eval_rewards = [] if config.same_evaluation: @@ -549,7 +551,7 @@ def load_model_hook(models, input_dir): eval_images = torch.cat(all_eval_images) eval_image_vis = [] if accelerator.is_main_process: - import wandb + name_val = wandb.run.name log_dir = f"logs/{name_val}/eval_vis" os.makedirs(log_dir, exist_ok=True) @@ -593,18 +595,11 @@ def load_model_hook(models, input_dir): global_step += 1 info = defaultdict(list) - # make sure we did an optimization step at the end of the inner epoch assert accelerator.sync_gradients if epoch % config.save_freq == 0 and accelerator.is_main_process: accelerator.save_state() - - - - - - if __name__ == "__main__": - app.run(main) \ No newline at end of file + app.run(main)