Skip to content

Commit

Permalink
safetensors, rm dup imports, homedir portability
Browse files Browse the repository at this point in the history
  • Loading branch information
Xerxemi committed Oct 14, 2023
1 parent 858f4dc commit a269c5a
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from diffusers.loaders import AttnProcsLayers
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
import datetime
import hpsv2
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
from accelerate.logging import get_logger
from accelerate import Accelerator
Expand Down Expand Up @@ -55,7 +56,9 @@ def hps_loss_fn(inference_dtype=None, device=None):

tokenizer = get_tokenizer(model_name)

checkpoint_path = "/home/mprabhud/.cache/hpsv2/HPS_v2_compressed.pt"
checkpoint_path = f"{os.path.expanduser('~')}/.cache/hpsv2/HPS_v2_compressed.pt"
# force download of model via score
hpsv2.score([], "")

checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
Expand Down Expand Up @@ -121,7 +124,6 @@ def evaluate(latent,train_neg_prompt_embeds,prompts, pipeline, accelerator, infe
pipeline.scheduler.alphas_cumprod = pipeline.scheduler.alphas_cumprod.to(accelerator.device)
prompt_embeds = pipeline.text_encoder(prompt_ids)[0]


all_rgbs_t = []
for i, t in tqdm(enumerate(pipeline.scheduler.timesteps), total=len(pipeline.scheduler.timesteps)):
t = torch.tensor([t],
Expand Down Expand Up @@ -187,7 +189,7 @@ def main(_):
accelerator.init_trackers(
project_name="align-prop", config=config.to_dict(), init_kwargs={"wandb": wandb_args}
)
import wandb

accelerator.project_configuration.project_dir = os.path.join(config.logdir, wandb.run.name)
accelerator.project_configuration.logging_dir = os.path.join(config.logdir, wandb.run.name)

Expand All @@ -197,10 +199,11 @@ def main(_):
# set seed (device_specific is very important to get different prompts on different devices)
set_seed(config.seed, device_specific=True)

from diffusers import StableDiffusionPipeline, DDIMScheduler

# load scheduler, tokenizer and models.
pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision)
if config.pretrained.model.endswith(".safetensors") or config.pretrained.model.endswith(".ckpt"):
pipeline = StableDiffusionPipeline.from_single_file(config.pretrained.model)
else:
pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision)

# freeze parameters of models to save more memory
pipeline.vae.requires_grad_(False)
Expand Down Expand Up @@ -330,7 +333,6 @@ def load_model_hook(models, input_dir):
# Initialize the optimizer
optimizer_cls = torch.optim.AdamW


optimizer = optimizer_cls(
unet.parameters(),
lr=config.train.learning_rate,
Expand Down Expand Up @@ -393,7 +395,7 @@ def load_model_hook(models, input_dir):

if config.only_eval:
#################### EVALUATION ONLY ####################
import wandb

all_eval_images = []
all_eval_rewards = []
if config.same_evaluation:
Expand All @@ -412,7 +414,7 @@ def load_model_hook(models, input_dir):
eval_images = torch.cat(all_eval_images)
eval_image_vis = []
if accelerator.is_main_process:
import wandb

if config.run_name != "":
name_val = config.run_name
else:
Expand Down Expand Up @@ -440,7 +442,7 @@ def load_model_hook(models, input_dir):
latent = torch.randn((config.train.batch_size_per_gpu_available, 4, 64, 64), device=accelerator.device, dtype=inference_dtype)

if accelerator.is_main_process:
import wandb

logger.info(f"{wandb.run.name} Epoch {epoch}.{inner_iters}: training")


Expand All @@ -463,6 +465,7 @@ def load_model_hook(models, input_dir):
with accelerator.accumulate(unet):
with autocast():
with torch.enable_grad(): # important b/c don't have on by default in module

keep_input = True
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
t = torch.tensor([t],
Expand Down Expand Up @@ -502,7 +505,6 @@ def load_model_hook(models, input_dir):
loss = loss.sum()
loss = loss/config.train.batch_size_per_gpu_available
loss = loss * config.train.loss_coeff


rewards_mean = rewards.mean()
rewards_std = rewards.std()
Expand Down Expand Up @@ -530,7 +532,7 @@ def load_model_hook(models, input_dir):
) % config.train.gradient_accumulation_steps == 0
# log training and evaluation
if config.visualize_eval and (global_step % config.vis_freq ==0):
import wandb

all_eval_images = []
all_eval_rewards = []
if config.same_evaluation:
Expand All @@ -549,7 +551,7 @@ def load_model_hook(models, input_dir):
eval_images = torch.cat(all_eval_images)
eval_image_vis = []
if accelerator.is_main_process:
import wandb

name_val = wandb.run.name
log_dir = f"logs/{name_val}/eval_vis"
os.makedirs(log_dir, exist_ok=True)
Expand Down Expand Up @@ -593,18 +595,11 @@ def load_model_hook(models, input_dir):
global_step += 1
info = defaultdict(list)


# make sure we did an optimization step at the end of the inner epoch
assert accelerator.sync_gradients

if epoch % config.save_freq == 0 and accelerator.is_main_process:
accelerator.save_state()







if __name__ == "__main__":
app.run(main)
app.run(main)

0 comments on commit a269c5a

Please sign in to comment.