diff --git a/consistency/loss.py b/consistency/loss.py new file mode 100644 index 0000000..8c7de08 --- /dev/null +++ b/consistency/loss.py @@ -0,0 +1,19 @@ +from typing import Literal + +from torch import nn +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + + +class LPIPSLoss(nn.Module): + def __init__(self, net_type: Literal["vgg", "alex", "squeeze"] = "vgg"): + super().__init__() + self.lpips = LearnedPerceptualImagePatchSimilarity(net_type=net_type) + self.lpips.requires_grad_(False) + + @staticmethod + def clamp(x): + return x.clamp(-1, 1) + + def forward(self, input, target): + lpips_loss = self.lpips(self.clamp(input), self.clamp(target)) + return lpips_loss diff --git a/examples/train.py b/examples/train.py index 16ac50b..2d7e034 100644 --- a/examples/train.py +++ b/examples/train.py @@ -10,6 +10,7 @@ from torchvision import transforms from consistency import Consistency +from consistency.loss import LPIPSLoss def parse_args(): @@ -119,6 +120,8 @@ def parse_args(): type=int, default=0, ) + parser.add_argument("--ckpt-path", type=str) + parser.add_argument("--wandb-id", type=str) args = parser.parse_args() return args @@ -164,64 +167,137 @@ def __getitem__(self, index: int) -> torch.Tensor: num_workers=args.dataloader_num_workers, ) - consistency = Consistency( - model=UNet2DModel( - sample_size=args.resolution, - in_channels=3, - out_channels=3, - layers_per_block=2, - block_out_channels=(128, 128, 256, 256, 512, 512), - down_block_types=( - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "AttnDownBlock2D", - "DownBlock2D", + if args.ckpt_path: + consistency = Consistency.load_from_checkpoint( + checkpoint_path=args.ckpt_path, + model=UNet2DModel( + sample_size=args.resolution, + in_channels=3, + out_channels=3, + layers_per_block=2, + block_out_channels=(128, 128, 256, 256, 512, 512), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), ), - up_block_types=( - "UpBlock2D", - "AttnUpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - ), - ), - learning_rate=args.learning_rate, - data_std=args.data_std, - time_min=args.time_min, - time_max=args.time_max, - bins_min=args.bins_min, - bins_max=args.bins_max, - bins_rho=args.bins_rho, - initial_ema_decay=args.initial_ema_decay, - samples_path=args.sample_path, - save_samples_every_n_epoch=args.save_samples_every_n_epoch, - num_samples=args.num_samples, - sample_steps=args.sample_steps, - sample_ema=args.sample_ema, - sample_seed=args.sample_seed, - ) - - trainer = Trainer( - accelerator="auto", - logger=WandbLogger(project="consistency", log_model=True), - callbacks=[ - ModelCheckpoint( - dirpath="ckpt", - save_top_k=3, - monitor="loss", + loss_fn=LPIPSLoss(), + learning_rate=args.learning_rate, + data_std=args.data_std, + time_min=args.time_min, + time_max=args.time_max, + bins_min=args.bins_min, + bins_max=args.bins_max, + bins_rho=args.bins_rho, + initial_ema_decay=args.initial_ema_decay, + samples_path=args.sample_path, + save_samples_every_n_epoch=args.save_samples_every_n_epoch, + num_samples=args.num_samples, + sample_steps=args.sample_steps, + sample_ema=args.sample_ema, + sample_seed=args.sample_seed, + ) + + trainer = Trainer( + accelerator="auto", + logger=WandbLogger( + project="consistency", + log_model=True, + id=args.wandb_id, + resume="must", ) - ], - max_epochs=args.max_epochs, - precision=16, - log_every_n_steps=args.log_every_n_steps, - gradient_clip_algorithm="norm", - gradient_clip_val=1.0, - ) + if args.wandb_id + else WandbLogger( + project="consistency", + log_model=True, + ), + callbacks=[ + ModelCheckpoint( + dirpath="ckpt", + save_top_k=3, + monitor="loss", + ) + ], + max_epochs=args.max_epochs, + precision=16, + log_every_n_steps=args.log_every_n_steps, + gradient_clip_algorithm="norm", + gradient_clip_val=1.0, + ) + trainer.fit(consistency, dataloader, ckpt_path=args.ckpt_path) + + else: + consistency = Consistency( + model=UNet2DModel( + sample_size=args.resolution, + in_channels=3, + out_channels=3, + layers_per_block=2, + block_out_channels=(128, 128, 256, 256, 512, 512), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ), + loss_fn=LPIPSLoss(), + learning_rate=args.learning_rate, + data_std=args.data_std, + time_min=args.time_min, + time_max=args.time_max, + bins_min=args.bins_min, + bins_max=args.bins_max, + bins_rho=args.bins_rho, + initial_ema_decay=args.initial_ema_decay, + samples_path=args.sample_path, + save_samples_every_n_epoch=args.save_samples_every_n_epoch, + num_samples=args.num_samples, + sample_steps=args.sample_steps, + sample_ema=args.sample_ema, + sample_seed=args.sample_seed, + ) + + trainer = Trainer( + accelerator="auto", + logger=WandbLogger(project="consistency", log_model=True), + callbacks=[ + ModelCheckpoint( + dirpath="ckpt", + save_top_k=3, + monitor="loss", + ) + ], + max_epochs=args.max_epochs, + precision=16, + log_every_n_steps=args.log_every_n_steps, + gradient_clip_algorithm="norm", + gradient_clip_val=1.0, + ) - trainer.fit(consistency, dataloader) + trainer.fit(consistency, dataloader) if __name__ == "__main__": diff --git a/setup.py b/setup.py index d5cbd07..0cb834a 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import find_packages, setup -__version__ = "0.1.2" +__version__ = "0.2.0" setup( name="consistency", @@ -19,6 +19,8 @@ "torchvision", "pytorch-lightning", "diffusers", + "torchmetrics", + "lpips", ], classifiers=[ "Development Status :: 4 - Beta",