Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding input for feature explore #619

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions gen_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True

#----------------------------------------------------------------------------

def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), **video_kwargs):
def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), feature=None, **video_kwargs):
grid_w = grid_dims[0]
grid_h = grid_dims[1]

Expand All @@ -56,6 +56,12 @@ def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind=
for idx in range(num_keyframes*grid_h*grid_w):
all_seeds[idx] = seeds[idx % len(seeds)]

if len(all_seeds) > 1 and feature is not None:
raise ValueError('Cannot explore a feature for more than a single image')

if len(all_seeds) == 1 and feature is None:
raise ValueError('Must specify a feature if exploring an image')

if shuffle_seed is not None:
rng = np.random.RandomState(seed=shuffle_seed)
rng.shuffle(all_seeds)
Expand All @@ -78,12 +84,16 @@ def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind=

# Render video.
video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs)
modifier = 1
for frame_idx in tqdm(range(num_keyframes * w_frames)):
imgs = []
for yi in range(grid_h):
for xi in range(grid_w):
interp = grid[yi][xi]
w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
if feature is not None:
w[feature] = w[feature] * modifier
modifier + .01
img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
imgs.append(img)
video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
Expand Down Expand Up @@ -133,6 +143,7 @@ def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]:
@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE')
@click.option('--feature', type=int, help='Feature to explore', default=None)
def generate_images(
network_pkl: str,
seeds: List[int],
Expand All @@ -141,7 +152,8 @@ def generate_images(
grid: Tuple[int,int],
num_keyframes: Optional[int],
w_frames: int,
output: str
output: str,
feature: int
):
"""Render a latent vector interpolation video.

Expand Down Expand Up @@ -170,7 +182,7 @@ def generate_images(
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi)
gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, feature=feature)

#----------------------------------------------------------------------------

Expand Down