Skip to content

Commit

Permalink
feat: add options for sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
junhsss committed Mar 21, 2023
1 parent 7c6ba2f commit 36da537
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 20 deletions.
File renamed without changes.
44 changes: 38 additions & 6 deletions consistency/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __init__(
initial_ema_decay: float = 0.9,
optimizer_type: Type[optim.Optimizer] = optim.AdamW,
samples_path: str = "samples/",
save_samples_every_n_epoch: int = 10,
num_samples: int = 16,
sample_steps: int = 1,
sample_seed: int = 0,
) -> None:
super().__init__()

Expand Down Expand Up @@ -97,6 +101,10 @@ def __init__(
Path(samples_path).mkdir(exist_ok=True, parents=True)

self.samples_path = samples_path
self.save_samples_every_n_epoch = save_samples_every_n_epoch
self.num_samples = num_samples
self.sample_steps = sample_steps
self.sample_seed = sample_seed

def forward(
self,
Expand Down Expand Up @@ -229,11 +237,25 @@ def timesteps_to_times(self, timesteps: torch.LongTensor, bins: int):

@rank_zero_only
def on_train_start(self) -> None:
self.save_samples(f"{0:05}")
self.save_samples(
f"{0:05}",
num_samples=self.num_samples,
steps=self.sample_steps,
seed=self.sample_seed,
)

@rank_zero_only
def on_train_epoch_end(self) -> None:
self.save_samples(f"{(self.current_epoch+1):05}")
if (
(self.trainer.current_epoch + 1) % self.save_samples_every_n_epoch
== 0
) or self.trainer.current_epoch == (self.trainer.max_epochs - 1):
self.save_samples(
f"{(self.current_epoch+1):05}",
num_samples=self.num_samples,
steps=self.sample_steps,
seed=self.sample_seed,
)

@torch.no_grad()
def sample(
Expand Down Expand Up @@ -280,13 +302,22 @@ def sample(
device=self.device,
generator=generator,
)
images = images + math.sqrt(time**2 - self.time_min**2) * noise
images = self(images, torch.tensor([time], device=self.device))
images = (
images
+ math.sqrt(time.item() ** 2 - self.time_min**2) * noise
)
images = self(images, time[None])

return images

def save_samples(self, filename: str):
samples = self.sample()
def save_samples(
self,
filename: str,
num_samples: int = 16,
steps: int = 1,
seed: int = 0,
):
samples = self.sample(num_samples=num_samples, steps=steps, seed=seed)
samples.mul_(0.5).add_(0.5)
grid = make_grid(
samples,
Expand All @@ -308,6 +339,7 @@ def save_samples(self, filename: str):
),
},
commit=False,
step=self.trainer.global_step,
)

@staticmethod
Expand Down
48 changes: 35 additions & 13 deletions examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ def parse_args():
default=16,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--num-samples",
type=int,
default=16,
help="The number of images to generate for evaluation.",
)
parser.add_argument(
"--dataloader-num-workers",
type=int,
Expand Down Expand Up @@ -94,14 +88,35 @@ def parse_args():
default=0.9,
)
parser.add_argument(
"--log_every_n_steps",
"--log-every-n-steps",
type=int,
default=30,
default=100,
)
parser.add_argument(
"--resume-from-checkpoint",
"--sample-path",
type=str,
default=None,
default="samples/",
)
parser.add_argument(
"--save-samples-every-n-epoch",
type=int,
default=10,
)
parser.add_argument(
"--num-samples",
type=int,
default=16,
help="The number of images to generate for evaluation.",
)
parser.add_argument(
"--sample-steps",
type=int,
default=5,
)
parser.add_argument(
"--sample-seed",
type=int,
default=0,
)
args = parser.parse_args()
return args
Expand Down Expand Up @@ -136,7 +151,9 @@ def __len__(self):
return len(self.dataset)

def __getitem__(self, index: int) -> torch.Tensor:
return augmentations(self.dataset[index][self.image_key].convert("RGB"))
return augmentations(
self.dataset[index][self.image_key].convert("RGB")
)

dataloader = DataLoader(
Dataset(
Expand Down Expand Up @@ -180,11 +197,16 @@ def __getitem__(self, index: int) -> torch.Tensor:
bins_max=args.bins_max,
bins_rho=args.bins_rho,
initial_ema_decay=args.initial_ema_decay,
samples_path=args.sample_path,
save_samples_every_n_epoch=args.save_samples_every_n_epoch,
num_samples=args.num_samples,
sample_steps=args.sample_steps,
sample_seed=args.sample_seed,
)

trainer = Trainer(
accelerator="auto",
logger=WandbLogger(project="test", log_model=True),
logger=WandbLogger(project="consistency", log_model=True),
callbacks=[
ModelCheckpoint(
dirpath="ckpt",
Expand All @@ -193,7 +215,7 @@ def __getitem__(self, index: int) -> torch.Tensor:
)
],
max_epochs=args.max_epochs,
precision="16-mixed",
precision=16,
log_every_n_steps=args.log_every_n_steps,
gradient_clip_algorithm="norm",
gradient_clip_val=1.0,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

__version__ = "0.1.0"
__version__ = "0.1.1"

setup(
name="consistency",
Expand Down

0 comments on commit 36da537

Please sign in to comment.