diff --git a/.github/workflows/release.yml b/.github/workflows/publish.yml similarity index 100% rename from .github/workflows/release.yml rename to .github/workflows/publish.yml diff --git a/consistency/consistency.py b/consistency/consistency.py index b360baf..ba7758e 100644 --- a/consistency/consistency.py +++ b/consistency/consistency.py @@ -49,6 +49,10 @@ def __init__( initial_ema_decay: float = 0.9, optimizer_type: Type[optim.Optimizer] = optim.AdamW, samples_path: str = "samples/", + save_samples_every_n_epoch: int = 10, + num_samples: int = 16, + sample_steps: int = 1, + sample_seed: int = 0, ) -> None: super().__init__() @@ -97,6 +101,10 @@ def __init__( Path(samples_path).mkdir(exist_ok=True, parents=True) self.samples_path = samples_path + self.save_samples_every_n_epoch = save_samples_every_n_epoch + self.num_samples = num_samples + self.sample_steps = sample_steps + self.sample_seed = sample_seed def forward( self, @@ -229,11 +237,25 @@ def timesteps_to_times(self, timesteps: torch.LongTensor, bins: int): @rank_zero_only def on_train_start(self) -> None: - self.save_samples(f"{0:05}") + self.save_samples( + f"{0:05}", + num_samples=self.num_samples, + steps=self.sample_steps, + seed=self.sample_seed, + ) @rank_zero_only def on_train_epoch_end(self) -> None: - self.save_samples(f"{(self.current_epoch+1):05}") + if ( + (self.trainer.current_epoch + 1) % self.save_samples_every_n_epoch + == 0 + ) or self.trainer.current_epoch == (self.trainer.max_epochs - 1): + self.save_samples( + f"{(self.current_epoch+1):05}", + num_samples=self.num_samples, + steps=self.sample_steps, + seed=self.sample_seed, + ) @torch.no_grad() def sample( @@ -280,13 +302,22 @@ def sample( device=self.device, generator=generator, ) - images = images + math.sqrt(time**2 - self.time_min**2) * noise - images = self(images, torch.tensor([time], device=self.device)) + images = ( + images + + math.sqrt(time.item() ** 2 - self.time_min**2) * noise + ) + images = self(images, time[None]) return images - def save_samples(self, filename: str): - samples = self.sample() + def save_samples( + self, + filename: str, + num_samples: int = 16, + steps: int = 1, + seed: int = 0, + ): + samples = self.sample(num_samples=num_samples, steps=steps, seed=seed) samples.mul_(0.5).add_(0.5) grid = make_grid( samples, @@ -308,6 +339,7 @@ def save_samples(self, filename: str): ), }, commit=False, + step=self.trainer.global_step, ) @staticmethod diff --git a/examples/train.py b/examples/train.py index 220c970..13d54e8 100644 --- a/examples/train.py +++ b/examples/train.py @@ -39,12 +39,6 @@ def parse_args(): default=16, help="Batch size (per device) for the training dataloader.", ) - parser.add_argument( - "--num-samples", - type=int, - default=16, - help="The number of images to generate for evaluation.", - ) parser.add_argument( "--dataloader-num-workers", type=int, @@ -94,14 +88,35 @@ def parse_args(): default=0.9, ) parser.add_argument( - "--log_every_n_steps", + "--log-every-n-steps", type=int, - default=30, + default=100, ) parser.add_argument( - "--resume-from-checkpoint", + "--sample-path", type=str, - default=None, + default="samples/", + ) + parser.add_argument( + "--save-samples-every-n-epoch", + type=int, + default=10, + ) + parser.add_argument( + "--num-samples", + type=int, + default=16, + help="The number of images to generate for evaluation.", + ) + parser.add_argument( + "--sample-steps", + type=int, + default=5, + ) + parser.add_argument( + "--sample-seed", + type=int, + default=0, ) args = parser.parse_args() return args @@ -136,7 +151,9 @@ def __len__(self): return len(self.dataset) def __getitem__(self, index: int) -> torch.Tensor: - return augmentations(self.dataset[index][self.image_key].convert("RGB")) + return augmentations( + self.dataset[index][self.image_key].convert("RGB") + ) dataloader = DataLoader( Dataset( @@ -180,11 +197,16 @@ def __getitem__(self, index: int) -> torch.Tensor: bins_max=args.bins_max, bins_rho=args.bins_rho, initial_ema_decay=args.initial_ema_decay, + samples_path=args.sample_path, + save_samples_every_n_epoch=args.save_samples_every_n_epoch, + num_samples=args.num_samples, + sample_steps=args.sample_steps, + sample_seed=args.sample_seed, ) trainer = Trainer( accelerator="auto", - logger=WandbLogger(project="test", log_model=True), + logger=WandbLogger(project="consistency", log_model=True), callbacks=[ ModelCheckpoint( dirpath="ckpt", @@ -193,7 +215,7 @@ def __getitem__(self, index: int) -> torch.Tensor: ) ], max_epochs=args.max_epochs, - precision="16-mixed", + precision=16, log_every_n_steps=args.log_every_n_steps, gradient_clip_algorithm="norm", gradient_clip_val=1.0, diff --git a/setup.py b/setup.py index 18e00a3..bcc32cb 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -__version__ = "0.1.0" +__version__ = "0.1.1" setup( name="consistency",