From 3831ee40296daba4bd17cb0a059dd67d37c7c308 Mon Sep 17 00:00:00 2001 From: Josh Date: Tue, 12 Sep 2023 12:40:00 -0400 Subject: [PATCH] adding input for feature explore --- gen_video.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/gen_video.py b/gen_video.py index 253360d2..502faefd 100644 --- a/gen_video.py +++ b/gen_video.py @@ -43,7 +43,7 @@ def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True #---------------------------------------------------------------------------- -def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), **video_kwargs): +def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), feature=None, **video_kwargs): grid_w = grid_dims[0] grid_h = grid_dims[1] @@ -56,6 +56,12 @@ def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind= for idx in range(num_keyframes*grid_h*grid_w): all_seeds[idx] = seeds[idx % len(seeds)] + if len(all_seeds) > 1 and feature is not None: + raise ValueError('Cannot explore a feature for more than a single image') + + if len(all_seeds) == 1 and feature is None: + raise ValueError('Must specify a feature if exploring an image') + if shuffle_seed is not None: rng = np.random.RandomState(seed=shuffle_seed) rng.shuffle(all_seeds) @@ -78,12 +84,16 @@ def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind= # Render video. video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs) + modifier = 1 for frame_idx in tqdm(range(num_keyframes * w_frames)): imgs = [] for yi in range(grid_h): for xi in range(grid_w): interp = grid[yi][xi] w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) + if feature is not None: + w[feature] = w[feature] * modifier + modifier + .01 img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0] imgs.append(img) video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) @@ -133,6 +143,7 @@ def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: @click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) @click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE') +@click.option('--feature', type=int, help='Feature to explore', default=None) def generate_images( network_pkl: str, seeds: List[int], @@ -141,7 +152,8 @@ def generate_images( grid: Tuple[int,int], num_keyframes: Optional[int], w_frames: int, - output: str + output: str, + feature: int ): """Render a latent vector interpolation video. @@ -170,7 +182,7 @@ def generate_images( with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore - gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi) + gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, feature=feature) #----------------------------------------------------------------------------