From 8b5049e6de062931c443d5580de3b81d89239fb6 Mon Sep 17 00:00:00 2001 From: ehsan Date: Sun, 28 Sep 2025 09:31:49 +0330 Subject: [PATCH 1/2] feat(utils): add text conditioning and memory diagnostics --- tests/test_utils.py | 125 +++++++++++++++++++++ wan/utils/memory_diag.py | 173 +++++++++++++++++++++++++++++ wan/utils/text_conditioning.py | 113 +++++++++++++++++++ wan/utils/utils.py | 192 +++++++++++++++++++++++++-------- 4 files changed, 557 insertions(+), 46 deletions(-) create mode 100644 tests/test_utils.py create mode 100644 wan/utils/memory_diag.py create mode 100644 wan/utils/text_conditioning.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..9459e000 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,125 @@ +import importlib.util +import os +from pathlib import Path +from unittest import mock + +import pytest +import torch + + +def _load_module(name: str, relative_path: str): + module_path = Path(__file__).resolve().parents[1] / relative_path + spec = importlib.util.spec_from_file_location(name, module_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) # type: ignore[attr-defined] + return module + + +memory_diag = _load_module("memory_diag", "wan/utils/memory_diag.py") +text_conditioning = _load_module("text_conditioning", "wan/utils/text_conditioning.py") +utils_mod = _load_module("wan_utils", "wan/utils/utils.py") + + +class _DummyT5Model(torch.nn.Module): + def forward(self, *args, **kwargs): + raise NotImplementedError + + +class _DummyTextEncoder: + def __init__(self): + super().__init__() + self.model = _DummyT5Model() + self.tokenizer = object() + + def __call__(self, prompts, device): + return [ + torch.ones(1, 2, device=device) * (idx + 1) for idx in range(len(prompts)) + ] + + +class _DummyPipe: + def __init__(self): + self.text_encoder = _DummyTextEncoder() + self.device = torch.device("cpu") + + +def test_prepare_text_conditioning_offloads_encoder(): + pipe = _DummyPipe() + + prompt_embeds, negative_embeds = text_conditioning.prepare_text_conditioning( + pipe, + prompts=["hello"], + negative_prompts=["world"], + precision="fp16", + device=pipe.device, + offload_strategy="set_none", + print_memory=False, + ) + + assert pipe.text_encoder is None + assert len(prompt_embeds) == 1 + assert prompt_embeds[0].dtype == torch.float16 + assert prompt_embeds[0].device == pipe.device + assert negative_embeds is not None + assert negative_embeds[0].dtype == torch.float16 + + +def test_assert_text_encoder_off_gpu_accepts_cpu(): + pipe = _DummyPipe() + pipe.text_encoder.model.to("cpu") + memory_diag.assert_text_encoder_off_gpu(pipe) + + +def test_track_cuda_memory_no_cuda(): + with memory_diag.track_cuda_memory("unit-test"): + pass + + +def test_download_cosyvoice_repo_skips_clone_when_exists(tmp_path): + repo_dir = tmp_path / "CosyVoice" + repo_dir.mkdir() + + with mock.patch.object(utils_mod.subprocess, "check_call") as check_call: + path = utils_mod.download_cosyvoice_repo(str(repo_dir)) + + assert path == repo_dir + check_call.assert_not_called() + + +def test_download_cosyvoice_repo_clones_when_missing(tmp_path): + repo_dir = tmp_path / "CosyVoice" + + def _fake_clone(cmd): + repo_dir.mkdir() + return 0 + + with mock.patch.object( + utils_mod.subprocess, "check_call", side_effect=_fake_clone + ) as check_call: + path = utils_mod.download_cosyvoice_repo(str(repo_dir)) + + assert path == repo_dir + assert repo_dir.exists() + check_call.assert_called_once() + + +def test_download_cosyvoice_model_uses_snapshot(tmp_path): + pytest.importorskip("huggingface_hub") + target = tmp_path / "CosyVoiceModel" + + def _fake_snapshot(*_, **__): + os.makedirs(target, exist_ok=True) + (target / "config.json").write_text("{}", encoding="utf-8") + return str(target) + + with mock.patch( + "huggingface_hub.snapshot_download", side_effect=_fake_snapshot + ) as snap: + path = utils_mod.download_cosyvoice_model( + "CosyVoice2-0.5B", str(target), repo_id="test/model" + ) + + assert path == target + assert (target / "config.json").exists() + snap.assert_called_once() diff --git a/wan/utils/memory_diag.py b/wan/utils/memory_diag.py new file mode 100644 index 00000000..e67414dd --- /dev/null +++ b/wan/utils/memory_diag.py @@ -0,0 +1,173 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import contextlib +import gc +import logging +from typing import Iterator, List, Tuple + +import torch + + +def _bytes(n: int) -> str: + """Format bytes into a compact human-readable string.""" + for unit in ["B", "KB", "MB", "GB", "TB"]: + if n < 1024: + return f"{n:.0f} {unit}" + n /= 1024 + return f"{n:.0f} PB" + + +def _collect_cuda_tensors() -> List[Tuple[int, str, str, str]]: + """Return a list of CUDA tensors with their sizes. + + Each item: (num_bytes, shape_str, dtype_str, device_str) + """ + tensors: List[Tuple[int, str, str, str]] = [] + try: + for obj in gc.get_objects(): + try: + if isinstance(obj, torch.Tensor) and obj.is_cuda: + nbytes = obj.element_size() * obj.nelement() + tensors.append( + (nbytes, str(tuple(obj.shape)), str(obj.dtype), str(obj.device)) + ) + except Exception: + # Be defensive: ignore objects we cannot inspect + continue + except Exception: + # Fallback if gc.get_objects() is unavailable or restricted + pass + # Sort largest first + tensors.sort(key=lambda x: x[0], reverse=True) + return tensors + + +@contextlib.contextmanager +def track_cuda_memory( + tag: str, show_processes: bool = False, top_k: int = 10 +) -> Iterator[None]: + """ + Context manager that reports CUDA memory before and after the block. + + - Prints allocated and reserved bytes. + - Prints the top-K CUDA tensors by size (best-effort via gc). + - Prints torch.cuda.memory_summary() for detailed allocator info. + - Optionally prints torch.cuda.list_gpu_processes() if available. + """ + if not torch.cuda.is_available(): + logging.info( + f"[track_cuda_memory:{tag}] CUDA not available; skipping diagnostics." + ) + yield + return + + device = torch.device("cuda") + torch.cuda.synchronize() + alloc_before = torch.cuda.memory_allocated(device) + reserv_before = torch.cuda.memory_reserved(device) + + logging.info( + f"[track_cuda_memory:{tag}] BEFORE | allocated={_bytes(alloc_before)}, reserved={_bytes(reserv_before)}" + ) + + tensors_before = _collect_cuda_tensors() + if tensors_before: + logging.info( + f"[track_cuda_memory:{tag}] Top CUDA tensors before (n={min(len(tensors_before), top_k)}):" + ) + for nbytes, shape, dtype, dev in tensors_before[:top_k]: + logging.info(f" - {shape} {dtype} on {dev} | {_bytes(nbytes)}") + + try: + logging.info( + f"[track_cuda_memory:{tag}] memory_summary BEFORE:\n{torch.cuda.memory_summary(device=device)}" + ) + except Exception: + pass + if show_processes and hasattr(torch.cuda, "list_gpu_processes"): + try: + logging.info( + f"[track_cuda_memory:{tag}] GPU processes BEFORE:\n{torch.cuda.list_gpu_processes()} " + ) + except Exception: + pass + + try: + yield + finally: + torch.cuda.synchronize() + alloc_after = torch.cuda.memory_allocated(device) + reserv_after = torch.cuda.memory_reserved(device) + logging.info( + f"[track_cuda_memory:{tag}] AFTER | allocated={_bytes(alloc_after)}, reserved={_bytes(reserv_after)} (Δalloc={_bytes(alloc_after-alloc_before)}, Δres={_bytes(reserv_after-reserv_before)})" + ) + + tensors_after = _collect_cuda_tensors() + if tensors_after: + logging.info( + f"[track_cuda_memory:{tag}] Top CUDA tensors after (n={min(len(tensors_after), top_k)}):" + ) + for nbytes, shape, dtype, dev in tensors_after[:top_k]: + logging.info(f" - {shape} {dtype} on {dev} | {_bytes(nbytes)}") + + try: + logging.info( + f"[track_cuda_memory:{tag}] memory_summary AFTER:\n{torch.cuda.memory_summary(device=device)}" + ) + except Exception: + pass + if show_processes and hasattr(torch.cuda, "list_gpu_processes"): + try: + logging.info( + f"[track_cuda_memory:{tag}] GPU processes AFTER:\n{torch.cuda.list_gpu_processes()} " + ) + except Exception: + pass + + +def assert_text_encoder_off_gpu(pipe) -> None: + """Assert that the pipeline's text encoder is not resident on any CUDA device. + + Conditions: + - pipe.text_encoder is None OR + - All parameters/buffers of pipe.text_encoder (and its submodules / .model) are on CPU. + Raises with a helpful message if violated. + """ + te = getattr(pipe, "text_encoder", None) + if te is None: + return + + # Helper to iterate parameters/buffers from various wrappers + def _iter_tensors(obj): + try: + if hasattr(obj, "parameters"): + for p in obj.parameters(recurse=True): + yield p + if hasattr(obj, "buffers"): + for b in obj.buffers(recurse=True): + yield b + except Exception: + return + + # Check the encoder itself and a nested `.model` attribute commonly used in wrappers + suspects = [te] + if hasattr(te, "model"): + suspects.append(te.model) + + on_cuda = [] + for s in suspects: + for t in _iter_tensors(s): + try: + if t.is_cuda: + on_cuda.append(t) + except Exception: + continue + + if on_cuda: + # Build a concise and actionable error message + n = len(on_cuda) + example = on_cuda[0] + raise RuntimeError( + "text_encoder tensors detected on CUDA: " + f"found {n} tensor(s), example device={example.device}, shape={tuple(example.shape)}, dtype={example.dtype}.\n" + "Ensure you compute embeddings once, then offload the encoder (e.g., move to CPU and set pipe.text_encoder=None)." + ) diff --git a/wan/utils/text_conditioning.py b/wan/utils/text_conditioning.py new file mode 100644 index 00000000..5797a44b --- /dev/null +++ b/wan/utils/text_conditioning.py @@ -0,0 +1,113 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +import logging +from typing import List, Optional, Tuple + +import torch + + +def print_cuda_mem_summary(device: torch.device) -> None: + """Prints CUDA memory summary if available, guarded for safety.""" + if torch.cuda.is_available(): + try: + logging.info( + "CUDA memory summary:\n" + torch.cuda.memory_summary(device=device) + ) + except Exception: + # Avoid crashing on some driver/toolkit combos + pass + + +def prepare_text_conditioning( + pipe, + prompts: List[str], + negative_prompts: Optional[List[str]] = None, + precision: str = "fp16", + device: Optional[torch.device] = None, + offload_strategy: str = "set_none", + print_memory: bool = True, +) -> Tuple[List[torch.Tensor], Optional[List[torch.Tensor]]]: + """ + Compute text embeddings once and safely offload the T5 encoder from GPU. + + - Enforces eval + torch.inference_mode() to avoid grads. + - Casts to fp16/bf16 as requested. + - Moves embeddings to the UNet/video model device. + - Immediately offloads or removes `pipe.text_encoder` and clears caches. + + Args: + pipe: Diffusers-style pipeline object. Must have `text_encoder` and `tokenizer` attrs. + prompts: List of prompts for conditioning. + negative_prompts: List of negative prompts for CFG, optional. + precision: "fp16" or "bf16". + device: Target device for embeddings. Defaults to the DiT device if None. + offload_strategy: One of {"cpu", "set_none"}. "set_none" breaks references. + print_memory: Print torch.cuda.memory_summary() before/after offload. + + Returns: + (prompt_embeds, negative_prompt_embeds) + """ + assert hasattr(pipe, "text_encoder"), "pipe must expose .text_encoder" + + # Resolve target device to match the UNet/DiT device + if device is None: + # Heuristic: Wan pipelines expose .device on the class; fallback to current CUDA device + device = getattr( + pipe, "device", torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) + + # No global grad graph for encoder + target_dtype = torch.float16 if precision == "fp16" else torch.bfloat16 + text_encoder = getattr(pipe, "text_encoder", None) + if text_encoder is None: + raise RuntimeError( + "pipe.text_encoder is None; cannot compute embeddings. Provide them externally." + ) + + if print_memory: + print_cuda_mem_summary(device) + + with torch.inference_mode(): + # Ensure the encoder is in eval() and placed appropriately for compute + if hasattr(text_encoder, "model"): + text_encoder.model.eval() + # Place model on compute device for speed + if hasattr(text_encoder, "model") and device.type == "cuda": + text_encoder.model.to(device) + + # The Wan T5 wrapper accepts a python list of strings + prompt_embeds = text_encoder(prompts, device) + negative_prompt_embeds = None + if negative_prompts is not None: + negative_prompt_embeds = text_encoder(negative_prompts, device) + + # Cast + move to target device and dtype + prompt_embeds = [t.to(dtype=target_dtype, device=device) for t in prompt_embeds] + if negative_prompt_embeds is not None: + negative_prompt_embeds = [ + t.to(dtype=target_dtype, device=device) for t in negative_prompt_embeds + ] + + # Immediately offload encoder to free VRAM and break references + try: + if hasattr(text_encoder, "model"): + text_encoder.model.to("cpu") + except Exception: + pass + + # Break references if requested to allow GC to reclaim memory + if offload_strategy == "set_none": + try: + pipe.text_encoder = None + # Keep tokenizer for later tokenization (it is CPU resident) + except Exception: + pass + + # Clear caches and print memory summaries + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if print_memory: + print_cuda_mem_summary(device) + + return prompt_embeds, negative_prompt_embeds diff --git a/wan/utils/utils.py b/wan/utils/utils.py index c563c698..ba165f04 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -4,76 +4,85 @@ import logging import os import os.path as osp +import subprocess +from pathlib import Path +from typing import Optional import imageio import torch import torchvision -__all__ = ['save_video', 'save_image', 'str2bool'] +__all__ = [ + "save_video", + "save_image", + "str2bool", + "download_cosyvoice_repo", + "download_cosyvoice_model", +] -def rand_name(length=8, suffix=''): - name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') +def rand_name(length=8, suffix=""): + name = binascii.b2a_hex(os.urandom(length)).decode("utf-8") if suffix: - if not suffix.startswith('.'): - suffix = '.' + suffix + if not suffix.startswith("."): + suffix = "." + suffix name += suffix return name -def save_video(tensor, - save_file=None, - fps=30, - suffix='.mp4', - nrow=8, - normalize=True, - value_range=(-1, 1)): +def save_video( + tensor, + save_file=None, + fps=30, + suffix=".mp4", + nrow=8, + normalize=True, + value_range=(-1, 1), +): # cache file - cache_file = osp.join('/tmp', rand_name( - suffix=suffix)) if save_file is None else save_file + cache_file = ( + osp.join("/tmp", rand_name(suffix=suffix)) if save_file is None else save_file + ) # save to cache try: # preprocess tensor = tensor.clamp(min(value_range), max(value_range)) - tensor = torch.stack([ - torchvision.utils.make_grid( - u, nrow=nrow, normalize=normalize, value_range=value_range) - for u in tensor.unbind(2) - ], - dim=1).permute(1, 2, 3, 0) + tensor = torch.stack( + [ + torchvision.utils.make_grid( + u, nrow=nrow, normalize=normalize, value_range=value_range + ) + for u in tensor.unbind(2) + ], + dim=1, + ).permute(1, 2, 3, 0) tensor = (tensor * 255).type(torch.uint8).cpu() # write video - writer = imageio.get_writer( - cache_file, fps=fps, codec='libx264', quality=8) + writer = imageio.get_writer(cache_file, fps=fps, codec="libx264", quality=8) for frame in tensor.numpy(): writer.append_data(frame) writer.close() except Exception as e: - logging.info(f'save_video failed, error: {e}') + logging.info(f"save_video failed, error: {e}") def save_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1)): # cache file suffix = osp.splitext(save_file)[1] - if suffix.lower() not in [ - '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' - ]: - suffix = '.png' + if suffix.lower() not in [".jpg", ".jpeg", ".png", ".tiff", ".gif", ".webp"]: + suffix = ".png" # save to cache try: tensor = tensor.clamp(min(value_range), max(value_range)) torchvision.utils.save_image( - tensor, - save_file, - nrow=nrow, - normalize=normalize, - value_range=value_range) + tensor, save_file, nrow=nrow, normalize=normalize, value_range=value_range + ) return save_file except Exception as e: - logging.info(f'save_image failed, error: {e}') + logging.info(f"save_image failed, error: {e}") def str2bool(v): @@ -95,12 +104,12 @@ def str2bool(v): if isinstance(v, bool): return v v_lower = v.lower() - if v_lower in ('yes', 'true', 't', 'y', '1'): + if v_lower in ("yes", "true", "t", "y", "1"): return True - elif v_lower in ('no', 'false', 'f', 'n', '0'): + elif v_lower in ("no", "false", "f", "n", "0"): return False else: - raise argparse.ArgumentTypeError('Boolean value expected (True/False)') + raise argparse.ArgumentTypeError("Boolean value expected (True/False)") def masks_like(tensor, zero=False, generator=None, p=0.2): @@ -113,14 +122,20 @@ def masks_like(tensor, zero=False, generator=None, p=0.2): if generator is not None: for u, v in zip(out1, out2): random_num = torch.rand( - 1, generator=generator, device=generator.device).item() + 1, generator=generator, device=generator.device + ).item() if random_num < p: - u[:, 0] = torch.normal( - mean=-3.5, - std=0.5, - size=(1,), - device=u.device, - generator=generator).expand_as(u[:, 0]).exp() + u[:, 0] = ( + torch.normal( + mean=-3.5, + std=0.5, + size=(1,), + device=u.device, + generator=generator, + ) + .expand_as(u[:, 0]) + .exp() + ) v[:, 0] = torch.zeros_like(v[:, 0]) else: u[:, 0] = u[:, 0] @@ -136,7 +151,7 @@ def masks_like(tensor, zero=False, generator=None, p=0.2): def best_output_size(w, h, dw, dh, expected_area): # float output size ratio = w / h - ow = (expected_area * ratio)**0.5 + ow = (expected_area * ratio) ** 0.5 oh = expected_area / ow # process width first @@ -152,8 +167,93 @@ def best_output_size(w, h, dw, dh, expected_area): ratio2 = ow2 / oh2 # compare ratios - if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, - ratio2 / ratio): + if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio): return ow1, oh1 else: return ow2, oh2 + + +def download_cosyvoice_repo(target_dir: str, repo_url: Optional[str] = None) -> Path: + """Clone the CosyVoice repository if it is not already present. + + Args: + target_dir: Directory where the repository should live. + repo_url: Optional override for the repository URL. Defaults to the + official upstream but can be customised via the + ``WAN_COSYVOICE_REPO_URL`` environment variable. + + Returns: + Path to the checked-out repository. + """ + + destination = Path(target_dir) + if destination.exists(): + return destination + + resolved_url = repo_url or os.getenv( + "WAN_COSYVOICE_REPO_URL", + "https://github.com/alibaba-damo-academy/CosyVoice.git", + ) + + destination.parent.mkdir(parents=True, exist_ok=True) + try: + subprocess.check_call( + ["git", "clone", "--depth", "1", resolved_url, str(destination)] + ) + except (FileNotFoundError, subprocess.CalledProcessError) as exc: + raise RuntimeError( + f"Failed to clone CosyVoice repository from {resolved_url}. " + "Install git or set WAN_COSYVOICE_REPO_URL to a valid mirror, " + "or download the repository manually." + ) from exc + + return destination + + +def download_cosyvoice_model( + model_name: str, target_dir: str, repo_id: Optional[str] = None +) -> Path: + """Download CosyVoice model weights from Hugging Face if missing. + + Args: + model_name: Name of the model folder to materialise locally. + target_dir: Destination directory for the weights. + repo_id: Optional Hugging Face repository id. Defaults to + ``iic/`` and can be overridden with the + ``WAN_COSYVOICE_MODEL_REPO`` environment variable. + + Returns: + Path to the directory containing the weights. + """ + + destination = Path(target_dir) + if destination.exists() and any(destination.iterdir()): + return destination + + resolved_repo = repo_id or os.getenv( + "WAN_COSYVOICE_MODEL_REPO", f"iic/{model_name}" + ) + + try: + from huggingface_hub import snapshot_download + except ImportError as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "huggingface-hub is required to download CosyVoice weights. " + "Install it or provide the weights manually." + ) from exc + + destination.parent.mkdir(parents=True, exist_ok=True) + try: + snapshot_download( + repo_id=resolved_repo, + local_dir=str(destination), + local_dir_use_symlinks=False, + ) + except Exception as exc: # pragma: no cover - network dependent + raise RuntimeError( + f"Failed to download CosyVoice weights from {resolved_repo}. " + "Set WAN_COSYVOICE_MODEL_REPO to a valid repository or place " + "the weights manually." + ) from exc + + return destination From f34b8fc980c1fd27b5434a67f4ef9da03aa55e99 Mon Sep 17 00:00:00 2001 From: ehsan Date: Sun, 28 Sep 2025 09:33:09 +0330 Subject: [PATCH 2/2] feat(pipelines): add speech-to-video and animation tasks --- INSTALL.md | 4 + README.md | 47 +- generate.py | 626 +++++-- requirements.txt | 15 +- requirements_animate.txt | 2 + requirements_s2v.txt | 12 + wan/__init__.py | 13 + wan/animate.py | 733 ++++++++ wan/configs/__init__.py | 44 +- wan/configs/wan_animate_14B.py | 40 + wan/configs/wan_s2v_14B.py | 58 + wan/image2video.py | 295 ++-- wan/modules/animate/__init__.py | 5 + wan/modules/animate/animate_utils.py | 143 ++ wan/modules/animate/clip.py | 586 +++++++ wan/modules/animate/face_blocks.py | 419 +++++ wan/modules/animate/model_animate.py | 530 ++++++ wan/modules/animate/motion_encoder.py | 351 ++++ wan/modules/animate/preprocess/UserGuider.md | 70 + wan/modules/animate/preprocess/__init__.py | 3 + .../animate/preprocess/human_visualization.py | 1485 +++++++++++++++++ wan/modules/animate/preprocess/pose2d.py | 505 ++++++ .../animate/preprocess/pose2d_utils.py | 1202 +++++++++++++ .../animate/preprocess/preprocess_data.py | 140 ++ .../animate/preprocess/process_pipepline.py | 486 ++++++ .../animate/preprocess/retarget_pose.py | 1166 +++++++++++++ wan/modules/animate/preprocess/sam_utils.py | 157 ++ wan/modules/animate/preprocess/utils.py | 250 +++ .../animate/preprocess/video_predictor.py | 158 ++ wan/modules/animate/xlm_roberta.py | 175 ++ wan/modules/s2v/__init__.py | 5 + wan/modules/s2v/audio_encoder.py | 193 +++ wan/modules/s2v/audio_utils.py | 119 ++ wan/modules/s2v/auxi_blocks.py | 239 +++ wan/modules/s2v/model_s2v.py | 964 +++++++++++ wan/modules/s2v/motioner.py | 865 ++++++++++ wan/modules/s2v/s2v_utils.py | 86 + wan/speech2video.py | 730 ++++++++ wan/text2video.py | 269 ++- wan/textimage2video.py | 438 +++-- 40 files changed, 13204 insertions(+), 424 deletions(-) create mode 100644 requirements_animate.txt create mode 100644 requirements_s2v.txt create mode 100644 wan/animate.py create mode 100644 wan/configs/wan_animate_14B.py create mode 100644 wan/configs/wan_s2v_14B.py create mode 100644 wan/modules/animate/__init__.py create mode 100644 wan/modules/animate/animate_utils.py create mode 100644 wan/modules/animate/clip.py create mode 100644 wan/modules/animate/face_blocks.py create mode 100644 wan/modules/animate/model_animate.py create mode 100644 wan/modules/animate/motion_encoder.py create mode 100644 wan/modules/animate/preprocess/UserGuider.md create mode 100644 wan/modules/animate/preprocess/__init__.py create mode 100644 wan/modules/animate/preprocess/human_visualization.py create mode 100644 wan/modules/animate/preprocess/pose2d.py create mode 100644 wan/modules/animate/preprocess/pose2d_utils.py create mode 100644 wan/modules/animate/preprocess/preprocess_data.py create mode 100644 wan/modules/animate/preprocess/process_pipepline.py create mode 100644 wan/modules/animate/preprocess/retarget_pose.py create mode 100644 wan/modules/animate/preprocess/sam_utils.py create mode 100644 wan/modules/animate/preprocess/utils.py create mode 100644 wan/modules/animate/preprocess/video_predictor.py create mode 100644 wan/modules/animate/xlm_roberta.py create mode 100644 wan/modules/s2v/__init__.py create mode 100644 wan/modules/s2v/audio_encoder.py create mode 100644 wan/modules/s2v/audio_utils.py create mode 100644 wan/modules/s2v/auxi_blocks.py create mode 100644 wan/modules/s2v/model_s2v.py create mode 100644 wan/modules/s2v/motioner.py create mode 100644 wan/modules/s2v/s2v_utils.py create mode 100644 wan/speech2video.py diff --git a/INSTALL.md b/INSTALL.md index 14c62958..2002f048 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -5,6 +5,10 @@ ```bash pip install . pip install .[dev] # Installe aussi les outils de dev + +# Optional extras +pip install -r requirements_s2v.txt # Speech-to-video audio/TTS stack +pip install -r requirements_animate.txt # Animation preprocessing stack ``` ## Install with Poetry diff --git a/README.md b/README.md index 96fa8ccb..02c7a624 100644 --- a/README.md +++ b/README.md @@ -230,6 +230,52 @@ torchrun --nproc_per_node=8 generate.py --task ti2v-5B --size 1280*704 --ckpt_di +#### Run Speech-to-Video Generation + +The repository also provides the `Wan2.2-S2V-14B` speech-to-video pipeline. It animates a reference portrait using either a driving audio clip or the integrated CosyVoice text-to-speech backend. + +- Speech-driven generation with an audio file + +```sh +python generate.py \ + --task s2v-14B \ + --size 1280*720 \ + --ckpt_dir ./Wan2.2-S2V-14B \ + --image examples/pose.png \ + --audio examples/talk.wav \ + --pose_video examples/pose.mp4 \ + --prompt "A charismatic presenter greeting the audience with confident gestures." \ + --offload_model True --convert_model_dtype +``` + +> Optional flags: +> - `--enable_tts` synthesises driving audio via CosyVoice when no `--audio` is provided. Pair it with `--tts_prompt_audio` (speaker reference) and `--tts_text` (target transcript). +> - `--num_repeat` controls how many clips to render for long speeches. The default is chosen automatically from the audio length. + +Install the extra audio/TTS dependencies with `pip install -r requirements_s2v.txt` before running CosyVoice-based workflows. + + +#### Run Character Animation Generation + +`Wan2.2-Animate-14B` produces pose-driven character animation. It expects a preprocessing folder containing the extracted pose (`src_pose.mp4`), facial reference (`src_face.mp4`), clean background (`src_bg.mp4`), masks (`src_mask.mp4`), and a reference key frame (`src_ref.png`). Example assets can be found under `examples/wan_animate/`. + +```sh +python generate.py \ + --task animate-14B \ + --ckpt_dir ./Wan2.2-Animate-14B \ + --animate_src_root examples/wan_animate/animate \ + --prompt "视频中的人在做动作" \ + --animate_refer_frames 5 \ + --offload_model True --convert_model_dtype +``` + +> Useful options: +> - `--animate_replace` enables background replacement when the preprocessing folder contains foreground/background sequences. +> - `--animate_clip_len` adjusts the temporal window per inference chunk (default 77 frames). + +Dependencies for the preprocessing toolkit reside in `requirements_animate.txt` (including SAM2). Install them when you need to run the pose extraction pipeline. + + ## Computational Efficiency on Different GPUs We test the computational efficiency of different **Wan2.2** models on different GPUs in the following table. The results are presented in the format: **Total time (s) / peak GPU memory (GB)**. @@ -312,4 +358,3 @@ We would like to thank the contributors to the [SD3](https://huggingface.co/stab ## Contact Us If you would like to leave a message to our research or product teams, feel free to join our [Discord](https://discord.gg/AKNgpMK4Yj) or [WeChat groups](https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg)! - diff --git a/generate.py b/generate.py index c3e58167..1d55806e 100644 --- a/generate.py +++ b/generate.py @@ -6,7 +6,7 @@ import warnings from datetime import datetime -warnings.filterwarnings('ignore') +warnings.filterwarnings("ignore") import random @@ -22,18 +22,26 @@ EXAMPLE_PROMPT = { "t2v-A14B": { - "prompt": - "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, "i2v-A14B": { - "prompt": - "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", - "image": - "examples/i2v_input.JPG", + "prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + "image": "examples/i2v_input.JPG", }, "ti2v-5B": { - "prompt": - "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "s2v-14B": { + "prompt": "A charismatic presenter greeting the audience with confident gestures.", + "image": "examples/pose.png", + "audio": "examples/talk.wav", + "pose_video": "examples/pose.mp4", + "tts_prompt_audio": "examples/zero_shot_prompt.wav", + "tts_text": "你好,很高兴见到你。", + }, + "animate-14B": { + "prompt": "视频中的人在做动作", + "animate_src_root": "examples/wan_animate/animate", }, } @@ -44,34 +52,96 @@ def _validate_args(args): assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" + example = EXAMPLE_PROMPT[args.task] if args.prompt is None: - args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] - if args.image is None and "image" in EXAMPLE_PROMPT[args.task]: - args.image = EXAMPLE_PROMPT[args.task]["image"] - - if args.task == "i2v-A14B": - assert args.image is not None, "Please specify the image path for i2v." + args.prompt = example.get("prompt", args.prompt) + + # Populate optional inputs from examples when omitted + for field in ( + "image", + "audio", + "pose_video", + "tts_prompt_audio", + "tts_prompt_text", + "tts_text", + "animate_src_root", + ): + if getattr(args, field, None) is None and field in example: + setattr(args, field, example[field]) + + task_lower = args.task.lower() + + # Task specific validations + if "i2v" in task_lower or "ti2v" in task_lower or "s2v" in task_lower: + assert ( + args.image is not None + ), f"Task {args.task} requires --image. Provide a reference image path." + + if "s2v" in task_lower: + if not args.enable_tts: + assert ( + args.audio is not None + ), "Speech-to-video requires --audio when TTS is disabled." + else: + assert ( + args.tts_text + ), "Provide --tts_text when enabling text-to-speech synthesis." + assert ( + args.tts_prompt_audio is not None + ), "Provide --tts_prompt_audio with a reference speaker clip for TTS." + if args.num_repeat is not None: + assert args.num_repeat > 0, "--num_repeat should be a positive integer." + + if "animate" in task_lower: + assert ( + args.animate_src_root is not None + ), "Wan Animate requires --animate_src_root pointing to the preprocessed folder." + if args.animate_refer_frames is not None: + assert args.animate_refer_frames in ( + 1, + 5, + ), "--animate_refer_frames must be 1 or 5 for Wan Animate." + if args.animate_clip_len is not None: + assert args.animate_clip_len > 0, "--animate_clip_len must be positive." cfg = WAN_CONFIGS[args.task] if args.sample_steps is None: args.sample_steps = cfg.sample_steps - if args.sample_shift is None: args.sample_shift = cfg.sample_shift - if args.sample_guide_scale is None: args.sample_guide_scale = cfg.sample_guide_scale - if args.frame_num is None: args.frame_num = cfg.frame_num - args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( - 0, sys.maxsize) + args.base_seed = ( + args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize) + ) + + if getattr(args, "auto_safe_size_5b", False) and ("5b" in task_lower): + safe_size = "832*480" + if args.size != safe_size: + logging.info( + f"auto_safe_size_5b enabled: overriding size {args.size} -> {safe_size}" + ) + args.size = safe_size + if args.frame_num > 81: + logging.info( + f"auto_safe_size_5b enabled: capping frame_num {args.frame_num} -> 81" + ) + args.frame_num = min(args.frame_num, 81) + + if "s2v" in task_lower and args.frame_num % 4 != 0: + raise AssertionError( + "Speech-to-video expects --frame_num to be a multiple of 4." + ) + # Size check - assert args.size in SUPPORTED_SIZES[ - args. - task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" + assert args.size in SUPPORTED_SIZES[args.task], ( + f"Unsupport size {args.size} for task {args.task}, supported sizes are: " + f"{', '.join(SUPPORTED_SIZES[args.task])}" + ) def _parse_args(): @@ -83,116 +153,237 @@ def _parse_args(): type=str, default="t2v-A14B", choices=list(WAN_CONFIGS.keys()), - help="The task to run.") + help="The task to run.", + ) parser.add_argument( "--size", type=str, default="1280*720", choices=list(SIZE_CONFIGS.keys()), - help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." + help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image.", ) parser.add_argument( "--frame_num", type=int, default=None, - help="How many frames of video are generated. The number should be 4n+1" + help="How many frames of video are generated. Differs per task (e.g. 4n+1 for T2V/I2V/TI2V, multiples of 4 for S2V).", ) parser.add_argument( "--ckpt_dir", type=str, default=None, - help="The path to the checkpoint directory.") + help="The path to the checkpoint directory.", + ) parser.add_argument( "--offload_model", type=str2bool, default=None, - help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage.", ) parser.add_argument( "--ulysses_size", type=int, default=1, - help="The size of the ulysses parallelism in DiT.") + help="The size of the ulysses parallelism in DiT.", + ) parser.add_argument( "--t5_fsdp", action="store_true", default=False, - help="Whether to use FSDP for T5.") + help="Whether to use FSDP for T5.", + ) parser.add_argument( "--t5_cpu", action="store_true", default=False, - help="Whether to place T5 model on CPU.") + help="Whether to place T5 model on CPU.", + ) parser.add_argument( "--dit_fsdp", action="store_true", default=False, - help="Whether to use FSDP for DiT.") + help="Whether to use FSDP for DiT.", + ) parser.add_argument( "--save_file", type=str, default=None, - help="The file to save the generated video to.") + help="The file to save the generated video to.", + ) parser.add_argument( "--prompt", type=str, default=None, - help="The prompt to generate the video from.") + help="The prompt to generate the video from.", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default=None, + help="Optional negative prompt for classifier-free guidance.", + ) parser.add_argument( "--use_prompt_extend", action="store_true", default=False, - help="Whether to use prompt extend.") + help="Whether to use prompt extend.", + ) parser.add_argument( "--prompt_extend_method", type=str, default="local_qwen", choices=["dashscope", "local_qwen"], - help="The prompt extend method to use.") + help="The prompt extend method to use.", + ) parser.add_argument( "--prompt_extend_model", type=str, default=None, - help="The prompt extend model to use.") + help="The prompt extend model to use.", + ) parser.add_argument( "--prompt_extend_target_lang", type=str, default="zh", choices=["zh", "en"], - help="The target language of prompt extend.") + help="The target language of prompt extend.", + ) parser.add_argument( "--base_seed", type=int, default=-1, - help="The seed to use for generating the video.") + help="The seed to use for generating the video.", + ) parser.add_argument( - "--image", + "--image", type=str, default=None, help="The image to generate the video from." + ) + parser.add_argument( + "--audio", type=str, default=None, - help="The image to generate the video from.") + help="Input audio track for speech-to-video tasks.", + ) + parser.add_argument( + "--pose_video", + type=str, + default=None, + help="Optional pose driving video for speech-to-video motion guidance.", + ) parser.add_argument( "--sample_solver", type=str, - default='unipc', - choices=['unipc', 'dpm++'], - help="The solver used to sample.") + default="unipc", + choices=["unipc", "dpm++"], + help="The solver used to sample.", + ) parser.add_argument( - "--sample_steps", type=int, default=None, help="The sampling steps.") + "--sample_steps", type=int, default=None, help="The sampling steps." + ) parser.add_argument( "--sample_shift", type=float, default=None, - help="Sampling shift factor for flow matching schedulers.") + help="Sampling shift factor for flow matching schedulers.", + ) parser.add_argument( "--sample_guide_scale", type=float, default=None, - help="Classifier free guidance scale.") + help="Classifier free guidance scale.", + ) parser.add_argument( "--convert_model_dtype", action="store_true", default=False, - help="Whether to convert model paramerters dtype.") + help="Whether to convert model paramerters dtype.", + ) + parser.add_argument( + "--num_repeat", + type=int, + default=None, + help="Number of video clips to generate for speech-to-video (auto if omitted).", + ) + parser.add_argument( + "--enable_tts", + action="store_true", + default=False, + help="Use CosyVoice TTS when no audio is provided for speech-to-video.", + ) + parser.add_argument( + "--tts_prompt_audio", + type=str, + default=None, + help="Reference speaker audio for TTS when --enable_tts is set.", + ) + parser.add_argument( + "--tts_prompt_text", + type=str, + default=None, + help="Optional reference transcript for the TTS prompt speaker.", + ) + parser.add_argument( + "--tts_text", + type=str, + default=None, + help="Target text to synthesise when --enable_tts is set.", + ) + parser.add_argument( + "--init_first_frame", + action="store_true", + default=False, + help="Initialise speech-to-video generation with the reference image as the first frame.", + ) + parser.add_argument( + "--animate_src_root", + type=str, + default=None, + help="Directory containing the Wan Animate preprocessed assets (src_pose.mp4, src_ref.png, etc.).", + ) + parser.add_argument( + "--animate_replace", + action="store_true", + default=False, + help="Enable background replacement for Wan Animate.", + ) + parser.add_argument( + "--animate_refer_frames", + type=int, + default=None, + help="Reference frame count (1 or 5) used by Wan Animate for temporal guidance.", + ) + parser.add_argument( + "--animate_clip_len", + type=int, + default=None, + help="Clip length processed per iteration in Wan Animate (defaults to model config).", + ) + + # Diagnostics and memory tracking flags + parser.add_argument( + "--diag_memory", + action="store_true", + default=False, + help="Enable CUDA memory diagnostics and precompute text embeddings with encoder offload.", + ) + parser.add_argument( + "--diag_show_processes", + action="store_true", + default=False, + help="Also print torch.cuda.list_gpu_processes() in diagnostics if available.", + ) + parser.add_argument( + "--diag_precision", + type=str, + choices=["fp16", "bf16"], + default="fp16", + help="Precision for text embeddings in diagnostics path (fp16 or bf16).", + ) + parser.add_argument( + "--auto_safe_size_5b", + action="store_true", + default=False, + help="Auto-pick a safer size/frame count for 5B models (e.g., ti2v-5B).", + ) args = parser.parse_args() @@ -208,7 +399,8 @@ def _init_logging(rank): logging.basicConfig( level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s", - handlers=[logging.StreamHandler(stream=sys.stdout)]) + handlers=[logging.StreamHandler(stream=sys.stdout)], + ) else: logging.basicConfig(level=logging.ERROR) @@ -222,15 +414,12 @@ def generate(args): if args.offload_model is None: args.offload_model = False if world_size > 1 else True - logging.info( - f"offload_model is not specified, set to {args.offload_model}.") + logging.info(f"offload_model is not specified, set to {args.offload_model}.") if world_size > 1: torch.cuda.set_device(local_rank) dist.init_process_group( - backend="nccl", - init_method="env://", - rank=rank, - world_size=world_size) + backend="nccl", init_method="env://", rank=rank, world_size=world_size + ) else: assert not ( args.t5_fsdp or args.dit_fsdp @@ -240,7 +429,9 @@ def generate(args): ), f"sequence parallel are not supported in non-distributed environments." if args.ulysses_size > 1: - assert args.ulysses_size == world_size, f"The number of ulysses_size should be equal to the world size." + assert ( + args.ulysses_size == world_size + ), f"The number of ulysses_size should be equal to the world size." init_distributed_group() if args.use_prompt_extend: @@ -248,24 +439,54 @@ def generate(args): prompt_expander = DashScopePromptExpander( model_name=args.prompt_extend_model, task=args.task, - is_vl=args.image is not None) + is_vl=args.image is not None, + ) elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, task=args.task, is_vl=args.image is not None, - device=rank) + device=rank, + ) else: raise NotImplementedError( - f"Unsupport prompt_extend_method: {args.prompt_extend_method}") + f"Unsupport prompt_extend_method: {args.prompt_extend_method}" + ) cfg = WAN_CONFIGS[args.task] if args.ulysses_size > 1: - assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`." + assert ( + cfg.num_heads % args.ulysses_size == 0 + ), f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`." + + # Apply convenience safer-size for 5B models if requested + if args.auto_safe_size_5b and "5B" in args.task: + safe_size = "832*480" + if args.size != safe_size: + logging.info( + f"auto_safe_size_5b enabled: overriding size {args.size} -> {safe_size}" + ) + args.size = safe_size + # Cap frames to 81 if higher, to reduce VRAM usage on 5B + if args.frame_num is None: + # _validate_args may have already set default; set conservatively to 81 + args.frame_num = 81 + elif args.frame_num > 81: + logging.info( + f"auto_safe_size_5b enabled: capping frame_num {args.frame_num} -> 81" + ) + args.frame_num = 81 logging.info(f"Generation job args: {args}") logging.info(f"Generation model config: {cfg}") + if args.auto_safe_size_5b and "5b" in args.task.lower(): + logging.info( + "auto_safe_size_5b active -> size=%s, frame_num=%s", + args.size, + args.frame_num, + ) + if dist.is_initialized(): base_seed = [args.base_seed] if rank == 0 else [None] dist.broadcast_object_list(base_seed, src=0) @@ -285,10 +506,10 @@ def generate(args): args.prompt, image=img, tar_lang=args.prompt_extend_target_lang, - seed=args.base_seed) + seed=args.base_seed, + ) if prompt_output.status == False: - logging.info( - f"Extending prompt failed: {prompt_output.message}") + logging.info(f"Extending prompt failed: {prompt_output.message}") logging.info("Falling back to original prompt.") input_prompt = args.prompt else: @@ -301,6 +522,8 @@ def generate(args): args.prompt = input_prompt[0] logging.info(f"Extended prompt: {args.prompt}") + neg_prompt = args.negative_prompt or "" + if "t2v" in args.task: logging.info("Creating WanT2V pipeline.") wan_t2v = wan.WanT2V( @@ -316,16 +539,66 @@ def generate(args): ) logging.info(f"Generating video ...") - video = wan_t2v.generate( - args.prompt, - size=SIZE_CONFIGS[args.size], - frame_num=args.frame_num, - shift=args.sample_shift, - sample_solver=args.sample_solver, - sampling_steps=args.sample_steps, - guide_scale=args.sample_guide_scale, - seed=args.base_seed, - offload_model=args.offload_model) + if args.diag_memory: + from wan.utils.memory_diag import ( + assert_text_encoder_off_gpu, + track_cuda_memory, + ) + from wan.utils.text_conditioning import prepare_text_conditioning as _prep + + # Precompute embeddings once and offload encoder + neg_text = ( + args.negative_prompt + if args.negative_prompt is not None + else wan_t2v.sample_neg_prompt + ) + neg = [neg_text] if neg_text else None + prompt_embeds, negative_prompt_embeds = _prep( + pipe=wan_t2v, + prompts=[args.prompt], + negative_prompts=neg, + precision=args.diag_precision, + device=wan_t2v.device, + offload_strategy="set_none", + print_memory=True, + ) + # Assert encoder is fully off GPU + assert_text_encoder_off_gpu(wan_t2v) + # Run generation under memory tracking + from torch import inference_mode + + with track_cuda_memory( + "generation", show_processes=args.diag_show_processes + ): + with inference_mode(): + video = wan_t2v.generate( + args.prompt, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + n_prompt=neg_prompt, + seed=args.base_seed, + offload_model=args.offload_model, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + precision=args.diag_precision, + ) + else: + video = wan_t2v.generate( + args.prompt, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + n_prompt=neg_prompt, + seed=args.base_seed, + offload_model=args.offload_model, + ) elif "ti2v" in args.task: logging.info("Creating WanTI2V pipeline.") wan_ti2v = wan.WanTI2V( @@ -341,21 +614,106 @@ def generate(args): ) logging.info(f"Generating video ...") - video = wan_ti2v.generate( - args.prompt, - img=img, - size=SIZE_CONFIGS[args.size], + if args.diag_memory: + from wan.utils.memory_diag import ( + assert_text_encoder_off_gpu, + track_cuda_memory, + ) + from wan.utils.text_conditioning import prepare_text_conditioning as _prep + + neg_text = ( + args.negative_prompt + if args.negative_prompt is not None + else wan_ti2v.sample_neg_prompt + ) + neg = [neg_text] if neg_text else None + prompt_embeds, negative_prompt_embeds = _prep( + pipe=wan_ti2v, + prompts=[args.prompt], + negative_prompts=neg, + precision=args.diag_precision, + device=wan_ti2v.device, + offload_strategy="set_none", + print_memory=True, + ) + assert_text_encoder_off_gpu(wan_ti2v) + from torch import inference_mode + + with track_cuda_memory( + "generation", show_processes=args.diag_show_processes + ): + with inference_mode(): + video = wan_ti2v.generate( + args.prompt, + img=img, + size=SIZE_CONFIGS[args.size], + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + n_prompt=neg_prompt, + seed=args.base_seed, + offload_model=args.offload_model, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + precision=args.diag_precision, + ) + else: + video = wan_ti2v.generate( + args.prompt, + img=img, + size=SIZE_CONFIGS[args.size], + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + n_prompt=neg_prompt, + seed=args.base_seed, + offload_model=args.offload_model, + ) + elif "s2v" in args.task: + logging.info("Creating WanS2V pipeline.") + wan_s2v = wan.WanS2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_sp=(args.ulysses_size > 1), + t5_cpu=args.t5_cpu, + convert_model_dtype=args.convert_model_dtype, + ) + + logging.info("Generating video ...") + video = wan_s2v.generate( + input_prompt=args.prompt, + ref_image_path=args.image, + audio_path=args.audio, + enable_tts=args.enable_tts, + tts_prompt_audio=args.tts_prompt_audio, + tts_prompt_text=args.tts_prompt_text, + tts_text=args.tts_text, + num_repeat=args.num_repeat if args.num_repeat else 1, + pose_video=args.pose_video, max_area=MAX_AREA_CONFIGS[args.size], - frame_num=args.frame_num, + infer_frames=args.frame_num, shift=args.sample_shift, sample_solver=args.sample_solver, sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, + n_prompt=neg_prompt, seed=args.base_seed, - offload_model=args.offload_model) - else: - logging.info("Creating WanI2V pipeline.") - wan_i2v = wan.WanI2V( + offload_model=args.offload_model, + init_first_frame=args.init_first_frame, + ) + elif "animate" in args.task: + logging.info("Creating WanAnimate pipeline.") + wan_animate = wan.WanAnimate( config=cfg, checkpoint_dir=args.ckpt_dir, device_id=device, @@ -368,25 +726,106 @@ def generate(args): ) logging.info("Generating video ...") - video = wan_i2v.generate( - args.prompt, - img, - max_area=MAX_AREA_CONFIGS[args.size], - frame_num=args.frame_num, + clip_len = args.animate_clip_len or args.frame_num + refer_frames = args.animate_refer_frames or 1 + video = wan_animate.generate( + src_root_path=args.animate_src_root, + replace_flag=args.animate_replace, + clip_len=clip_len, + refert_num=refer_frames, shift=args.sample_shift, sample_solver=args.sample_solver, sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, + input_prompt=args.prompt, + n_prompt=neg_prompt, seed=args.base_seed, - offload_model=args.offload_model) + offload_model=args.offload_model, + ) + else: + logging.info("Creating WanI2V pipeline.") + wan_i2v = wan.WanI2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_sp=(args.ulysses_size > 1), + t5_cpu=args.t5_cpu, + convert_model_dtype=args.convert_model_dtype, + ) + + logging.info("Generating video ...") + if args.diag_memory: + from wan.utils.memory_diag import ( + assert_text_encoder_off_gpu, + track_cuda_memory, + ) + from wan.utils.text_conditioning import prepare_text_conditioning as _prep + + neg_text = ( + args.negative_prompt + if args.negative_prompt is not None + else wan_i2v.sample_neg_prompt + ) + neg = [neg_text] if neg_text else None + prompt_embeds, negative_prompt_embeds = _prep( + pipe=wan_i2v, + prompts=[args.prompt], + negative_prompts=neg, + precision=args.diag_precision, + device=wan_i2v.device, + offload_strategy="set_none", + print_memory=True, + ) + assert_text_encoder_off_gpu(wan_i2v) + from torch import inference_mode + + with track_cuda_memory( + "generation", show_processes=args.diag_show_processes + ): + with inference_mode(): + video = wan_i2v.generate( + args.prompt, + img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + n_prompt=neg_prompt, + seed=args.base_seed, + offload_model=args.offload_model, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + precision=args.diag_precision, + ) + else: + video = wan_i2v.generate( + args.prompt, + img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + n_prompt=neg_prompt, + seed=args.base_seed, + offload_model=args.offload_model, + ) if rank == 0: if args.save_file is None: formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") - formatted_prompt = args.prompt.replace(" ", "_").replace("/", - "_")[:50] - suffix = '.mp4' - args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix + formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50] + suffix = ".mp4" + args.save_file = ( + f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + + suffix + ) logging.info(f"Saving generated video to {args.save_file}") save_video( @@ -395,7 +834,8 @@ def generate(args): fps=cfg.sample_fps, nrow=1, normalize=True, - value_range=(-1, 1)) + value_range=(-1, 1), + ) del video torch.cuda.synchronize() diff --git a/requirements.txt b/requirements.txt index 77c1e6d5..266166fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ torch>=2.4.0 torchvision>=0.19.0 +torchaudio opencv-python>=4.9.0.80 diffusers>=0.31.0 -transformers>=4.49.0 +transformers>=4.49.0,<=4.51.3 tokenizers>=0.20.3 accelerate>=1.1.1 tqdm @@ -13,3 +14,15 @@ dashscope imageio-ffmpeg flash_attn numpy>=1.23.5,<2 +einops>=0.7.0 +decord>=0.6.0 +peft>=0.12.0 +safetensors>=0.4.2 +matplotlib +onnxruntime +hydra-core +omegaconf +loguru +moviepy +pandas +sentencepiece diff --git a/requirements_animate.txt b/requirements_animate.txt new file mode 100644 index 00000000..7aa2d669 --- /dev/null +++ b/requirements_animate.txt @@ -0,0 +1,2 @@ +# Optional dependencies for Wan Animate preprocessing utilities +-e git+https://github.com/facebookresearch/sam2.git@0e78a118995e66bb27d78518c4bd9a3e95b4e266#egg=SAM-2 diff --git a/requirements_s2v.txt b/requirements_s2v.txt new file mode 100644 index 00000000..4556849a --- /dev/null +++ b/requirements_s2v.txt @@ -0,0 +1,12 @@ +# Optional dependencies for Wan Speech-to-Video pipelines (audio / TTS) +HyperPyYAML +inflect +librosa +lightning +rich +gdown +wget +pyarrow +pyworld +modelscope +GitPython diff --git a/wan/__init__.py b/wan/__init__.py index 0861d669..80539ed1 100644 --- a/wan/__init__.py +++ b/wan/__init__.py @@ -1,5 +1,18 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. from . import configs, distributed, modules +from .animate import WanAnimate from .image2video import WanI2V +from .speech2video import WanS2V from .text2video import WanT2V from .textimage2video import WanTI2V + +__all__ = [ + "WanT2V", + "WanI2V", + "WanTI2V", + "WanS2V", + "WanAnimate", + "configs", + "distributed", + "modules", +] diff --git a/wan/animate.py b/wan/animate.py new file mode 100644 index 00000000..c173ee85 --- /dev/null +++ b/wan/animate.py @@ -0,0 +1,733 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math +import os +import types +from copy import deepcopy +from functools import partial + +import cv2 +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from decord import VideoReader +from einops import rearrange +from peft import set_peft_model_state_dict +from tqdm import tqdm + +from .distributed.fsdp import shard_model +from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward +from .distributed.util import get_world_size +from .modules.animate import CLIPModel, WanAnimateModel +from .modules.animate.animate_utils import TensorList, get_loraconfig +from .modules.t5 import T5EncoderModel +from .modules.vae2_1 import Wan2_1_VAE +from .utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + + +class WanAnimate: + + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_sp=False, + t5_cpu=False, + init_on_cpu=True, + convert_model_dtype=False, + use_relighting_lora=False, + ): + r""" + Initializes the generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for T5 model + dit_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for DiT model + use_sp (`bool`, *optional*, defaults to False): + Enable distribution strategy of sequence parallel. + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + init_on_cpu (`bool`, *optional*, defaults to True): + Enable initializing Transformer Model on CPU. Only works without FSDP or USP. + convert_model_dtype (`bool`, *optional*, defaults to False): + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + use_relighting_lora (`bool`, *optional*, defaults to False): + Whether to use relighting lora for character replacement. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.init_on_cpu = init_on_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + if t5_fsdp or dit_fsdp or use_sp: + self.init_on_cpu = False + + shard_fn = partial(shard_model, device_id=device_id) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device("cpu"), + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=shard_fn if t5_fsdp else None, + ) + + self.clip = CLIPModel( + dtype=torch.float16, + device=self.device, + checkpoint_path=os.path.join(checkpoint_dir, config.clip_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer), + ) + + self.vae = Wan2_1_VAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device, + ) + + logging.info(f"Creating WanAnimate from {checkpoint_dir}") + + if not dit_fsdp: + self.noise_model = WanAnimateModel.from_pretrained( + checkpoint_dir, torch_dtype=self.param_dtype, device_map=self.device + ) + else: + self.noise_model = WanAnimateModel.from_pretrained( + checkpoint_dir, torch_dtype=self.param_dtype + ) + + self.noise_model = self._configure_model( + model=self.noise_model, + use_sp=use_sp, + dit_fsdp=dit_fsdp, + shard_fn=shard_fn, + convert_model_dtype=convert_model_dtype, + use_lora=use_relighting_lora, + checkpoint_dir=checkpoint_dir, + config=config, + ) + + if use_sp: + self.sp_size = get_world_size() + else: + self.sp_size = 1 + + self.sample_neg_prompt = config.sample_neg_prompt + self.sample_prompt = config.prompt + + def _configure_model( + self, + model, + use_sp, + dit_fsdp, + shard_fn, + convert_model_dtype, + use_lora, + checkpoint_dir, + config, + ): + """ + Configures a model object. This includes setting evaluation modes, + applying distributed parallel strategy, and handling device placement. + + Args: + model (torch.nn.Module): + The model instance to configure. + use_sp (`bool`): + Enable distribution strategy of sequence parallel. + dit_fsdp (`bool`): + Enable FSDP sharding for DiT model. + shard_fn (callable): + The function to apply FSDP sharding. + convert_model_dtype (`bool`): + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + + Returns: + torch.nn.Module: + The configured model. + """ + model.eval().requires_grad_(False) + + if use_sp: + for block in model.blocks: + block.self_attn.forward = types.MethodType( + sp_attn_forward, block.self_attn + ) + + model.use_context_parallel = True + + if dist.is_initialized(): + dist.barrier() + + if use_lora: + logging.info("Loading Relighting Lora. ") + lora_config = get_loraconfig(transformer=model, rank=128, alpha=128) + model.add_adapter(lora_config) + lora_path = os.path.join(checkpoint_dir, config.lora_checkpoint) + peft_state_dict = torch.load(lora_path)["state_dict"] + set_peft_model_state_dict(model, peft_state_dict) + + if dit_fsdp: + model = shard_fn(model, use_lora=use_lora) + else: + if convert_model_dtype: + model.to(self.param_dtype) + if not self.init_on_cpu: + model.to(self.device) + + return model + + def inputs_padding(self, array, target_len): + idx = 0 + flip = False + target_array = [] + while len(target_array) < target_len: + target_array.append(deepcopy(array[idx])) + if flip: + idx -= 1 + else: + idx += 1 + if idx == 0 or idx == len(array) - 1: + flip = not flip + return target_array[:target_len] + + def get_valid_len(self, real_len, clip_len=81, overlap=1): + real_clip_len = clip_len - overlap + last_clip_num = (real_len - overlap) % real_clip_len + if last_clip_num == 0: + extra = 0 + else: + extra = real_clip_len - last_clip_num + target_len = real_len + extra + return target_len + + def get_i2v_mask( + self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda" + ): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t - 1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = mask_pixel_values.clone() + msk[:, :mask_len] = 1 + msk = torch.concat( + [torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1 + ) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + return msk + + def padding_resize( + self, + img_ori, + height=512, + width=512, + padding_color=(0, 0, 0), + interpolation=cv2.INTER_LINEAR, + ): + ori_height = img_ori.shape[0] + ori_width = img_ori.shape[1] + channel = img_ori.shape[2] + + img_pad = np.zeros((height, width, channel)) + if channel == 1: + img_pad[:, :, 0] = padding_color[0] + else: + img_pad[:, :, 0] = padding_color[0] + img_pad[:, :, 1] = padding_color[1] + img_pad[:, :, 2] = padding_color[2] + + if (ori_height / ori_width) > (height / width): + new_width = int(height / ori_height * ori_width) + img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation) + padding = int((width - new_width) / 2) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img_pad[:, padding : padding + new_width, :] = img + else: + new_height = int(width / ori_width * ori_height) + img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation) + padding = int((height - new_height) / 2) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img_pad[padding : padding + new_height, :, :] = img + + img_pad = np.uint8(img_pad) + + return img_pad + + def prepare_source(self, src_pose_path, src_face_path, src_ref_path): + pose_video_reader = VideoReader(src_pose_path) + pose_len = len(pose_video_reader) + pose_idxs = list(range(pose_len)) + cond_images = pose_video_reader.get_batch(pose_idxs).asnumpy() + + face_video_reader = VideoReader(src_face_path) + face_len = len(face_video_reader) + face_idxs = list(range(face_len)) + face_images = face_video_reader.get_batch(face_idxs).asnumpy() + height, width = cond_images[0].shape[:2] + refer_images = cv2.imread(src_ref_path)[..., ::-1] + refer_images = self.padding_resize(refer_images, height=height, width=width) + return cond_images, face_images, refer_images + + def prepare_source_for_replace(self, src_bg_path, src_mask_path): + bg_video_reader = VideoReader(src_bg_path) + bg_len = len(bg_video_reader) + bg_idxs = list(range(bg_len)) + bg_images = bg_video_reader.get_batch(bg_idxs).asnumpy() + + mask_video_reader = VideoReader(src_mask_path) + mask_len = len(mask_video_reader) + mask_idxs = list(range(mask_len)) + mask_images = mask_video_reader.get_batch(mask_idxs).asnumpy() + mask_images = mask_images[:, :, :, 0] / 255 + return bg_images, mask_images + + def generate( + self, + src_root_path, + replace_flag=False, + clip_len=77, + refert_num=1, + shift=5.0, + sample_solver="dpm++", + sampling_steps=20, + guide_scale=1, + input_prompt="", + n_prompt="", + seed=-1, + offload_model=True, + ): + r""" + Generates video frames from input image using diffusion process. + + Args: + src_root_path ('str'): + Process output path + replace_flag (`bool`, *optional*, defaults to False): + Whether to use character replace. + clip_len (`int`, *optional*, defaults to 77): + How many frames to generate per clips. The number should be 4n+1 + refert_num (`int`, *optional*, defaults to 1): + How many frames used for temporal guidance. Recommended to be 1 or 5. + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. + sample_solver (`str`, *optional*, defaults to 'dpm++'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 20): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float` or tuple[`float`], *optional*, defaults 1.0): + Classifier-free guidance scale. We only use it for expression control. + In most cases, it's not necessary and faster generation can be achieved without it. + When expression adjustments are needed, you may consider using this feature. + input_prompt (`str`): + Text prompt for content generation. We don't recommend custom prompts (although they work) + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N, H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames + - H: Frame height + - W: Frame width + """ + assert refert_num == 1 or refert_num == 5, "refert_num should be 1 or 5." + + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + + if input_prompt == "": + input_prompt = self.sample_prompt + + src_pose_path = os.path.join(src_root_path, "src_pose.mp4") + src_face_path = os.path.join(src_root_path, "src_face.mp4") + src_ref_path = os.path.join(src_root_path, "src_ref.png") + + cond_images, face_images, refer_images = self.prepare_source( + src_pose_path=src_pose_path, + src_face_path=src_face_path, + src_ref_path=src_ref_path, + ) + + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device("cpu")) + context_null = self.text_encoder([n_prompt], torch.device("cpu")) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + real_frame_len = len(cond_images) + target_len = self.get_valid_len(real_frame_len, clip_len, overlap=refert_num) + logging.info( + "real frames: {} target frames: {}".format(real_frame_len, target_len) + ) + cond_images = self.inputs_padding(cond_images, target_len) + face_images = self.inputs_padding(face_images, target_len) + + if replace_flag: + src_bg_path = os.path.join(src_root_path, "src_bg.mp4") + src_mask_path = os.path.join(src_root_path, "src_mask.mp4") + bg_images, mask_images = self.prepare_source_for_replace( + src_bg_path, src_mask_path + ) + bg_images = self.inputs_padding(bg_images, target_len) + mask_images = self.inputs_padding(mask_images, target_len) + + height, width = refer_images.shape[:2] + start = 0 + end = clip_len + all_out_frames = [] + while True: + if start + refert_num >= len(cond_images): + break + + if start == 0: + mask_reft_len = 0 + else: + mask_reft_len = refert_num + + batch = { + "conditioning_pixel_values": torch.zeros(1, 3, clip_len, height, width), + "bg_pixel_values": torch.zeros(1, 3, clip_len, height, width), + "mask_pixel_values": torch.zeros(1, 1, clip_len, height, width), + "face_pixel_values": torch.zeros(1, 3, clip_len, 512, 512), + "refer_pixel_values": torch.zeros(1, 3, height, width), + "refer_t_pixel_values": torch.zeros(refert_num, 3, height, width), + } + + batch["conditioning_pixel_values"] = rearrange( + torch.tensor(np.stack(cond_images[start:end]) / 127.5 - 1), + "t h w c -> 1 c t h w", + ) + batch["face_pixel_values"] = rearrange( + torch.tensor(np.stack(face_images[start:end]) / 127.5 - 1), + "t h w c -> 1 c t h w", + ) + + batch["refer_pixel_values"] = rearrange( + torch.tensor(refer_images / 127.5 - 1), "h w c -> 1 c h w" + ) + + if start > 0: + batch["refer_t_pixel_values"] = rearrange( + out_frames[0, :, -refert_num:].clone().detach(), + "c t h w -> t c h w", + ) + + batch["refer_t_pixel_values"] = rearrange( + batch["refer_t_pixel_values"], + "t c h w -> 1 c t h w", + ) + + if replace_flag: + batch["bg_pixel_values"] = rearrange( + torch.tensor(np.stack(bg_images[start:end]) / 127.5 - 1), + "t h w c -> 1 c t h w", + ) + + batch["mask_pixel_values"] = rearrange( + torch.tensor(np.stack(mask_images[start:end])[:, :, :, None]), + "t h w c -> 1 t c h w", + ) + + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(device=self.device, dtype=torch.bfloat16) + + ref_pixel_values = batch["refer_pixel_values"] + refer_t_pixel_values = batch["refer_t_pixel_values"] + conditioning_pixel_values = batch["conditioning_pixel_values"] + face_pixel_values = batch["face_pixel_values"] + + B, _, H, W = ref_pixel_values.shape + T = clip_len + lat_h = H // 8 + lat_w = W // 8 + lat_t = T // 4 + 1 + target_shape = [lat_t + 1, lat_h, lat_w] + noise = [ + torch.randn( + 16, + target_shape[0], + target_shape[1], + target_shape[2], + dtype=torch.float32, + device=self.device, + generator=seed_g, + ) + ] + + max_seq_len = ( + int(math.ceil(np.prod(target_shape) // 4 / self.sp_size)) * self.sp_size + ) + if max_seq_len % self.sp_size != 0: + raise ValueError( + f"max_seq_len {max_seq_len} is not divisible by sp_size {self.sp_size}" + ) + + with ( + torch.autocast( + device_type=str(self.device), dtype=torch.bfloat16, enabled=True + ), + torch.no_grad(), + ): + if sample_solver == "unipc": + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False, + ) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift + ) + timesteps = sample_scheduler.timesteps + elif sample_solver == "dpm++": + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False, + ) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, device=self.device, sigmas=sampling_sigmas + ) + else: + raise NotImplementedError("Unsupported solver.") + + latents = noise + + pose_latents_no_ref = self.vae.encode( + conditioning_pixel_values.to(torch.bfloat16) + ) + pose_latents_no_ref = torch.stack(pose_latents_no_ref) + pose_latents = torch.cat([pose_latents_no_ref], dim=2) + + ref_pixel_values = rearrange(ref_pixel_values, "t c h w -> 1 c t h w") + ref_latents = self.vae.encode(ref_pixel_values.to(torch.bfloat16)) + ref_latents = torch.stack(ref_latents) + + mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=self.device) + y_ref = torch.concat([mask_ref, ref_latents[0]]).to( + dtype=torch.bfloat16, device=self.device + ) + + img = ref_pixel_values[0, :, 0] + clip_context = self.clip.visual([img[:, None, :, :]]).to( + dtype=torch.bfloat16, device=self.device + ) + + if mask_reft_len > 0: + if replace_flag: + bg_pixel_values = batch["bg_pixel_values"] + y_reft = self.vae.encode( + [ + torch.concat( + [ + refer_t_pixel_values[0, :, :mask_reft_len], + bg_pixel_values[0, :, mask_reft_len:], + ], + dim=1, + ).to(self.device) + ] + )[0] + mask_pixel_values = 1 - batch["mask_pixel_values"] + mask_pixel_values = rearrange( + mask_pixel_values, "b t c h w -> (b t) c h w" + ) + mask_pixel_values = F.interpolate( + mask_pixel_values, size=(H // 8, W // 8), mode="nearest" + ) + mask_pixel_values = rearrange( + mask_pixel_values, "(b t) c h w -> b t c h w", b=1 + )[:, :, 0] + msk_reft = self.get_i2v_mask( + lat_t, + lat_h, + lat_w, + mask_reft_len, + mask_pixel_values=mask_pixel_values, + device=self.device, + ) + else: + y_reft = self.vae.encode( + [ + torch.concat( + [ + torch.nn.functional.interpolate( + refer_t_pixel_values[ + 0, :, :mask_reft_len + ].cpu(), + size=(H, W), + mode="bicubic", + ), + torch.zeros(3, T - mask_reft_len, H, W), + ], + dim=1, + ).to(self.device) + ] + )[0] + msk_reft = self.get_i2v_mask( + lat_t, lat_h, lat_w, mask_reft_len, device=self.device + ) + else: + if replace_flag: + bg_pixel_values = batch["bg_pixel_values"] + mask_pixel_values = 1 - batch["mask_pixel_values"] + mask_pixel_values = rearrange( + mask_pixel_values, "b t c h w -> (b t) c h w" + ) + mask_pixel_values = F.interpolate( + mask_pixel_values, size=(H // 8, W // 8), mode="nearest" + ) + mask_pixel_values = rearrange( + mask_pixel_values, "(b t) c h w -> b t c h w", b=1 + )[:, :, 0] + y_reft = self.vae.encode( + [ + torch.concat( + [ + bg_pixel_values[0], + ], + dim=1, + ).to(self.device) + ] + )[0] + msk_reft = self.get_i2v_mask( + lat_t, + lat_h, + lat_w, + mask_reft_len, + mask_pixel_values=mask_pixel_values, + device=self.device, + ) + else: + y_reft = self.vae.encode( + [ + torch.concat( + [ + torch.zeros(3, T - mask_reft_len, H, W), + ], + dim=1, + ).to(self.device) + ] + )[0] + msk_reft = self.get_i2v_mask( + lat_t, lat_h, lat_w, mask_reft_len, device=self.device + ) + + y_reft = torch.concat([msk_reft, y_reft]).to( + dtype=torch.bfloat16, device=self.device + ) + y = torch.concat([y_ref, y_reft], dim=1) + + arg_c = { + "context": context, + "seq_len": max_seq_len, + "clip_fea": clip_context.to( + dtype=torch.bfloat16, device=self.device + ), + "y": [y], + "pose_latents": pose_latents, + "face_pixel_values": face_pixel_values, + } + + if guide_scale > 1: + face_pixel_values_uncond = face_pixel_values * 0 - 1 + arg_null = { + "context": context_null, + "seq_len": max_seq_len, + "clip_fea": clip_context.to( + dtype=torch.bfloat16, device=self.device + ), + "y": [y], + "pose_latents": pose_latents, + "face_pixel_values": face_pixel_values_uncond, + } + + for i, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = torch.stack(timestep) + + noise_pred_cond = TensorList( + self.noise_model( + TensorList(latent_model_input), t=timestep, **arg_c + ) + ) + + if guide_scale > 1: + noise_pred_uncond = TensorList( + self.noise_model( + TensorList(latent_model_input), t=timestep, **arg_null + ) + ) + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond + ) + else: + noise_pred = noise_pred_cond + + temp_x0 = sample_scheduler.step( + noise_pred[0].unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g, + )[0] + latents[0] = temp_x0.squeeze(0) + + x0 = latents + + x0 = [x.to(dtype=torch.float32) for x in x0] + out_frames = torch.stack(self.vae.decode([x0[0][:, 1:]])) + + if start != 0: + out_frames = out_frames[:, :, refert_num:] + + all_out_frames.append(out_frames.cpu()) + + start += clip_len - refert_num + end += clip_len - refert_num + + videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len] + return videos[0] if self.rank == 0 else None diff --git a/wan/configs/__init__.py b/wan/configs/__init__.py index 875afe7e..923d17c1 100644 --- a/wan/configs/__init__.py +++ b/wan/configs/__init__.py @@ -2,38 +2,44 @@ import copy import os -os.environ['TOKENIZERS_PARALLELISM'] = 'false' +os.environ["TOKENIZERS_PARALLELISM"] = "false" +from .wan_animate_14B import animate_14B from .wan_i2v_A14B import i2v_A14B +from .wan_s2v_14B import s2v_14B from .wan_t2v_A14B import t2v_A14B from .wan_ti2v_5B import ti2v_5B WAN_CONFIGS = { - 't2v-A14B': t2v_A14B, - 'i2v-A14B': i2v_A14B, - 'ti2v-5B': ti2v_5B, + "t2v-A14B": t2v_A14B, + "i2v-A14B": i2v_A14B, + "ti2v-5B": ti2v_5B, + "s2v-14B": s2v_14B, + "animate-14B": animate_14B, } SIZE_CONFIGS = { - '720*1280': (720, 1280), - '1280*720': (1280, 720), - '480*832': (480, 832), - '832*480': (832, 480), - '704*1280': (704, 1280), - '1280*704': (1280, 704) + "720*1280": (720, 1280), + "1280*720": (1280, 720), + "480*832": (480, 832), + "832*480": (832, 480), + "704*1280": (704, 1280), + "1280*704": (1280, 704), } MAX_AREA_CONFIGS = { - '720*1280': 720 * 1280, - '1280*720': 1280 * 720, - '480*832': 480 * 832, - '832*480': 832 * 480, - '704*1280': 704 * 1280, - '1280*704': 1280 * 704, + "720*1280": 720 * 1280, + "1280*720": 1280 * 720, + "480*832": 480 * 832, + "832*480": 832 * 480, + "704*1280": 704 * 1280, + "1280*704": 1280 * 704, } SUPPORTED_SIZES = { - 't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'), - 'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'), - 'ti2v-5B': ('704*1280', '1280*704'), + "t2v-A14B": ("720*1280", "1280*720", "480*832", "832*480"), + "i2v-A14B": ("720*1280", "1280*720", "480*832", "832*480"), + "ti2v-5B": ("704*1280", "1280*704", "832*480"), + "s2v-14B": ("720*1280", "1280*720", "480*832", "832*480"), + "animate-14B": ("720*1280", "1280*720", "480*832", "832*480"), } diff --git a/wan/configs/wan_animate_14B.py b/wan/configs/wan_animate_14B.py new file mode 100644 index 00000000..c935f9f2 --- /dev/null +++ b/wan/configs/wan_animate_14B.py @@ -0,0 +1,40 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +# ------------------------ Wan animate 14B ------------------------# +animate_14B = EasyDict(__name__="Config: Wan animate 14B") +animate_14B.update(wan_shared_cfg) + +animate_14B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" +animate_14B.t5_tokenizer = "google/umt5-xxl" + +animate_14B.clip_checkpoint = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" +animate_14B.clip_tokenizer = "xlm-roberta-large" +animate_14B.lora_checkpoint = "relighting_lora.ckpt" +# vae +animate_14B.vae_checkpoint = "Wan2.1_VAE.pth" +animate_14B.vae_stride = (4, 8, 8) + +# transformer +animate_14B.patch_size = (1, 2, 2) +animate_14B.dim = 5120 +animate_14B.ffn_dim = 13824 +animate_14B.freq_dim = 256 +animate_14B.num_heads = 40 +animate_14B.num_layers = 40 +animate_14B.window_size = (-1, -1) +animate_14B.qk_norm = True +animate_14B.cross_attn_norm = True +animate_14B.eps = 1e-6 +animate_14B.use_face_encoder = True +animate_14B.motion_encoder_dim = 512 + +# inference +animate_14B.sample_shift = 5.0 +animate_14B.sample_steps = 20 +animate_14B.sample_guide_scale = 1.0 +animate_14B.frame_num = 77 +animate_14B.sample_fps = 30 +animate_14B.prompt = "视频中的人在做动作" diff --git a/wan/configs/wan_s2v_14B.py b/wan/configs/wan_s2v_14B.py new file mode 100644 index 00000000..832f9e9c --- /dev/null +++ b/wan/configs/wan_s2v_14B.py @@ -0,0 +1,58 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +# ------------------------ Wan S2V 14B ------------------------# + +s2v_14B = EasyDict(__name__="Config: Wan S2V 14B") +s2v_14B.update(wan_shared_cfg) + +# t5 +s2v_14B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" +s2v_14B.t5_tokenizer = "google/umt5-xxl" + +# vae +s2v_14B.vae_checkpoint = "Wan2.1_VAE.pth" +s2v_14B.vae_stride = (4, 8, 8) + +# wav2vec +s2v_14B.wav2vec = "wav2vec2-large-xlsr-53-english" + +s2v_14B.num_heads = 40 +# transformer +s2v_14B.transformer = EasyDict(__name__="Config: Transformer config for WanModel_S2V") +s2v_14B.transformer.patch_size = (1, 2, 2) +s2v_14B.transformer.dim = 5120 +s2v_14B.transformer.ffn_dim = 13824 +s2v_14B.transformer.freq_dim = 256 +s2v_14B.transformer.num_heads = 40 +s2v_14B.transformer.num_layers = 40 +s2v_14B.transformer.window_size = (-1, -1) +s2v_14B.transformer.qk_norm = True +s2v_14B.transformer.cross_attn_norm = True +s2v_14B.transformer.eps = 1e-6 +s2v_14B.transformer.enable_adain = True +s2v_14B.transformer.adain_mode = "attn_norm" +s2v_14B.transformer.audio_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39] +s2v_14B.transformer.zero_init = True +s2v_14B.transformer.zero_timestep = True +s2v_14B.transformer.enable_motioner = False +s2v_14B.transformer.add_last_motion = True +s2v_14B.transformer.trainable_token = False +s2v_14B.transformer.enable_tsm = False +s2v_14B.transformer.enable_framepack = True +s2v_14B.transformer.framepack_drop_mode = "padd" +s2v_14B.transformer.audio_dim = 1024 + +s2v_14B.transformer.motion_frames = 73 +s2v_14B.transformer.cond_dim = 16 + +# inference +s2v_14B.sample_neg_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +s2v_14B.drop_first_motion = True +s2v_14B.sample_shift = 3 +s2v_14B.sample_steps = 40 +s2v_14B.sample_guide_scale = 4.5 +s2v_14B.sample_fps = 30 +s2v_14B.frame_num = 80 diff --git a/wan/image2video.py b/wan/image2video.py index 659564c2..14ccd19d 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -88,36 +88,44 @@ def __init__( self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, - device=torch.device('cpu'), + device=torch.device("cpu"), checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), shard_fn=shard_fn if t5_fsdp else None, ) + # Diffusers-style handle to tokenizer + self.tokenizer = self.text_encoder.tokenizer + self._text_encoder_offloaded = False self.vae_stride = config.vae_stride self.patch_size = config.patch_size self.vae = Wan2_1_VAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) + device=self.device, + ) logging.info(f"Creating WanModel from {checkpoint_dir}") self.low_noise_model = WanModel.from_pretrained( - checkpoint_dir, subfolder=config.low_noise_checkpoint) + checkpoint_dir, subfolder=config.low_noise_checkpoint + ) self.low_noise_model = self._configure_model( model=self.low_noise_model, use_sp=use_sp, dit_fsdp=dit_fsdp, shard_fn=shard_fn, - convert_model_dtype=convert_model_dtype) + convert_model_dtype=convert_model_dtype, + ) self.high_noise_model = WanModel.from_pretrained( - checkpoint_dir, subfolder=config.high_noise_checkpoint) + checkpoint_dir, subfolder=config.high_noise_checkpoint + ) self.high_noise_model = self._configure_model( model=self.high_noise_model, use_sp=use_sp, dit_fsdp=dit_fsdp, shard_fn=shard_fn, - convert_model_dtype=convert_model_dtype) + convert_model_dtype=convert_model_dtype, + ) if use_sp: self.sp_size = get_world_size() else: @@ -125,8 +133,20 @@ def __init__( self.sample_neg_prompt = config.sample_neg_prompt - def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, - convert_model_dtype): + # Optional diffusers-style API for memory optimization + def enable_model_cpu_offload(self): + self.init_on_cpu = True + + def enable_sequential_cpu_offload(self): + self.enable_model_cpu_offload() + + def enable_attention_slicing(self, *args, **kwargs): + logging.info("attention slicing not applicable; using custom attention.") + + def enable_xformers_memory_efficient_attention(self): + logging.info("xFormers toggle ignored; using built-in optimized attention.") + + def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, convert_model_dtype): """ Configures a model object. This includes setting evaluation modes, applying distributed parallel strategy, and handling device placement. @@ -153,7 +173,8 @@ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, if use_sp: for block in model.blocks: block.self_attn.forward = types.MethodType( - sp_attn_forward, block.self_attn) + sp_attn_forward, block.self_attn + ) model.forward = types.MethodType(sp_dit_forward, model) if dist.is_initialized(): @@ -187,34 +208,41 @@ def _prepare_model_for_timestep(self, t, boundary, offload_model): The active model on the target device for the current timestep. """ if t.item() >= boundary: - required_model_name = 'high_noise_model' - offload_model_name = 'low_noise_model' + required_model_name = "high_noise_model" + offload_model_name = "low_noise_model" else: - required_model_name = 'low_noise_model' - offload_model_name = 'high_noise_model' + required_model_name = "low_noise_model" + offload_model_name = "high_noise_model" if offload_model or self.init_on_cpu: - if next(getattr( - self, - offload_model_name).parameters()).device.type == 'cuda': - getattr(self, offload_model_name).to('cpu') - if next(getattr( - self, - required_model_name).parameters()).device.type == 'cpu': + if ( + next(getattr(self, offload_model_name).parameters()).device.type + == "cuda" + ): + getattr(self, offload_model_name).to("cpu") + if ( + next(getattr(self, required_model_name).parameters()).device.type + == "cpu" + ): getattr(self, required_model_name).to(self.device) return getattr(self, required_model_name) - def generate(self, - input_prompt, - img, - max_area=720 * 1280, - frame_num=81, - shift=5.0, - sample_solver='unipc', - sampling_steps=40, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True): + def generate( + self, + input_prompt, + img, + max_area=720 * 1280, + frame_num=81, + shift=5.0, + sample_solver="unipc", + sampling_steps=40, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + prompt_embeds=None, + negative_prompt_embeds=None, + precision: str = "fp16", + ): r""" Generates video frames from input image and text prompt using diffusion process. @@ -254,24 +282,37 @@ def generate(self, - W: Frame width from max_area) """ # preprocess - guide_scale = (guide_scale, guide_scale) if isinstance( - guide_scale, float) else guide_scale + guide_scale = ( + (guide_scale, guide_scale) + if isinstance(guide_scale, float) + else guide_scale + ) img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) F = frame_num h, w = img.shape[1:] aspect_ratio = h / w lat_h = round( - np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // - self.patch_size[1] * self.patch_size[1]) + np.sqrt(max_area * aspect_ratio) + // self.vae_stride[1] + // self.patch_size[1] + * self.patch_size[1] + ) lat_w = round( - np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // - self.patch_size[2] * self.patch_size[2]) + np.sqrt(max_area / aspect_ratio) + // self.vae_stride[2] + // self.patch_size[2] + * self.patch_size[2] + ) h = lat_h * self.vae_stride[1] w = lat_w * self.vae_stride[2] - max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( - self.patch_size[1] * self.patch_size[2]) + max_seq_len = ( + ((F - 1) // self.vae_stride[0] + 1) + * lat_h + * lat_w + // (self.patch_size[1] * self.patch_size[2]) + ) max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size seed = seed if seed >= 0 else random.randint(0, sys.maxsize) @@ -284,80 +325,138 @@ def generate(self, lat_w, dtype=torch.float32, generator=seed_g, - device=self.device) + device=self.device, + ) msk = torch.ones(1, F, lat_h, lat_w, device=self.device) msk[:, 1:] = 0 - msk = torch.concat([ - torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] - ], - dim=1) + msk = torch.concat( + [torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1 + ) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.transpose(1, 2)[0] if n_prompt == "": n_prompt = self.sample_neg_prompt - # preprocess - if not self.t5_cpu: - self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if offload_model: - self.text_encoder.model.cpu() + # Text conditioning + if prompt_embeds is not None or negative_prompt_embeds is not None: + if isinstance(guide_scale, tuple) and (negative_prompt_embeds is None): + raise ValueError( + "negative_prompt_embeds must be provided when using guidance." + ) + context = prompt_embeds + context_null = ( + negative_prompt_embeds if negative_prompt_embeds is not None else [] + ) + if self.text_encoder is not None: + logging.warning( + "prompt_embeds provided; preventing redundant text encoding." + ) + if hasattr(self.text_encoder, "model"): + self.text_encoder.model.cpu() + self._text_encoder_offloaded = True + target_dtype = torch.float16 if precision == "fp16" else torch.bfloat16 + context = [t.to(dtype=target_dtype, device=self.device) for t in context] + context_null = [ + t.to(dtype=target_dtype, device=self.device) for t in context_null + ] else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) - context = [t.to(self.device) for t in context] - context_null = [t.to(self.device) for t in context_null] - - y = self.vae.encode([ - torch.concat([ - torch.nn.functional.interpolate( - img[None].cpu(), size=(h, w), mode='bicubic').transpose( - 0, 1), - torch.zeros(3, F - 1, h, w) - ], - dim=1).to(self.device) - ])[0] + if self.text_encoder is None: + raise RuntimeError( + "text_encoder is not available. Provide prompt_embeds to generate()." + ) + with torch.inference_mode(): + self.text_encoder.model.eval() + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + else: + context = self.text_encoder([input_prompt], torch.device("cpu")) + context_null = self.text_encoder([n_prompt], torch.device("cpu")) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + target_dtype = torch.float16 if precision == "fp16" else torch.bfloat16 + context = [ + t.to(dtype=target_dtype, device=self.device) for t in context + ] + context_null = [ + t.to(dtype=target_dtype, device=self.device) for t in context_null + ] + try: + if hasattr(torch.cuda, "memory_summary") and torch.cuda.is_available(): + logging.info( + "CUDA memory before T5 offload:\n" + + torch.cuda.memory_summary(device=self.device) + ) + except Exception: + pass + if hasattr(self.text_encoder, "model"): + self.text_encoder.model.cpu() + self._text_encoder_offloaded = True + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + try: + logging.info( + "CUDA memory after T5 offload:\n" + + torch.cuda.memory_summary(device=self.device) + ) + except Exception: + pass + + y = self.vae.encode( + [ + torch.concat( + [ + torch.nn.functional.interpolate( + img[None].cpu(), size=(h, w), mode="bicubic" + ).transpose(0, 1), + torch.zeros(3, F - 1, h, w), + ], + dim=1, + ).to(self.device) + ] + )[0] y = torch.concat([msk, y]) @contextmanager def noop_no_sync(): yield - no_sync_low_noise = getattr(self.low_noise_model, 'no_sync', - noop_no_sync) - no_sync_high_noise = getattr(self.high_noise_model, 'no_sync', - noop_no_sync) + no_sync_low_noise = getattr(self.low_noise_model, "no_sync", noop_no_sync) + no_sync_high_noise = getattr(self.high_noise_model, "no_sync", noop_no_sync) # evaluation mode with ( - torch.amp.autocast('cuda', dtype=self.param_dtype), - torch.no_grad(), - no_sync_low_noise(), - no_sync_high_noise(), + torch.amp.autocast("cuda", dtype=self.param_dtype), + torch.no_grad(), + no_sync_low_noise(), + no_sync_high_noise(), ): boundary = self.boundary * self.num_train_timesteps - if sample_solver == 'unipc': + if sample_solver == "unipc": sample_scheduler = FlowUniPCMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, - use_dynamic_shifting=False) + use_dynamic_shifting=False, + ) sample_scheduler.set_timesteps( - sampling_steps, device=self.device, shift=shift) + sampling_steps, device=self.device, shift=shift + ) timesteps = sample_scheduler.timesteps - elif sample_solver == 'dpm++': + elif sample_solver == "dpm++": sample_scheduler = FlowDPMSolverMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, - use_dynamic_shifting=False) + use_dynamic_shifting=False, + ) sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) timesteps, _ = retrieve_timesteps( - sample_scheduler, - device=self.device, - sigmas=sampling_sigmas) + sample_scheduler, device=self.device, sigmas=sampling_sigmas + ) else: raise NotImplementedError("Unsupported solver.") @@ -365,15 +464,15 @@ def noop_no_sync(): latent = noise arg_c = { - 'context': [context[0]], - 'seq_len': max_seq_len, - 'y': [y], + "context": [context[0]], + "seq_len": max_seq_len, + "y": [y], } arg_null = { - 'context': context_null, - 'seq_len': max_seq_len, - 'y': [y], + "context": context_null, + "seq_len": max_seq_len, + "y": [y], } if offload_model: @@ -385,28 +484,28 @@ def noop_no_sync(): timestep = torch.stack(timestep).to(self.device) - model = self._prepare_model_for_timestep( - t, boundary, offload_model) - sample_guide_scale = guide_scale[1] if t.item( - ) >= boundary else guide_scale[0] + model = self._prepare_model_for_timestep(t, boundary, offload_model) + sample_guide_scale = ( + guide_scale[1] if t.item() >= boundary else guide_scale[0] + ) - noise_pred_cond = model( - latent_model_input, t=timestep, **arg_c)[0] + noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0] if offload_model: torch.cuda.empty_cache() - noise_pred_uncond = model( - latent_model_input, t=timestep, **arg_null)[0] + noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0] if offload_model: torch.cuda.empty_cache() noise_pred = noise_pred_uncond + sample_guide_scale * ( - noise_pred_cond - noise_pred_uncond) + noise_pred_cond - noise_pred_uncond + ) temp_x0 = sample_scheduler.step( noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, - generator=seed_g)[0] + generator=seed_g, + )[0] latent = temp_x0.squeeze(0) x0 = [latent] diff --git a/wan/modules/animate/__init__.py b/wan/modules/animate/__init__.py new file mode 100644 index 00000000..af426277 --- /dev/null +++ b/wan/modules/animate/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from .model_animate import WanAnimateModel +from .clip import CLIPModel + +__all__ = ["WanAnimateModel", "CLIPModel"] diff --git a/wan/modules/animate/animate_utils.py b/wan/modules/animate/animate_utils.py new file mode 100644 index 00000000..e228483b --- /dev/null +++ b/wan/modules/animate/animate_utils.py @@ -0,0 +1,143 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import numbers +from peft import LoraConfig + + +def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"): + target_modules = [] + for name, module in transformer.named_modules(): + if ( + "blocks" in name + and "face" not in name + and "modulation" not in name + and isinstance(module, torch.nn.Linear) + ): + target_modules.append(name) + + transformer_lora_config = LoraConfig( + r=rank, + lora_alpha=alpha, + init_lora_weights=init_lora_weights, + target_modules=target_modules, + ) + return transformer_lora_config + + +class TensorList(object): + + def __init__(self, tensors): + """ + tensors: a list of torch.Tensor objects. No need to have uniform shape. + """ + assert isinstance(tensors, (list, tuple)) + assert all(isinstance(u, torch.Tensor) for u in tensors) + assert len(set([u.ndim for u in tensors])) == 1 + assert len(set([u.dtype for u in tensors])) == 1 + assert len(set([u.device for u in tensors])) == 1 + self.tensors = tensors + + def to(self, *args, **kwargs): + return TensorList([u.to(*args, **kwargs) for u in self.tensors]) + + def size(self, dim): + assert dim == 0, "only support get the 0th size" + return len(self.tensors) + + def pow(self, *args, **kwargs): + return TensorList([u.pow(*args, **kwargs) for u in self.tensors]) + + def squeeze(self, dim): + assert dim != 0 + if dim > 0: + dim -= 1 + return TensorList([u.squeeze(dim) for u in self.tensors]) + + def type(self, *args, **kwargs): + return TensorList([u.type(*args, **kwargs) for u in self.tensors]) + + def type_as(self, other): + assert isinstance(other, (torch.Tensor, TensorList)) + if isinstance(other, torch.Tensor): + return TensorList([u.type_as(other) for u in self.tensors]) + else: + return TensorList([u.type(other.dtype) for u in self.tensors]) + + @property + def dtype(self): + return self.tensors[0].dtype + + @property + def device(self): + return self.tensors[0].device + + @property + def ndim(self): + return 1 + self.tensors[0].ndim + + def __getitem__(self, index): + return self.tensors[index] + + def __len__(self): + return len(self.tensors) + + def __add__(self, other): + return self._apply(other, lambda u, v: u + v) + + def __radd__(self, other): + return self._apply(other, lambda u, v: v + u) + + def __sub__(self, other): + return self._apply(other, lambda u, v: u - v) + + def __rsub__(self, other): + return self._apply(other, lambda u, v: v - u) + + def __mul__(self, other): + return self._apply(other, lambda u, v: u * v) + + def __rmul__(self, other): + return self._apply(other, lambda u, v: v * u) + + def __floordiv__(self, other): + return self._apply(other, lambda u, v: u // v) + + def __truediv__(self, other): + return self._apply(other, lambda u, v: u / v) + + def __rfloordiv__(self, other): + return self._apply(other, lambda u, v: v // u) + + def __rtruediv__(self, other): + return self._apply(other, lambda u, v: v / u) + + def __pow__(self, other): + return self._apply(other, lambda u, v: u**v) + + def __rpow__(self, other): + return self._apply(other, lambda u, v: v**u) + + def __neg__(self): + return TensorList([-u for u in self.tensors]) + + def __iter__(self): + for tensor in self.tensors: + yield tensor + + def __repr__(self): + return "TensorList: \n" + repr(self.tensors) + + def _apply(self, other, op): + if isinstance(other, (list, tuple, TensorList)) or ( + isinstance(other, torch.Tensor) and (other.numel() > 1 or other.ndim > 1) + ): + assert len(other) == len(self.tensors) + return TensorList([op(u, v) for u, v in zip(self.tensors, other)]) + elif isinstance(other, numbers.Number) or ( + isinstance(other, torch.Tensor) and (other.numel() == 1 and other.ndim <= 1) + ): + return TensorList([op(u, other) for u in self.tensors]) + else: + raise TypeError( + f'unsupported operand for *: "TensorList" and "{type(other)}"' + ) diff --git a/wan/modules/animate/clip.py b/wan/modules/animate/clip.py new file mode 100644 index 00000000..09e908a8 --- /dev/null +++ b/wan/modules/animate/clip.py @@ -0,0 +1,586 @@ +# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T + +from ..attention import flash_attention +from ..tokenizers import HuggingfaceTokenizer +from .xlm_roberta import XLMRoberta + +__all__ = [ + "XLMRobertaCLIP", + "clip_xlm_roberta_vit_h_14", + "CLIPModel", +] + + +def pos_interpolate(pos, seq_len): + if pos.size(1) == seq_len: + return pos + else: + src_grid = int(math.sqrt(pos.size(1))) + tar_grid = int(math.sqrt(seq_len)) + n = pos.size(1) - src_grid * src_grid + return torch.cat( + [ + pos[:, :n], + F.interpolate( + pos[:, n:] + .float() + .reshape(1, src_grid, src_grid, -1) + .permute(0, 3, 1, 2), + size=(tar_grid, tar_grid), + mode="bicubic", + align_corners=False, + ) + .flatten(2) + .transpose(1, 2), + ], + dim=1, + ) + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__( + self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0 + ): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.causal = causal + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) + + # compute attention + p = self.attn_dropout if self.training else 0.0 + x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) + x = x.reshape(b, s, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + return x + + +class SwiGLU(nn.Module): + + def __init__(self, dim, mid_dim): + super().__init__() + self.dim = dim + self.mid_dim = mid_dim + + # layers + self.fc1 = nn.Linear(dim, mid_dim) + self.fc2 = nn.Linear(dim, mid_dim) + self.fc3 = nn.Linear(mid_dim, dim) + + def forward(self, x): + x = F.silu(self.fc1(x)) * self.fc2(x) + x = self.fc3(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__( + self, + dim, + mlp_ratio, + num_heads, + post_norm=False, + causal=False, + activation="quick_gelu", + attn_dropout=0.0, + proj_dropout=0.0, + norm_eps=1e-5, + ): + assert activation in ["quick_gelu", "gelu", "swi_glu"] + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.post_norm = post_norm + self.causal = causal + self.norm_eps = norm_eps + + # layers + self.norm1 = LayerNorm(dim, eps=norm_eps) + self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout) + self.norm2 = LayerNorm(dim, eps=norm_eps) + if activation == "swi_glu": + self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) + else: + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == "quick_gelu" else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), + nn.Dropout(proj_dropout), + ) + + def forward(self, x): + if self.post_norm: + x = x + self.norm1(self.attn(x)) + x = x + self.norm2(self.mlp(x)) + else: + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class AttentionPool(nn.Module): + + def __init__( + self, + dim, + mlp_ratio, + num_heads, + activation="gelu", + proj_dropout=0.0, + norm_eps=1e-5, + ): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.proj_dropout = proj_dropout + self.norm_eps = norm_eps + + # layers + gain = 1.0 / math.sqrt(dim) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.to_q = nn.Linear(dim, dim) + self.to_kv = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + self.norm = LayerNorm(dim, eps=norm_eps) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == "quick_gelu" else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), + nn.Dropout(proj_dropout), + ) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) + k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) + + # compute attention + x = flash_attention(q, k, v, version=2) + x = x.reshape(b, 1, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + + # mlp + x = x + self.mlp(self.norm(x)) + return x[:, 0] + + +class VisionTransformer(nn.Module): + + def __init__( + self, + image_size=224, + patch_size=16, + dim=768, + mlp_ratio=4, + out_dim=512, + num_heads=12, + num_layers=12, + pool_type="token", + pre_norm=True, + post_norm=False, + activation="quick_gelu", + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5, + ): + if image_size % patch_size != 0: + print("[WARNING] image_size is not divisible by patch_size", flush=True) + assert pool_type in ("token", "token_fc", "attn_pool") + out_dim = out_dim or dim + super().__init__() + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size) ** 2 + self.dim = dim + self.mlp_ratio = mlp_ratio + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pool_type = pool_type + self.post_norm = post_norm + self.norm_eps = norm_eps + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm + ) + if pool_type in ("token", "token_fc"): + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter( + gain + * torch.randn( + 1, + self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), + dim, + ) + ) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + self.transformer = nn.Sequential( + *[ + AttentionBlock( + dim, + mlp_ratio, + num_heads, + post_norm, + False, + activation, + attn_dropout, + proj_dropout, + norm_eps, + ) + for _ in range(num_layers) + ] + ) + self.post_norm = LayerNorm(dim, eps=norm_eps) + + # head + if pool_type == "token": + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + elif pool_type == "token_fc": + self.head = nn.Linear(dim, out_dim) + elif pool_type == "attn_pool": + self.head = AttentionPool( + dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps + ) + + def forward(self, x, interpolation=False, use_31_block=False): + b = x.size(0) + + # embeddings + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + if self.pool_type in ("token", "token_fc"): + x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) + if interpolation: + e = pos_interpolate(self.pos_embedding, x.size(1)) + else: + e = self.pos_embedding + x = self.dropout(x + e) + if self.pre_norm is not None: + x = self.pre_norm(x) + + # transformer + if use_31_block: + x = self.transformer[:-1](x) + return x + else: + x = self.transformer(x) + return x + + +class XLMRobertaWithHead(XLMRoberta): + + def __init__(self, **kwargs): + self.out_dim = kwargs.pop("out_dim") + super().__init__(**kwargs) + + # head + mid_dim = (self.dim + self.out_dim) // 2 + self.head = nn.Sequential( + nn.Linear(self.dim, mid_dim, bias=False), + nn.GELU(), + nn.Linear(mid_dim, self.out_dim, bias=False), + ) + + def forward(self, ids): + # xlm-roberta + x = super().forward(ids) + + # average pooling + mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) + x = (x * mask).sum(dim=1) / mask.sum(dim=1) + + # head + x = self.head(x) + return x + + +class XLMRobertaCLIP(nn.Module): + + def __init__( + self, + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool="token", + vision_pre_norm=True, + vision_post_norm=False, + activation="gelu", + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5, + ): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.activation = activation + self.vocab_size = vocab_size + self.max_text_len = max_text_len + self.type_size = type_size + self.pad_id = pad_id + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + self.text_post_norm = text_post_norm + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps, + ) + self.textual = XLMRobertaWithHead( + vocab_size=vocab_size, + max_seq_len=max_text_len, + type_size=type_size, + pad_id=pad_id, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + post_norm=text_post_norm, + dropout=text_dropout, + ) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. + Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def param_groups(self): + groups = [ + { + "params": [ + p + for n, p in self.named_parameters() + if "norm" in n or n.endswith("bias") + ], + "weight_decay": 0.0, + }, + { + "params": [ + p + for n, p in self.named_parameters() + if not ("norm" in n or n.endswith("bias")) + ] + }, + ] + return groups + + +def _clip( + pretrained=False, + pretrained_name=None, + model_cls=XLMRobertaCLIP, + return_transforms=False, + return_tokenizer=False, + tokenizer_padding="eos", + dtype=torch.float32, + device="cpu", + **kwargs, +): + # init a model on device + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + output = (model,) + + # init transforms + if return_transforms: + # mean and std + if "siglip" in pretrained_name.lower(): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # transforms + transforms = T.Compose( + [ + T.Resize( + (model.image_size, model.image_size), + interpolation=T.InterpolationMode.BICUBIC, + ), + T.ToTensor(), + T.Normalize(mean=mean, std=std), + ] + ) + output += (transforms,) + return output[0] if len(output) == 1 else output + + +def clip_xlm_roberta_vit_h_14( + pretrained=False, + pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", + **kwargs, +): + cfg = dict( + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool="token", + activation="gelu", + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + ) + cfg.update(**kwargs) + return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) + + +class CLIPModel: + + def __init__(self, dtype, device, checkpoint_path, tokenizer_path): + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + self.model, self.transforms = clip_xlm_roberta_vit_h_14( + pretrained=False, + return_transforms=True, + return_tokenizer=False, + dtype=dtype, + device=device, + ) + self.model = self.model.eval().requires_grad_(False) + logging.info(f"loading {checkpoint_path}") + self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) + + # init tokenizer + self.tokenizer = HuggingfaceTokenizer( + name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace" + ) + + def visual(self, videos): + # preprocess + size = (self.model.image_size,) * 2 + videos = torch.cat( + [ + F.interpolate( + u.transpose(0, 1), size=size, mode="bicubic", align_corners=False + ) + for u in videos + ] + ) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) + + # forward + with torch.cuda.amp.autocast(dtype=self.dtype): + out = self.model.visual(videos, use_31_block=True) + return out diff --git a/wan/modules/animate/face_blocks.py b/wan/modules/animate/face_blocks.py new file mode 100644 index 00000000..76e75495 --- /dev/null +++ b/wan/modules/animate/face_blocks.py @@ -0,0 +1,419 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from torch import nn +import torch +from typing import Tuple, Optional +from einops import rearrange +import torch.nn.functional as F +import math +from ...distributed.util import gather_forward, get_rank, get_world_size + + +try: + from flash_attn import flash_attn_qkvpacked_func, flash_attn_func +except ImportError: + flash_attn_func = None + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="flash", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal + ) + + elif mode == "flash": + x = flash_attn_func( + q, + k, + v, + ) + x = x.view( + batch_size, max_seqlen_q, x.shape[-2], x.shape[-1] + ) # reshape x to [b, s, a, d] + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert ( + attn_mask is None + ), "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril( + diagonal=0 + ) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + + def __init__( + self, + chan_in, + chan_out, + kernel_size=3, + stride=1, + dilation=1, + pad_mode="replicate", + **kwargs, + ): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d( + chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs + ) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class FaceEncoder(nn.Module): + def __init__( + self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) + self.norm1 = nn.LayerNorm( + hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm( + 1024, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + self.norm2 = nn.LayerNorm( + 1024, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + self.norm3 = nn.LayerNorm( + 1024, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, + device=None, + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + + self.pre_norm_feat = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + self.pre_norm_motion = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) N H D") + v = rearrange(v, "B L N H D -> (B L) N H D") + + if use_context_parallel: + q = gather_forward(q, dim=1) + + q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp) + # Compute attention. + attn = attention( + q, + k, + v, + max_seqlen_q=q.shape[1], + batch_size=q.shape[0], + ) + + attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp) + if use_context_parallel: + attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()] + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze( + -1 + ) + + return output diff --git a/wan/modules/animate/model_animate.py b/wan/modules/animate/model_animate.py new file mode 100644 index 00000000..f6078cf6 --- /dev/null +++ b/wan/modules/animate/model_animate.py @@ -0,0 +1,530 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +import types +from copy import deepcopy +from einops import rearrange +from typing import List +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.loaders import PeftAdapterMixin + +from ...distributed.sequence_parallel import ( + distributed_attention, + gather_forward, + get_rank, + get_world_size, +) + + +from ..model import ( + Head, + WanAttentionBlock, + WanLayerNorm, + WanRMSNorm, + WanModel, + WanSelfAttention, + flash_attention, + rope_params, + sinusoidal_embedding_1d, + rope_apply, +) + +from .face_blocks import FaceEncoder, FaceAdapter +from .motion_encoder import Generator + + +class HeadAnimate(Head): + + def forward(self, x, e): + """ + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, L1, C] + """ + assert e.dtype == torch.float32 + with amp.autocast(dtype=torch.float32): + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) + return x + + +class WanAnimateSelfAttention(WanSelfAttention): + + def forward(self, x, seq_lens, grid_sizes, freqs): + """ + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + x = flash_attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size, + ) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanAnimateCrossAttention(WanSelfAttention): + def __init__( + self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6, + use_img_emb=True, + ): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + self.use_img_emb = use_img_emb + + if use_img_emb: + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context, context_lens): + """ + x: [B, L1, C]. + context: [B, L2, C]. + context_lens: [B]. + """ + if self.use_img_emb: + context_img = context[:, :257] + context = context[:, 257:] + else: + context = context + + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + if self.use_img_emb: + k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) + v_img = self.v_img(context_img).view(b, -1, n, d) + img_x = flash_attention(q, k_img, v_img, k_lens=None) + # compute attention + x = flash_attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) + + if self.use_img_emb: + img_x = img_x.flatten(2) + x = x + img_x + + x = self.o(x) + return x + + +class WanAnimateAttentionBlock(nn.Module): + def __init__( + self, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + use_img_emb=True, + ): + + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanAnimateSelfAttention( + dim, num_heads, window_size, qk_norm, eps + ) + + self.norm3 = ( + WanLayerNorm(dim, eps, elementwise_affine=True) + if cross_attn_norm + else nn.Identity() + ) + + self.cross_attn = WanAnimateCrossAttention( + dim, num_heads, (-1, -1), qk_norm, eps, use_img_emb=use_img_emb + ) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), + nn.GELU(approximate="tanh"), + nn.Linear(ffn_dim, dim), + ) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + ): + """ + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, L1, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 + with amp.autocast(dtype=torch.float32): + e = (self.modulation + e).chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn( + self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs + ) + with amp.autocast(dtype=torch.float32): + x = x + y * e[2] + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + x = x + self.cross_attn(self.norm3(x), context, context_lens) + y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) + with amp.autocast(dtype=torch.float32): + x = x + y * e[5] + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class MLPProj(torch.nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), + torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), + torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim), + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _no_split_modules = ["WanAttentionBlock"] + + @register_to_config + def __init__( + self, + patch_size=(1, 2, 2), + text_len=512, + in_dim=36, + dim=5120, + ffn_dim=13824, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=40, + num_layers=40, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + motion_encoder_dim=512, + use_context_parallel=False, + use_img_emb=True, + ): + + super().__init__() + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + self.motion_encoder_dim = motion_encoder_dim + self.use_context_parallel = use_context_parallel + self.use_img_emb = use_img_emb + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size + ) + + self.pose_patch_embedding = nn.Conv3d( + 16, dim, kernel_size=patch_size, stride=patch_size + ) + + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim) + ) + + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + self.blocks = nn.ModuleList( + [ + WanAnimateAttentionBlock( + dim, + ffn_dim, + num_heads, + window_size, + qk_norm, + cross_attn_norm, + eps, + use_img_emb, + ) + for _ in range(num_layers) + ] + ) + + # head + self.head = HeadAnimate(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + dim=1, + ) + + self.img_emb = MLPProj(1280, dim) + + # initialize weights + self.init_weights() + + self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) + self.face_adapter = FaceAdapter( + heads_num=self.num_heads, + hidden_dim=self.dim, + num_adapter_layers=self.num_layers // 5, + ) + + self.face_encoder = FaceEncoder( + in_dim=motion_encoder_dim, + hidden_dim=self.dim, + num_heads=4, + ) + + def after_patch_embedding( + self, x: List[torch.Tensor], pose_latents, face_pixel_values + ): + pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents] + for x_, pose_latents_ in zip(x, pose_latents): + x_[:, :, 1:] += pose_latents_ + + b, c, T, h, w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0] / encode_bs)): + face_pixel_values_tmp.append( + self.motion_encoder.get_motion( + face_pixel_values[i * encode_bs : (i + 1) * encode_bs] + ) + ) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + return x, motion_vec + + def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None): + if block_idx % 5 == 0: + adapter_args = [x, motion_vec, motion_masks, self.use_context_parallel] + residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args) + x = residual_out + x + return x + + def forward( + self, + x, + t, + clip_fea, + context, + seq_len, + y=None, + pose_latents=None, + face_pixel_values=None, + ): + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values) + + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x] + ) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat( + [ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) + for u in x + ] + ) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack( + [ + torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ] + ) + ) + + if self.use_img_emb: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + ) + + if self.use_context_parallel: + x = torch.chunk(x, get_world_size(), dim=1)[get_rank()] + + for idx, block in enumerate(self.blocks): + x = block(x, **kwargs) + x = self.after_transformer_block(idx, x, motion_vec) + + # head + x = self.head(x, e) + + if self.use_context_parallel: + x = gather_forward(x, dim=1) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[: math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum("fhwpqrc->cfphqwr", u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) diff --git a/wan/modules/animate/motion_encoder.py b/wan/modules/animate/motion_encoder.py new file mode 100644 index 00000000..a5b5cbd4 --- /dev/null +++ b/wan/modules/animate/motion_encoder.py @@ -0,0 +1,351 @@ +# Modified from ``https://github.com/wyhsirius/LIA`` +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +from torch.nn import functional as F +import math + + +def custom_qr(input_tensor): + original_dtype = input_tensor.dtype + if original_dtype == torch.bfloat16: + q, r = torch.linalg.qr(input_tensor.to(torch.float32)) + return q.to(original_dtype), r.to(original_dtype) + return torch.linalg.qr(input_tensor) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[ + :, + :, + max(-pad_y0, 0) : out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[3] - max(-pad_x1, 0), + ] + + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor**2) + + self.register_buffer("kernel", kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size**2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + + return F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," + f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" + ) + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class EncoderApp(nn.Module): + def __init__(self, size, w_dim=512): + super(EncoderApp, self).__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + 1024: 16, + } + + self.w_dim = w_dim + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False)) + + def forward(self, x): + + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:] + + +class Encoder(nn.Module): + def __init__(self, size, dim=512, dim_motion=20): + super(Encoder, self).__init__() + + # appearance netmork + self.net_app = EncoderApp(size, dim) + + # motion network + fc = [EqualLinear(dim, dim)] + for i in range(3): + fc.append(EqualLinear(dim, dim)) + + fc.append(EqualLinear(dim, dim_motion)) + self.fc = nn.Sequential(*fc) + + def enc_app(self, x): + h_source = self.net_app(x) + return h_source + + def enc_motion(self, x): + h, _ = self.net_app(x) + h_motion = self.fc(h) + return h_motion + + +class Direction(nn.Module): + def __init__(self, motion_dim): + super(Direction, self).__init__() + self.weight = nn.Parameter(torch.randn(512, motion_dim)) + + def forward(self, input): + + weight = self.weight + 1e-8 + Q, R = custom_qr(weight) + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class Synthesis(nn.Module): + def __init__(self, motion_dim): + super(Synthesis, self).__init__() + self.direction = Direction(motion_dim) + + +class Generator(nn.Module): + def __init__(self, size, style_dim=512, motion_dim=20): + super().__init__() + + self.enc = Encoder(size, style_dim, motion_dim) + self.dec = Synthesis(motion_dim) + + def get_motion(self, img): + # motion_feat = self.enc.enc_motion(img) + motion_feat = torch.utils.checkpoint.checkpoint( + (self.enc.enc_motion), img, use_reentrant=True + ) + with torch.cuda.amp.autocast(dtype=torch.float32): + motion = self.dec.direction(motion_feat) + return motion diff --git a/wan/modules/animate/preprocess/UserGuider.md b/wan/modules/animate/preprocess/UserGuider.md new file mode 100644 index 00000000..b40f7f3d --- /dev/null +++ b/wan/modules/animate/preprocess/UserGuider.md @@ -0,0 +1,70 @@ +# Wan-animate Preprocessing User Guider + +## 1. Introductions + + +Wan-animate offers two generation modes: `animation` and `replacement`. While both modes extract the skeleton from the reference video, they each have a distinct preprocessing pipeline. + +### 1.1 Animation Mode + +In this mode, it is highly recommended to enable pose retargeting, especially if the body proportions of the reference and driving characters are dissimilar. + + - A simplified version of pose retargeting pipeline is provided to help developers quickly implement this functionality. + + - **NOTE:** Due to the potential complexity of input data, the results from this simplified retargeting version are NOT guaranteed to be perfect. It is strongly advised to verify the preprocessing results before proceeding. + + - Community contributions to improve on this feature are welcome. + +### 1.2 Replacement Mode + + - Pose retargeting is DISABLED by default in this mode. This is a deliberate choice to account for potential spatial interactions between the character and the environment. + + - **WARNING**: If there is a significant mismatch in body proportions between the reference and driving characters, artifacts or deformations may appear in the final output. + + - A simplified version for extracting the character's mask is also provided. + - **WARNING:** This mask extraction process is designed for **single-person videos ONLY** and may produce incorrect results or fail in multi-person videos (incorrect pose tracking). For multi-person video, users are required to either develop their own solution or integrate a suitable open-source tool. + +--- + +## 2. Preprocessing Instructions and Recommendations + +### 2.1 Basic Usage + +- The preprocessing process requires some additional models, including pose detection (mandatory), and mask extraction and image editing models (optional, as needed). Place them according to the following directory structure: +``` + /path/to/your/ckpt_path/ + ├── det/ + │ └── yolov10m.onnx + ├── pose2d/ + │ └── vitpose_h_wholebody.onnx + ├── sam2/ + │ └── sam2_hiera_large.pt + └── FLUX.1-Kontext-dev/ +``` +- `video_path`, `refer_path`, and `save_path` correspond to the paths for the input driving video, the character image, and the preprocessed results. + +- When using `animation` mode, two videos, `src_face.mp4` and `src_pose.mp4`, will be generated in `save_path`. When using `replacement` mode, two additional videos, `src_bg.mp4` and `src_mask.mp4`, will also be generated. + +- The `resolution_area` parameter determines the resolution for both preprocessing and the generation model. Its size is determined by pixel area. + +- The `fps` parameter can specify the frame rate for video processing. A lower frame rate can improve generation efficiency, but may cause stuttering or choppiness. + +--- + +### 2.2 Animation Mode + +- We support three forms: not using pose retargeting, using basic pose retargeting, and using enhanced pose retargeting based on the `FLUX.1-Kontext-dev` image editing model. These are specified via the `retarget_flag` and `use_flux` parameters. + +- Specifying `retarget_flag` to use basic pose retargeting requires ensuring that both the reference character and the character in the first frame of the driving video are in a front-facing, stretched pose. + +- Other than that, we recommend using enhanced pose retargeting by specifying both `retarget_flag` and `use_flux`. **NOTE:** Due to the limited capabilities of `FLUX.1-Kontext-dev`, it is NOT guaranteed to produce the expected results (e.g., consistency is not maintained, the pose is incorrect, etc.). It is recommended to check the intermediate results as well as the finally generated pose video; both are stored in `save_path`. Of course, users can also use a better image editing model, or explore the prompts for Flux on their own. + +--- + +### 2.3 Replacement Mode + +- Specifying `replace_flag` to enable data preprocessing for this mode. The preprocessing will additionally process a mask for the character in the video, and its size and shape can be adjusted by specifying some parameters. +- `iterations` and `k` can make the mask larger, covering more area. +- `w_len` and `h_len` can adjust the mask's shape. Smaller values will make the outline coarser, while larger values will make it finer. + +- A smaller, finer-contoured mask can allow for more of the original background to be preserved, but may potentially limit the character's generation area (considering potential appearance differences, this can lead to some shape leakage). A larger, coarser mask can allow the character generation to be more flexible and consistent, but because it includes more of the background, it might affect the background's consistency. We recommend users to adjust the relevant parameters based on their specific input data. \ No newline at end of file diff --git a/wan/modules/animate/preprocess/__init__.py b/wan/modules/animate/preprocess/__init__.py new file mode 100644 index 00000000..7b76b624 --- /dev/null +++ b/wan/modules/animate/preprocess/__init__.py @@ -0,0 +1,3 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from .process_pipepline import ProcessPipeline +from .video_predictor import SAM2VideoPredictor diff --git a/wan/modules/animate/preprocess/human_visualization.py b/wan/modules/animate/preprocess/human_visualization.py new file mode 100644 index 00000000..49b0ee57 --- /dev/null +++ b/wan/modules/animate/preprocess/human_visualization.py @@ -0,0 +1,1485 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import os +import cv2 +import time +import math +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from typing import Dict, List +import random +from .pose2d_utils import AAPoseMeta + + +def draw_handpose(canvas, keypoints, hand_score_th=0.6): + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + eps = 0.01 + + H, W, C = canvas.shape + stickwidth = max(int(min(H, W) / 200), 1) + + edges = [ + [0, 1], + [1, 2], + [2, 3], + [3, 4], + [0, 5], + [5, 6], + [6, 7], + [7, 8], + [0, 9], + [9, 10], + [10, 11], + [11, 12], + [0, 13], + [13, 14], + [14, 15], + [15, 16], + [0, 17], + [17, 18], + [18, 19], + [19, 20], + ] + + for ie, (e1, e2) in enumerate(edges): + k1 = keypoints[e1] + k2 = keypoints[e2] + if k1 is None or k2 is None: + continue + if k1[2] < hand_score_th or k2[2] < hand_score_th: + continue + + x1 = int(k1[0]) + y1 = int(k1[1]) + x2 = int(k2[0]) + y2 = int(k2[1]) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line( + canvas, + (x1, y1), + (x2, y2), + matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, + thickness=stickwidth, + ) + + for keypoint in keypoints: + + if keypoint is None: + continue + if keypoint[2] < hand_score_th: + continue + + x, y = keypoint[0], keypoint[1] + x = int(x) + y = int(y) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1) + return canvas + + +def draw_handpose_new(canvas, keypoints, stickwidth_type="v2", hand_score_th=0.6): + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + eps = 0.01 + + H, W, C = canvas.shape + if stickwidth_type == "v1": + stickwidth = max(int(min(H, W) / 200), 1) + elif stickwidth_type == "v2": + stickwidth = max(max(int(min(H, W) / 200) - 1, 1) // 2, 1) + + edges = [ + [0, 1], + [1, 2], + [2, 3], + [3, 4], + [0, 5], + [5, 6], + [6, 7], + [7, 8], + [0, 9], + [9, 10], + [10, 11], + [11, 12], + [0, 13], + [13, 14], + [14, 15], + [15, 16], + [0, 17], + [17, 18], + [18, 19], + [19, 20], + ] + + for ie, (e1, e2) in enumerate(edges): + k1 = keypoints[e1] + k2 = keypoints[e2] + if k1 is None or k2 is None: + continue + if k1[2] < hand_score_th or k2[2] < hand_score_th: + continue + + x1 = int(k1[0]) + y1 = int(k1[1]) + x2 = int(k2[0]) + y2 = int(k2[1]) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line( + canvas, + (x1, y1), + (x2, y2), + matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, + thickness=stickwidth, + ) + + for keypoint in keypoints: + + if keypoint is None: + continue + if keypoint[2] < hand_score_th: + continue + + x, y = keypoint[0], keypoint[1] + x = int(x) + y = int(y) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1) + return canvas + + +def draw_ellipse_by_2kp(img, keypoint1, keypoint2, color, threshold=0.6): + H, W, C = img.shape + stickwidth = max(int(min(H, W) / 200), 1) + + if keypoint1[-1] < threshold or keypoint2[-1] < threshold: + return img + + Y = np.array([keypoint1[0], keypoint2[0]]) + X = np.array([keypoint1[1], keypoint2[1]]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly( + (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1 + ) + cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) + return img + + +def split_pose2d_kps_to_aa(kp2ds: np.ndarray) -> List[np.ndarray]: + """Convert the 133 keypoints from pose2d to body and hands keypoints. + + Args: + kp2ds (np.ndarray): [133, 2] + + Returns: + List[np.ndarray]: _description_ + """ + kp2ds_body = ( + kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]] + ) / 2 + kp2ds_lhand = kp2ds[91:112] + kp2ds_rhand = kp2ds[112:133] + return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy() + + +def draw_aapose_by_meta( + img, + meta: AAPoseMeta, + threshold=0.5, + stick_width_norm=200, + draw_hand=True, + draw_head=True, +): + kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1) + kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) + kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) + pose_img = draw_aapose( + img, + kp2ds, + threshold, + kp2ds_lhand=kp2ds_lhand, + kp2ds_rhand=kp2ds_rhand, + stick_width_norm=stick_width_norm, + draw_hand=draw_hand, + draw_head=draw_head, + ) + return pose_img + + +def draw_aapose_by_meta_new( + img, + meta: AAPoseMeta, + threshold=0.5, + stickwidth_type="v2", + draw_hand=True, + draw_head=True, +): + kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1) + kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) + kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) + pose_img = draw_aapose_new( + img, + kp2ds, + threshold, + kp2ds_lhand=kp2ds_lhand, + kp2ds_rhand=kp2ds_rhand, + stickwidth_type=stickwidth_type, + draw_hand=draw_hand, + draw_head=draw_head, + ) + return pose_img + + +def draw_hand_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200): + kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None] * 0], axis=1) + kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) + kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) + pose_img = draw_aapose( + img, + kp2ds, + threshold, + kp2ds_lhand=kp2ds_lhand, + kp2ds_rhand=kp2ds_rhand, + stick_width_norm=stick_width_norm, + draw_hand=True, + draw_head=False, + ) + return pose_img + + +def draw_aaface_by_meta( + img, + meta: AAPoseMeta, + threshold=0.5, + stick_width_norm=200, + draw_hand=False, + draw_head=True, +): + kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1) + # kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) + # kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) + pose_img = draw_M( + img, + kp2ds, + threshold, + kp2ds_lhand=None, + kp2ds_rhand=None, + stick_width_norm=stick_width_norm, + draw_hand=draw_hand, + draw_head=draw_head, + ) + return pose_img + + +def draw_aanose_by_meta( + img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=100, draw_hand=False +): + kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1) + # kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) + # kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) + pose_img = draw_nose( + img, + kp2ds, + threshold, + kp2ds_lhand=None, + kp2ds_rhand=None, + stick_width_norm=stick_width_norm, + draw_hand=draw_hand, + ) + return pose_img + + +def gen_face_motion_seq( + img, metas: List[AAPoseMeta], threshold=0.5, stick_width_norm=200 +): + + return + + +def draw_M( + img, + kp2ds, + threshold=0.6, + data_to_json=None, + idx=-1, + kp2ds_lhand=None, + kp2ds_rhand=None, + draw_hand=False, + stick_width_norm=200, + draw_head=True, +): + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + + new_kep_list = [ + "Nose", + "Neck", + "RShoulder", + "RElbow", + "RWrist", # No.4 + "LShoulder", + "LElbow", + "LWrist", # No.7 + "RHip", + "RKnee", + "RAnkle", # No.10 + "LHip", + "LKnee", + "LAnkle", # No.13 + "REye", + "LEye", + "REar", + "LEar", + "LToe", + "RToe", + ] + # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \ + # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 + kp2ds = kp2ds.copy() + # import ipdb; ipdb.set_trace() + kp2ds[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 18, 19], 2] = 0 + if not draw_head: + kp2ds[[0, 14, 15, 16, 17], 2] = 0 + kp2ds_body = kp2ds + # kp2ds_body = kp2ds_body[:18] + + # kp2ds_lhand = kp2ds.copy()[91:112] + # kp2ds_rhand = kp2ds.copy()[112:133] + + limbSeq = [ + # [2, 3], + # [2, 6], # shoulders + # [3, 4], + # [4, 5], # left arm + # [6, 7], + # [7, 8], # right arm + # [2, 9], + # [9, 10], + # [10, 11], # right leg + # [2, 12], + # [12, 13], + # [13, 14], # left leg + # [2, 1], + [1, 15], + [15, 17], + [1, 16], + [16, 18], # face (nose, eyes, ears) + # [14, 19], + # [11, 20], # foot + ] + + colors = [ + # [255, 0, 0], + # [255, 85, 0], + # [255, 170, 0], + # [255, 255, 0], + # [170, 255, 0], + # [85, 255, 0], + # [0, 255, 0], + # [0, 255, 85], + # [0, 255, 170], + # [0, 255, 255], + # [0, 170, 255], + # [0, 85, 255], + # [0, 0, 255], + # [85, 0, 255], + [170, 0, 255], + [255, 0, 255], + [255, 0, 170], + [255, 0, 85], + # foot + # [200, 200, 0], + # [100, 100, 0], + ] + + H, W, C = img.shape + stickwidth = max(int(min(H, W) / stick_width_norm), 1) + + for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)): + keypoint1 = kp2ds_body[k1_index - 1] + keypoint2 = kp2ds_body[k2_index - 1] + + if keypoint1[-1] < threshold or keypoint2[-1] < threshold: + continue + + Y = np.array([keypoint1[0], keypoint2[0]]) + X = np.array([keypoint1[1], keypoint2[1]]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly( + (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1 + ) + cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) + + for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)): + if keypoint[-1] < threshold: + continue + x, y = keypoint[0], keypoint[1] + # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1) + + if draw_hand: + img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold) + img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold) + + kp2ds_body[:, 0] /= W + kp2ds_body[:, 1] /= H + + if data_to_json is not None: + if idx == -1: + data_to_json.append( + { + "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + ) + else: + data_to_json[idx] = { + "image_id": "frame_{:05d}.jpg".format(idx + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + return img + + +def draw_nose( + img, + kp2ds, + threshold=0.6, + data_to_json=None, + idx=-1, + kp2ds_lhand=None, + kp2ds_rhand=None, + draw_hand=False, + stick_width_norm=200, +): + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + + new_kep_list = [ + "Nose", + "Neck", + "RShoulder", + "RElbow", + "RWrist", # No.4 + "LShoulder", + "LElbow", + "LWrist", # No.7 + "RHip", + "RKnee", + "RAnkle", # No.10 + "LHip", + "LKnee", + "LAnkle", # No.13 + "REye", + "LEye", + "REar", + "LEar", + "LToe", + "RToe", + ] + # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \ + # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 + kp2ds = kp2ds.copy() + kp2ds[1:, 2] = 0 + # kp2ds[0, 2] = 1 + kp2ds_body = kp2ds + # kp2ds_body = kp2ds_body[:18] + + # kp2ds_lhand = kp2ds.copy()[91:112] + # kp2ds_rhand = kp2ds.copy()[112:133] + + limbSeq = [ + # [2, 3], + # [2, 6], # shoulders + # [3, 4], + # [4, 5], # left arm + # [6, 7], + # [7, 8], # right arm + # [2, 9], + # [9, 10], + # [10, 11], # right leg + # [2, 12], + # [12, 13], + # [13, 14], # left leg + # [2, 1], + [1, 15], + [15, 17], + [1, 16], + [16, 18], # face (nose, eyes, ears) + # [14, 19], + # [11, 20], # foot + ] + + colors = [ + # [255, 0, 0], + # [255, 85, 0], + # [255, 170, 0], + # [255, 255, 0], + # [170, 255, 0], + # [85, 255, 0], + # [0, 255, 0], + # [0, 255, 85], + # [0, 255, 170], + # [0, 255, 255], + # [0, 170, 255], + # [0, 85, 255], + # [0, 0, 255], + # [85, 0, 255], + [170, 0, 255], + # [255, 0, 255], + # [255, 0, 170], + # [255, 0, 85], + # foot + # [200, 200, 0], + # [100, 100, 0], + ] + + H, W, C = img.shape + stickwidth = max(int(min(H, W) / stick_width_norm), 1) + + # for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)): + # keypoint1 = kp2ds_body[k1_index - 1] + # keypoint2 = kp2ds_body[k2_index - 1] + + # if keypoint1[-1] < threshold or keypoint2[-1] < threshold: + # continue + + # Y = np.array([keypoint1[0], keypoint2[0]]) + # X = np.array([keypoint1[1], keypoint2[1]]) + # mX = np.mean(X) + # mY = np.mean(Y) + # length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + # angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + # polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + # cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) + + for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)): + if keypoint[-1] < threshold: + continue + x, y = keypoint[0], keypoint[1] + # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1) + + if draw_hand: + img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold) + img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold) + + kp2ds_body[:, 0] /= W + kp2ds_body[:, 1] /= H + + if data_to_json is not None: + if idx == -1: + data_to_json.append( + { + "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + ) + else: + data_to_json[idx] = { + "image_id": "frame_{:05d}.jpg".format(idx + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + return img + + +def draw_aapose( + img, + kp2ds, + threshold=0.6, + data_to_json=None, + idx=-1, + kp2ds_lhand=None, + kp2ds_rhand=None, + draw_hand=False, + stick_width_norm=200, + draw_head=True, +): + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + + new_kep_list = [ + "Nose", + "Neck", + "RShoulder", + "RElbow", + "RWrist", # No.4 + "LShoulder", + "LElbow", + "LWrist", # No.7 + "RHip", + "RKnee", + "RAnkle", # No.10 + "LHip", + "LKnee", + "LAnkle", # No.13 + "REye", + "LEye", + "REar", + "LEar", + "LToe", + "RToe", + ] + # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \ + # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 + kp2ds = kp2ds.copy() + if not draw_head: + kp2ds[[0, 14, 15, 16, 17], 2] = 0 + kp2ds_body = kp2ds + + # kp2ds_lhand = kp2ds.copy()[91:112] + # kp2ds_rhand = kp2ds.copy()[112:133] + + limbSeq = [ + [2, 3], + [2, 6], # shoulders + [3, 4], + [4, 5], # left arm + [6, 7], + [7, 8], # right arm + [2, 9], + [9, 10], + [10, 11], # right leg + [2, 12], + [12, 13], + [13, 14], # left leg + [2, 1], + [1, 15], + [15, 17], + [1, 16], + [16, 18], # face (nose, eyes, ears) + [14, 19], + [11, 20], # foot + ] + + colors = [ + [255, 0, 0], + [255, 85, 0], + [255, 170, 0], + [255, 255, 0], + [170, 255, 0], + [85, 255, 0], + [0, 255, 0], + [0, 255, 85], + [0, 255, 170], + [0, 255, 255], + [0, 170, 255], + [0, 85, 255], + [0, 0, 255], + [85, 0, 255], + [170, 0, 255], + [255, 0, 255], + [255, 0, 170], + [255, 0, 85], + # foot + [200, 200, 0], + [100, 100, 0], + ] + + H, W, C = img.shape + stickwidth = max(int(min(H, W) / stick_width_norm), 1) + + for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)): + keypoint1 = kp2ds_body[k1_index - 1] + keypoint2 = kp2ds_body[k2_index - 1] + + if keypoint1[-1] < threshold or keypoint2[-1] < threshold: + continue + + Y = np.array([keypoint1[0], keypoint2[0]]) + X = np.array([keypoint1[1], keypoint2[1]]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly( + (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1 + ) + cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) + + for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)): + if keypoint[-1] < threshold: + continue + x, y = keypoint[0], keypoint[1] + # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1) + + if draw_hand: + img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold) + img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold) + + kp2ds_body[:, 0] /= W + kp2ds_body[:, 1] /= H + + if data_to_json is not None: + if idx == -1: + data_to_json.append( + { + "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + ) + else: + data_to_json[idx] = { + "image_id": "frame_{:05d}.jpg".format(idx + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + return img + + +def draw_aapose_new( + img, + kp2ds, + threshold=0.6, + data_to_json=None, + idx=-1, + kp2ds_lhand=None, + kp2ds_rhand=None, + draw_hand=False, + stickwidth_type="v2", + draw_head=True, +): + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + + new_kep_list = [ + "Nose", + "Neck", + "RShoulder", + "RElbow", + "RWrist", # No.4 + "LShoulder", + "LElbow", + "LWrist", # No.7 + "RHip", + "RKnee", + "RAnkle", # No.10 + "LHip", + "LKnee", + "LAnkle", # No.13 + "REye", + "LEye", + "REar", + "LEar", + "LToe", + "RToe", + ] + # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \ + # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 + kp2ds = kp2ds.copy() + if not draw_head: + kp2ds[[0, 14, 15, 16, 17], 2] = 0 + kp2ds_body = kp2ds + + # kp2ds_lhand = kp2ds.copy()[91:112] + # kp2ds_rhand = kp2ds.copy()[112:133] + + limbSeq = [ + [2, 3], + [2, 6], # shoulders + [3, 4], + [4, 5], # left arm + [6, 7], + [7, 8], # right arm + [2, 9], + [9, 10], + [10, 11], # right leg + [2, 12], + [12, 13], + [13, 14], # left leg + [2, 1], + [1, 15], + [15, 17], + [1, 16], + [16, 18], # face (nose, eyes, ears) + [14, 19], + [11, 20], # foot + ] + + colors = [ + [255, 0, 0], + [255, 85, 0], + [255, 170, 0], + [255, 255, 0], + [170, 255, 0], + [85, 255, 0], + [0, 255, 0], + [0, 255, 85], + [0, 255, 170], + [0, 255, 255], + [0, 170, 255], + [0, 85, 255], + [0, 0, 255], + [85, 0, 255], + [170, 0, 255], + [255, 0, 255], + [255, 0, 170], + [255, 0, 85], + # foot + [200, 200, 0], + [100, 100, 0], + ] + + H, W, C = img.shape + H, W, C = img.shape + + if stickwidth_type == "v1": + stickwidth = max(int(min(H, W) / 200), 1) + elif stickwidth_type == "v2": + stickwidth = max(int(min(H, W) / 200) - 1, 1) + else: + raise + + for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)): + keypoint1 = kp2ds_body[k1_index - 1] + keypoint2 = kp2ds_body[k2_index - 1] + + if keypoint1[-1] < threshold or keypoint2[-1] < threshold: + continue + + Y = np.array([keypoint1[0], keypoint2[0]]) + X = np.array([keypoint1[1], keypoint2[1]]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly( + (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1 + ) + cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) + + for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)): + if keypoint[-1] < threshold: + continue + x, y = keypoint[0], keypoint[1] + # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1) + + if draw_hand: + img = draw_handpose_new( + img, kp2ds_lhand, stickwidth_type=stickwidth_type, hand_score_th=threshold + ) + img = draw_handpose_new( + img, kp2ds_rhand, stickwidth_type=stickwidth_type, hand_score_th=threshold + ) + + kp2ds_body[:, 0] /= W + kp2ds_body[:, 1] /= H + + if data_to_json is not None: + if idx == -1: + data_to_json.append( + { + "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + ) + else: + data_to_json[idx] = { + "image_id": "frame_{:05d}.jpg".format(idx + 1), + "height": H, + "width": W, + "category_id": 1, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + } + return img + + +def draw_bbox(img, bbox, color=(255, 0, 0)): + img = load_image(img) + bbox = [int(bbox_tmp) for bbox_tmp in bbox] + cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2) + return img + + +def draw_kp2ds( + img, kp2ds, threshold=0, color=(255, 0, 0), skeleton=None, reverse=False +): + img = load_image(img, reverse) + + if skeleton is not None: + if skeleton == "coco17": + skeleton_list = [ + [6, 8], + [8, 10], + [5, 7], + [7, 9], + [11, 13], + [13, 15], + [12, 14], + [14, 16], + [5, 6], + [6, 12], + [12, 11], + [11, 5], + ] + color_list = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + ] + elif skeleton == "cocowholebody": + skeleton_list = [ + [6, 8], + [8, 10], + [5, 7], + [7, 9], + [11, 13], + [13, 15], + [12, 14], + [14, 16], + [5, 6], + [6, 12], + [12, 11], + [11, 5], + [15, 17], + [15, 18], + [15, 19], + [16, 20], + [16, 21], + [16, 22], + [91, 92, 93, 94, 95], + [91, 96, 97, 98, 99], + [91, 100, 101, 102, 103], + [91, 104, 105, 106, 107], + [91, 108, 109, 110, 111], + [112, 113, 114, 115, 116], + [112, 117, 118, 119, 120], + [112, 121, 122, 123, 124], + [112, 125, 126, 127, 128], + [112, 129, 130, 131, 132], + ] + color_list = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + ] + else: + color_list = [color] + for _idx, _skeleton in enumerate(skeleton_list): + for i in range(len(_skeleton) - 1): + cv2.line( + img, + (int(kp2ds[_skeleton[i], 0]), int(kp2ds[_skeleton[i], 1])), + (int(kp2ds[_skeleton[i + 1], 0]), int(kp2ds[_skeleton[i + 1], 1])), + color_list[_idx % len(color_list)], + 3, + ) + + for _idx, kp2d in enumerate(kp2ds): + if kp2d[2] > threshold: + cv2.circle(img, (int(kp2d[0]), int(kp2d[1])), 3, color, -1) + # cv2.putText(img, + # str(_idx), + # (int(kp2d[0, i, 0])*1, + # int(kp2d[0, i, 1])*1), + # cv2.FONT_HERSHEY_SIMPLEX, + # 0.75, + # color, + # 2 + # ) + + return img + + +def draw_mask(img, mask, background=0, return_rgba=False): + img = load_image(img) + h, w, _ = img.shape + if type(background) == int: + background = np.ones((h, w, 3)).astype(np.uint8) * 255 * background + backgournd = cv2.resize(background, (w, h)) + img_rgba = np.concatenate([img, mask], -1) + return alphaMerge(img_rgba, background, 0, 0, return_rgba=True) + + +def draw_pcd(pcd_list, save_path=None): + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + + color_list = ["r", "g", "b", "y", "p"] + + for _idx, _pcd in enumerate(pcd_list): + ax.scatter(_pcd[:, 0], _pcd[:, 1], _pcd[:, 2], c=color_list[_idx], marker="o") + + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + + if save_path is not None: + plt.savefig(save_path) + else: + plt.savefig("tmp.png") + + +def load_image(img, reverse=False): + if type(img) == str: + img = cv2.imread(img) + if reverse: + img = img.astype(np.float32) + img = img[:, :, ::-1] + img = img.astype(np.uint8) + return img + + +def draw_skeleten(meta): + kps = [] + for i, kp in enumerate(meta["keypoints_body"]): + if kp is None: + # if kp is None: + kps.append([0, 0, 0]) + else: + kps.append([*kp, 1]) + kps = np.array(kps) + + kps[:, 0] *= meta["width"] + kps[:, 1] *= meta["height"] + pose_img = np.zeros([meta["height"], meta["width"], 3], dtype=np.uint8) + + pose_img = draw_aapose( + pose_img, + kps, + draw_hand=True, + kp2ds_lhand=meta["keypoints_left_hand"], + kp2ds_rhand=meta["keypoints_right_hand"], + ) + return pose_img + + +def draw_skeleten_with_pncc(pncc: np.ndarray, meta: Dict) -> np.ndarray: + """ + Args: + pncc: [H,W,3] + meta: required keys: keypoints_body: [N, 3] keypoints_left_hand, keypoints_right_hand + Return: + np.ndarray [H, W, 3] + """ + # preprocess keypoints + kps = [] + for i, kp in enumerate(meta["keypoints_body"]): + if kp is None: + # if kp is None: + kps.append([0, 0, 0]) + elif i in [14, 15, 16, 17]: + kps.append([0, 0, 0]) + else: + kps.append([*kp]) + kps = np.stack(kps) + + kps[:, 0] *= pncc.shape[1] + kps[:, 1] *= pncc.shape[0] + + # draw neck + canvas = np.zeros_like(pncc) + if kps[0][2] > 0.6 and kps[1][2] > 0.6: + canvas = draw_ellipse_by_2kp(canvas, kps[0], kps[1], [0, 0, 255]) + + # draw pncc + mask = (pncc > 0).max(axis=2) + canvas[mask] = pncc[mask] + pncc = canvas + + # draw other skeleten + kps[0] = 0 + + meta["keypoints_left_hand"][:, 0] *= meta["width"] + meta["keypoints_left_hand"][:, 1] *= meta["height"] + + meta["keypoints_right_hand"][:, 0] *= meta["width"] + meta["keypoints_right_hand"][:, 1] *= meta["height"] + pose_img = draw_aapose( + pncc, + kps, + draw_hand=True, + kp2ds_lhand=meta["keypoints_left_hand"], + kp2ds_rhand=meta["keypoints_right_hand"], + ) + return pose_img + + +FACE_CUSTOM_STYLE = { + "eyeball": {"indexs": [68, 69], "color": [255, 255, 255], "connect": False}, + "left_eyebrow": {"indexs": [17, 18, 19, 20, 21], "color": [0, 255, 0]}, + "right_eyebrow": {"indexs": [22, 23, 24, 25, 26], "color": [0, 0, 255]}, + "left_eye": { + "indexs": [36, 37, 38, 39, 40, 41], + "color": [255, 255, 0], + "close": True, + }, + "right_eye": { + "indexs": [42, 43, 44, 45, 46, 47], + "color": [255, 0, 255], + "close": True, + }, + "mouth_outside": { + "indexs": list(range(48, 60)), + "color": [100, 255, 50], + "close": True, + }, + "mouth_inside": { + "indexs": [60, 61, 62, 63, 64, 65, 66, 67], + "color": [255, 100, 50], + "close": True, + }, +} + + +def draw_face_kp(img, kps, thickness=2, style=FACE_CUSTOM_STYLE): + """ + Args: + img: [H, W, 3] + kps: [70, 2] + """ + img = img.copy() + for key, item in style.items(): + pts = np.array(kps[item["indexs"]]).astype(np.int32) + connect = item.get("connect", True) + color = item["color"] + close = item.get("close", False) + if connect: + cv2.polylines(img, [pts], close, color, thickness=thickness) + else: + for kp in pts: + kp = np.array(kp).astype(np.int32) + cv2.circle(img, kp, thickness * 2, color=color, thickness=-1) + return img + + +def draw_traj(metas: List[AAPoseMeta], threshold=0.6): + + colors = [ + [255, 0, 0], + [255, 85, 0], + [255, 170, 0], + [255, 255, 0], + [170, 255, 0], + [85, 255, 0], + [0, 255, 0], + [0, 255, 85], + [0, 255, 170], + [0, 255, 255], + [0, 170, 255], + [0, 85, 255], + [0, 0, 255], + [85, 0, 255], + [170, 0, 255], + [255, 0, 255], + [255, 0, 170], + [255, 0, 85], + [100, 255, 50], + [255, 100, 50], + # foot + [200, 200, 0], + [100, 100, 0], + ] + limbSeq = [ + [1, 2], + [1, 5], # shoulders + [2, 3], + [3, 4], # left arm + [5, 6], + [6, 7], # right arm + [1, 8], + [8, 9], + [9, 10], # right leg + [1, 11], + [11, 12], + [12, 13], # left leg + # face (nose, eyes, ears) + [13, 18], + [10, 19], # foot + ] + + face_seq = [[1, 0], [0, 14], [14, 16], [0, 15], [15, 17]] + kp_body = np.array([meta.kps_body for meta in metas]) + kp_body_p = np.array([meta.kps_body_p for meta in metas]) + + face_seq = random.sample(face_seq, 2) + + kp_lh = np.array([meta.kps_lhand for meta in metas]) + kp_rh = np.array([meta.kps_rhand for meta in metas]) + + kp_lh_p = np.array([meta.kps_lhand_p for meta in metas]) + kp_rh_p = np.array([meta.kps_rhand_p for meta in metas]) + + # kp_lh = np.concatenate([kp_lh, kp_lh_p], axis=-1) + # kp_rh = np.concatenate([kp_rh, kp_rh_p], axis=-1) + + new_limbSeq = [] + key_point_list = [] + for _idx, ((k1_index, k2_index)) in enumerate(limbSeq): + + vis = ( + (kp_body_p[:, k1_index] > threshold) + * (kp_body_p[:, k2_index] > threshold) + * 1 + ) + if vis.sum() * 1.0 / vis.shape[0] > 0.4: + new_limbSeq.append([k1_index, k2_index]) + + for _idx, ((k1_index, k2_index)) in enumerate(limbSeq): + + keypoint1 = kp_body[:, k1_index - 1] + keypoint2 = kp_body[:, k2_index - 1] + interleave = random.randint(4, 7) + randind = random.randint(0, interleave - 1) + # randind = random.rand(range(interleave), sampling_num) + + Y = np.array([keypoint1[:, 0], keypoint2[:, 0]]) + X = np.array([keypoint1[:, 1], keypoint2[:, 1]]) + + vis = (keypoint1[:, -1] > threshold) * (keypoint2[:, -1] > threshold) * 1 + + # for randidx in randind: + t = randind / interleave + x = (1 - t) * Y[0, :] + t * Y[1, :] + y = (1 - t) * X[0, :] + t * X[1, :] + + # np.array([1]) + x = x.astype(int) + y = y.astype(int) + + new_array = np.array([x, y, vis]).T + + key_point_list.append(new_array) + + indx_lh = random.randint(0, kp_lh.shape[1] - 1) + lh = kp_lh[:, indx_lh, :] + lh_p = kp_lh_p[:, indx_lh : indx_lh + 1] + lh = np.concatenate([lh, lh_p], axis=-1) + + indx_rh = random.randint(0, kp_rh.shape[1] - 1) + rh = kp_rh[:, random.randint(0, kp_rh.shape[1] - 1), :] + rh_p = kp_rh_p[:, indx_rh : indx_rh + 1] + rh = np.concatenate([rh, rh_p], axis=-1) + + lh[-1, :] = (lh[-1, :] > threshold) * 1 + rh[-1, :] = (rh[-1, :] > threshold) * 1 + + # print(rh.shape, new_array.shape) + # exit() + key_point_list.append(lh.astype(int)) + key_point_list.append(rh.astype(int)) + + key_points_list = np.stack(key_point_list) + num_points = len(key_points_list) + sample_colors = random.sample(colors, num_points) + + stickwidth = max(int(min(metas[0].width, metas[0].height) / 150), 2) + + image_list_ori = [] + for i in range(key_points_list.shape[-2]): + _image_vis = np.zeros((metas[0].width, metas[0].height, 3)) + points = key_points_list[:, i, :] + for idx, point in enumerate(points): + x, y, vis = point + if vis == 1: + cv2.circle( + _image_vis, (x, y), stickwidth, sample_colors[idx], thickness=-1 + ) + + image_list_ori.append(_image_vis) + + return image_list_ori + + return [np.zeros([meta.width, meta.height, 3], dtype=np.uint8) for meta in metas] + + +if __name__ == "__main__": + meta = { + "image_id": "00472.jpg", + "height": 540, + "width": 414, + "category_id": 1, + "keypoints_body": [ + [0.5084776947463768, 0.11350188078703703], + [0.504467655495169, 0.20419560185185184], + [0.3982016153381642, 0.198046875], + [0.3841664779589372, 0.34869068287037036], + [0.3901815368357488, 0.4670536747685185], + [0.610733695652174, 0.2103443287037037], + [0.6167487545289855, 0.3517650462962963], + [0.6448190292874396, 0.4762767650462963], + [0.4523371452294686, 0.47320240162037036], + [0.4503321256038647, 0.6776475694444445], + [0.47639738073671495, 0.8544234664351852], + [0.5766483620169082, 0.47320240162037036], + [0.5666232638888888, 0.6761103877314815], + [0.534542949879227, 0.863646556712963], + [0.4864224788647343, 0.09505570023148148], + [0.5285278910024155, 0.09351851851851851], + [0.46236224335748793, 0.10581597222222222], + [0.5586031853864735, 0.10274160879629629], + [0.4994551064311594, 0.9405056423611111], + [0.4152442821557971, 0.9312825520833333], + ], + "keypoints_left_hand": [ + [267.78515625, 263.830078125, 1.2840936183929443], + [265.294921875, 269.640625, 1.2546794414520264], + [263.634765625, 277.111328125, 1.2863062620162964], + [262.8046875, 285.412109375, 1.267038345336914], + [261.14453125, 292.8828125, 1.280144453048706], + [273.595703125, 281.26171875, 1.2592815160751343], + [271.10546875, 291.22265625, 1.3256099224090576], + [265.294921875, 294.54296875, 1.2368024587631226], + [261.14453125, 294.54296875, 0.9771889448165894], + [274.42578125, 282.091796875, 1.250044584274292], + [269.4453125, 291.22265625, 1.2571144104003906], + [264.46484375, 292.8828125, 1.177802324295044], + [260.314453125, 292.052734375, 0.9283463358879089], + [273.595703125, 282.091796875, 1.1834490299224854], + [269.4453125, 290.392578125, 1.188171625137329], + [265.294921875, 290.392578125, 1.192609429359436], + [261.974609375, 289.5625, 0.9366656541824341], + [271.935546875, 281.26171875, 1.0946396589279175], + [268.615234375, 287.072265625, 0.9906131029129028], + [265.294921875, 287.90234375, 1.0219476222991943], + [262.8046875, 287.072265625, 0.9240120053291321], + ], + "keypoints_right_hand": [ + [161.53515625, 258.849609375, 1.2069408893585205], + [168.17578125, 263.0, 1.1846840381622314], + [173.986328125, 269.640625, 1.1435924768447876], + [173.986328125, 277.94140625, 1.1802611351013184], + [173.986328125, 286.2421875, 1.2599592208862305], + [165.685546875, 275.451171875, 1.0633569955825806], + [167.345703125, 286.2421875, 1.1693341732025146], + [169.8359375, 291.22265625, 1.2698509693145752], + [170.666015625, 294.54296875, 1.0619274377822876], + [160.705078125, 276.28125, 1.0995020866394043], + [163.1953125, 287.90234375, 1.2735884189605713], + [166.515625, 291.22265625, 1.339503526687622], + [169.005859375, 294.54296875, 1.0835273265838623], + [157.384765625, 277.111328125, 1.0866981744766235], + [161.53515625, 287.072265625, 1.2468621730804443], + [164.025390625, 289.5625, 1.2817761898040771], + [166.515625, 292.052734375, 1.099466323852539], + [155.724609375, 277.111328125, 1.1065717935562134], + [159.044921875, 285.412109375, 1.1924479007720947], + [160.705078125, 287.072265625, 1.1304771900177002], + [162.365234375, 287.90234375, 1.0040509700775146], + ], + } + demo_meta = AAPoseMeta(meta) + res = draw_traj([demo_meta] * 5) + cv2.imwrite("traj.png", res[0][..., ::-1]) diff --git a/wan/modules/animate/preprocess/pose2d.py b/wan/modules/animate/preprocess/pose2d.py new file mode 100644 index 00000000..7b572546 --- /dev/null +++ b/wan/modules/animate/preprocess/pose2d.py @@ -0,0 +1,505 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import os +import cv2 +from typing import Union, List + +import numpy as np +import torch +import onnxruntime + +from .pose2d_utils import ( + read_img, + box_convert_simple, + bbox_from_detector, + crop, + keypoints_from_heatmaps, + load_pose_metas_from_kp2ds_seq, +) + + +class SimpleOnnxInference(object): + def __init__(self, checkpoint, device="cuda", reverse_input=False, **kwargs): + if isinstance(device, str): + device = torch.device(device) + if device.type == "cuda": + device = "{}:{}".format(device.type, device.index) + providers = [ + ( + "CUDAExecutionProvider", + { + "device_id": ( + device[-1:] + if device[-1] in [str(_i) for _i in range(10)] + else "0" + ) + }, + ), + "CPUExecutionProvider", + ] + else: + providers = ["CPUExecutionProvider"] + self.device = device + if not os.path.exists(checkpoint): + raise RuntimeError("{} is not existed!".format(checkpoint)) + + if os.path.isdir(checkpoint): + checkpoint = os.path.join(checkpoint, "end2end.onnx") + + self.session = onnxruntime.InferenceSession(checkpoint, providers=providers) + self.input_name = self.session.get_inputs()[0].name + self.output_name = self.session.get_outputs()[0].name + self.input_resolution = ( + self.session.get_inputs()[0].shape[2:] + if not reverse_input + else self.session.get_inputs()[0].shape[2:][::-1] + ) + self.input_resolution = np.array(self.input_resolution) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def get_output_names(self): + output_names = [] + for node in self.session.get_outputs(): + output_names.append(node.name) + return output_names + + def set_device(self, device): + if isinstance(device, str): + device = torch.device(device) + if device.type == "cuda": + device = "{}:{}".format(device.type, device.index) + providers = [ + ( + "CUDAExecutionProvider", + { + "device_id": ( + device[-1:] + if device[-1] in [str(_i) for _i in range(10)] + else "0" + ) + }, + ), + "CPUExecutionProvider", + ] + else: + providers = ["CPUExecutionProvider"] + self.session.set_providers(providers) + self.device = device + + +class Yolo(SimpleOnnxInference): + def __init__( + self, + checkpoint, + device="cuda", + threshold_conf=0.05, + threshold_multi_persons=0.1, + input_resolution=(640, 640), + threshold_iou=0.5, + threshold_bbox_shape_ratio=0.4, + cat_id=[1], + select_type="max", + strict=True, + sorted_func=None, + **kwargs, + ): + super(Yolo, self).__init__(checkpoint, device=device, **kwargs) + + model_inputs = self.session.get_inputs() + input_shape = model_inputs[0].shape + + self.input_width = 640 + self.input_height = 640 + + self.threshold_multi_persons = threshold_multi_persons + self.threshold_conf = threshold_conf + self.threshold_iou = threshold_iou + self.threshold_bbox_shape_ratio = threshold_bbox_shape_ratio + self.input_resolution = input_resolution + self.cat_id = cat_id + self.select_type = select_type + self.strict = strict + self.sorted_func = sorted_func + + def preprocess(self, input_image): + """ + Preprocesses the input image before performing inference. + + Returns: + image_data: Preprocessed image data ready for inference. + """ + img = read_img(input_image) + # Get the height and width of the input image + img_height, img_width = img.shape[:2] + # Resize the image to match the input shape + img = cv2.resize(img, (self.input_resolution[1], self.input_resolution[0])) + # Normalize the image data by dividing it by 255.0 + image_data = np.array(img) / 255.0 + # Transpose the image to have the channel dimension as the first dimension + image_data = np.transpose(image_data, (2, 0, 1)) # Channel first + # Expand the dimensions of the image data to match the expected input shape + # image_data = np.expand_dims(image_data, axis=0).astype(np.float32) + image_data = image_data.astype(np.float32) + # Return the preprocessed image data + return image_data, np.array([img_height, img_width]) + + def postprocess(self, output, shape_raw, cat_id=[1]): + """ + Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs. + + Args: + input_image (numpy.ndarray): The input image. + output (numpy.ndarray): The output of the model. + + Returns: + numpy.ndarray: The input image with detections drawn on it. + """ + # Transpose and squeeze the output to match the expected shape + + outputs = np.squeeze(output) + if len(outputs.shape) == 1: + outputs = outputs[None] + if output.shape[-1] != 6 and output.shape[1] == 84: + outputs = np.transpose(outputs) + + # Get the number of rows in the outputs array + rows = outputs.shape[0] + + # Calculate the scaling factors for the bounding box coordinates + x_factor = shape_raw[1] / self.input_width + y_factor = shape_raw[0] / self.input_height + + # Lists to store the bounding boxes, scores, and class IDs of the detections + boxes = [] + scores = [] + class_ids = [] + + if outputs.shape[-1] == 6: + max_scores = outputs[:, 4] + classid = outputs[:, -1] + + threshold_conf_masks = max_scores >= self.threshold_conf + classid_masks = classid[threshold_conf_masks] != 3.14159 + + max_scores = max_scores[threshold_conf_masks][classid_masks] + classid = classid[threshold_conf_masks][classid_masks] + + boxes = outputs[:, :4][threshold_conf_masks][classid_masks] + boxes[:, [0, 2]] *= x_factor + boxes[:, [1, 3]] *= y_factor + boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + boxes = boxes.astype(np.int32) + + else: + classes_scores = outputs[:, 4:] + max_scores = np.amax(classes_scores, -1) + threshold_conf_masks = max_scores >= self.threshold_conf + + classid = np.argmax(classes_scores[threshold_conf_masks], -1) + + classid_masks = classid != 3.14159 + + classes_scores = classes_scores[threshold_conf_masks][classid_masks] + max_scores = max_scores[threshold_conf_masks][classid_masks] + classid = classid[classid_masks] + + xywh = outputs[:, :4][threshold_conf_masks][classid_masks] + + x = xywh[:, 0:1] + y = xywh[:, 1:2] + w = xywh[:, 2:3] + h = xywh[:, 3:4] + + left = (x - w / 2) * x_factor + top = (y - h / 2) * y_factor + width = w * x_factor + height = h * y_factor + boxes = np.concatenate([left, top, width, height], axis=-1).astype(np.int32) + + boxes = boxes.tolist() + scores = max_scores.tolist() + class_ids = classid.tolist() + + # Apply non-maximum suppression to filter out overlapping bounding boxes + indices = cv2.dnn.NMSBoxes( + boxes, scores, self.threshold_conf, self.threshold_iou + ) + # Iterate over the selected indices after non-maximum suppression + + results = [] + for i in indices: + # Get the box, score, and class ID corresponding to the index + box = box_convert_simple(boxes[i], "xywh2xyxy") + score = scores[i] + class_id = class_ids[i] + results.append(box + [score] + [class_id]) + # # Draw the detection on the input image + + # Return the modified input image + return np.array(results) + + def process_results(self, results, shape_raw, cat_id=[1], single_person=True): + if isinstance(results, tuple): + det_results = results[0] + else: + det_results = results + + person_results = [] + person_count = 0 + if len(results): + max_idx = -1 + max_bbox_size = shape_raw[0] * shape_raw[1] * -10 + max_bbox_shape = -1 + + bboxes = [] + idx_list = [] + for i in range(results.shape[0]): + bbox = results[i] + if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf): + idx_list.append(i) + bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1]))) + if bbox_shape > max_bbox_shape: + max_bbox_shape = bbox_shape + + results = results[idx_list] + + for i in range(results.shape[0]): + bbox = results[i] + bboxes.append(bbox) + if self.select_type == "max": + bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1])) + elif self.select_type == "center": + bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1] / 2)) * -1 + bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1]))) + if bbox_size > max_bbox_size: + if ( + (self.strict or max_idx != -1) + and bbox_shape + < max_bbox_shape * self.threshold_bbox_shape_ratio + ): + continue + max_bbox_size = bbox_size + max_bbox_shape = bbox_shape + max_idx = i + + if self.sorted_func is not None and len(bboxes) > 0: + max_idx = self.sorted_func(bboxes, shape_raw) + bbox = bboxes[max_idx] + if self.select_type == "max": + max_bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1])) + elif self.select_type == "center": + max_bbox_size = ( + abs((bbox[2] + bbox[0]) / 2 - shape_raw[1] / 2) + ) * -1 + + if max_idx != -1: + person_count = 1 + + if max_idx != -1: + person = {} + person["bbox"] = results[max_idx, :5] + person["track_id"] = int(0) + person_results.append(person) + + for i in range(results.shape[0]): + bbox = results[i] + if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf): + if self.select_type == "max": + bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1])) + elif self.select_type == "center": + bbox_size = ( + abs((bbox[2] + bbox[0]) / 2 - shape_raw[1] / 2) + ) * -1 + if ( + i != max_idx + and bbox_size > max_bbox_size * self.threshold_multi_persons + and bbox_size < max_bbox_size + ): + person_count += 1 + if not single_person: + person = {} + person["bbox"] = results[i, :5] + person["track_id"] = int(person_count - 1) + person_results.append(person) + return person_results + else: + return None + + def postprocess_threading( + self, outputs, shape_raw, person_results, i, single_person=True, **kwargs + ): + result = self.postprocess(outputs[i], shape_raw[i], cat_id=self.cat_id) + result = self.process_results( + result, shape_raw[i], cat_id=self.cat_id, single_person=single_person + ) + if result is not None and len(result) != 0: + person_results[i] = result + + def forward(self, img, shape_raw, **kwargs): + """ + Performs inference using an ONNX model and returns the output image with drawn detections. + + Returns: + output_img: The output image with drawn detections. + """ + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + shape_raw = shape_raw.cpu().numpy() + + outputs = self.session.run(None, {self.session.get_inputs()[0].name: img})[0] + person_results = [ + [ + { + "bbox": np.array( + [0.0, 0.0, 1.0 * shape_raw[i][1], 1.0 * shape_raw[i][0], -1] + ), + "track_id": -1, + } + ] + for i in range(len(outputs)) + ] + + for i in range(len(outputs)): + self.postprocess_threading(outputs, shape_raw, person_results, i, **kwargs) + return person_results + + +class ViTPose(SimpleOnnxInference): + def __init__(self, checkpoint, device="cuda", **kwargs): + super(ViTPose, self).__init__(checkpoint, device=device) + + def forward(self, img, center, scale, **kwargs): + heatmaps = self.session.run([], {self.session.get_inputs()[0].name: img})[0] + points, prob = keypoints_from_heatmaps( + heatmaps=heatmaps, + center=center, + scale=scale * 200, + unbiased=True, + use_udp=False, + ) + return np.concatenate([points, prob], axis=2) + + @staticmethod + def preprocess( + img, bbox=None, input_resolution=(256, 192), rescale=1.25, mask=None, **kwargs + ): + if ( + bbox is None + or bbox[-1] <= 0 + or (bbox[2] - bbox[0]) < 10 + or (bbox[3] - bbox[1]) < 10 + ): + bbox = np.array([0, 0, img.shape[1], img.shape[0]]) + + bbox_xywh = bbox + if mask is not None: + img = np.where(mask > 128, img, mask) + + if isinstance(input_resolution, int): + center, scale = bbox_from_detector( + bbox_xywh, (input_resolution, input_resolution), rescale=rescale + ) + img, new_shape, old_xy, new_xy = crop( + img, center, scale, (input_resolution, input_resolution) + ) + else: + center, scale = bbox_from_detector( + bbox_xywh, input_resolution, rescale=rescale + ) + img, new_shape, old_xy, new_xy = crop( + img, center, scale, (input_resolution[0], input_resolution[1]) + ) + + IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406]) + IMG_NORM_STD = np.array([0.229, 0.224, 0.225]) + img_norm = (img / 255.0 - IMG_NORM_MEAN) / IMG_NORM_STD + img_norm = img_norm.transpose(2, 0, 1).astype(np.float32) + return img_norm, np.array(center), np.array(scale) + + +class Pose2d: + def __init__(self, checkpoint, detector_checkpoint=None, device="cuda", **kwargs): + + if detector_checkpoint is not None: + self.detector = Yolo(detector_checkpoint, device) + else: + self.detector = None + + self.model = ViTPose(checkpoint, device) + self.device = device + + def load_images(self, inputs): + """ + Load images from various input types. + + Args: + inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path, + single image array, or list of image arrays + + Returns: + List[np.ndarray]: List of RGB image arrays + + Raises: + ValueError: If file format is unsupported or image cannot be read + """ + if isinstance(inputs, str): + if inputs.lower().endswith((".mp4", ".avi", ".mov", ".mkv")): + cap = cv2.VideoCapture(inputs) + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + cap.release() + images = frames + elif inputs.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")): + img = cv2.cvtColor(cv2.imread(inputs), cv2.COLOR_BGR2RGB) + if img is None: + raise ValueError(f"Cannot read image: {inputs}") + images = [img] + else: + raise ValueError(f"Unsupported file format: {inputs}") + + elif isinstance(inputs, np.ndarray): + images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs] + elif isinstance(inputs, list): + images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs] + return images + + def __call__( + self, + inputs: Union[str, np.ndarray, List[np.ndarray]], + return_image: bool = False, + **kwargs, + ): + """ + Process input and estimate 2D keypoints. + + Args: + inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path, + single image array, or list of image arrays + **kwargs: Additional arguments for processing + + Returns: + np.ndarray: Array of detected 2D keypoints for all input images + """ + images = self.load_images(inputs) + H, W = images[0].shape[:2] + if self.detector is not None: + bboxes = [] + for _image in images: + img, shape = self.detector.preprocess(_image) + bboxes.append(self.detector(img[None], shape[None])[0][0]["bbox"]) + else: + bboxes = [None] * len(images) + + kp2ds = [] + for _image, _bbox in zip(images, bboxes): + img, center, scale = self.model.preprocess(_image, _bbox) + kp2ds.append(self.model(img[None], center[None], scale[None])) + kp2ds = np.concatenate(kp2ds, 0) + metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H) + return metas diff --git a/wan/modules/animate/preprocess/pose2d_utils.py b/wan/modules/animate/preprocess/pose2d_utils.py new file mode 100644 index 00000000..eb42ef17 --- /dev/null +++ b/wan/modules/animate/preprocess/pose2d_utils.py @@ -0,0 +1,1202 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import warnings +import cv2 +import numpy as np +from typing import List +from PIL import Image + + +def box_convert_simple(box, convert_type="xyxy2xywh"): + if convert_type == "xyxy2xywh": + return [box[0], box[1], box[2] - box[0], box[3] - box[1]] + elif convert_type == "xywh2xyxy": + return [box[0], box[1], box[2] + box[0], box[3] + box[1]] + elif convert_type == "xyxy2ctwh": + return [ + (box[0] + box[2]) / 2, + (box[1] + box[3]) / 2, + box[2] - box[0], + box[3] - box[1], + ] + elif convert_type == "ctwh2xyxy": + return [ + box[0] - box[2] // 2, + box[1] - box[3] // 2, + box[0] + (box[2] - box[2] // 2), + box[1] + (box[3] - box[3] // 2), + ] + + +def read_img(image, convert="RGB", check_exist=False): + if isinstance(image, str): + if check_exist and not osp.exists(image): + return None + try: + img = Image.open(image) + if convert: + img = img.convert(convert) + except: + raise IOError("File error: ", image) + return np.asarray(img) + else: + if isinstance(image, np.ndarray): + if convert: + return image[..., ::-1] + else: + if convert: + img = img.convert(convert) + return np.asarray(img) + + +class AAPoseMeta: + def __init__(self, meta=None, kp2ds=None): + self.image_id = "" + self.height = 0 + self.width = 0 + + self.kps_body: np.ndarray = None + self.kps_lhand: np.ndarray = None + self.kps_rhand: np.ndarray = None + self.kps_face: np.ndarray = None + self.kps_body_p: np.ndarray = None + self.kps_lhand_p: np.ndarray = None + self.kps_rhand_p: np.ndarray = None + self.kps_face_p: np.ndarray = None + + if meta is not None: + self.load_from_meta(meta) + elif kp2ds is not None: + self.load_from_kp2ds(kp2ds) + + def is_valid(self, kp, p, threshold): + x, y = kp + if x < 0 or y < 0 or x > self.width or y > self.height or p < threshold: + return False + else: + return True + + def get_bbox(self, kp, kp_p, threshold=0.5): + kps = kp[kp_p > threshold] + if kps.size == 0: + return 0, 0, 0, 0 + x0, y0 = kps.min(axis=0) + x1, y1 = kps.max(axis=0) + return x0, y0, x1, y1 + + def crop(self, x0, y0, x1, y1): + all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] + for kps in all_kps: + if kps is not None: + kps[:, 0] -= x0 + kps[:, 1] -= y0 + self.width = x1 - x0 + self.height = y1 - y0 + return self + + def resize(self, width, height): + scale_x = width / self.width + scale_y = height / self.height + all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] + for kps in all_kps: + if kps is not None: + kps[:, 0] *= scale_x + kps[:, 1] *= scale_y + self.width = width + self.height = height + return self + + def get_kps_body_with_p(self, normalize=False): + kps_body = self.kps_body.copy() + if normalize: + kps_body = kps_body / np.array([self.width, self.height]) + + return np.concatenate([kps_body, self.kps_body_p[:, None]]) + + @staticmethod + def from_kps_face(kps_face: np.ndarray, height: int, width: int): + + pose_meta = AAPoseMeta() + pose_meta.kps_face = kps_face[:, :2] + if kps_face.shape[1] == 3: + pose_meta.kps_face_p = kps_face[:, 2] + else: + pose_meta.kps_face_p = kps_face[:, 0] * 0 + 1 + pose_meta.height = height + pose_meta.width = width + return pose_meta + + @staticmethod + def from_kps_body(kps_body: np.ndarray, height: int, width: int): + + pose_meta = AAPoseMeta() + pose_meta.kps_body = kps_body[:, :2] + pose_meta.kps_body_p = kps_body[:, 2] + pose_meta.height = height + pose_meta.width = width + return pose_meta + + @staticmethod + def from_humanapi_meta(meta): + pose_meta = AAPoseMeta() + width, height = meta["width"], meta["height"] + pose_meta.width = width + pose_meta.height = height + pose_meta.kps_body = meta["keypoints_body"][:, :2] * (width, height) + pose_meta.kps_body_p = meta["keypoints_body"][:, 2] + pose_meta.kps_lhand = meta["keypoints_left_hand"][:, :2] * (width, height) + pose_meta.kps_lhand_p = meta["keypoints_left_hand"][:, 2] + pose_meta.kps_rhand = meta["keypoints_right_hand"][:, :2] * (width, height) + pose_meta.kps_rhand_p = meta["keypoints_right_hand"][:, 2] + if "keypoints_face" in meta: + pose_meta.kps_face = meta["keypoints_face"][:, :2] * (width, height) + pose_meta.kps_face_p = meta["keypoints_face"][:, 2] + return pose_meta + + def load_from_meta(self, meta, norm_body=True, norm_hand=False): + + self.image_id = meta.get("image_id", "00000.png") + self.height = meta["height"] + self.width = meta["width"] + kps_body_p = [] + kps_body = [] + for kp in meta["keypoints_body"]: + if kp is None: + kps_body.append([0, 0]) + kps_body_p.append(0) + else: + kps_body.append(kp) + kps_body_p.append(1) + + self.kps_body = np.array(kps_body) + self.kps_body[:, 0] *= self.width + self.kps_body[:, 1] *= self.height + self.kps_body_p = np.array(kps_body_p) + + self.kps_lhand = np.array(meta["keypoints_left_hand"])[:, :2] + self.kps_lhand_p = np.array(meta["keypoints_left_hand"])[:, 2] + self.kps_rhand = np.array(meta["keypoints_right_hand"])[:, :2] + self.kps_rhand_p = np.array(meta["keypoints_right_hand"])[:, 2] + + @staticmethod + def load_from_kp2ds(kp2ds: List[np.ndarray], width: int, height: int): + """input 133x3 numpy keypoints and output AAPoseMeta + + Args: + kp2ds (List[np.ndarray]): _description_ + width (int): _description_ + height (int): _description_ + + Returns: + _type_: _description_ + """ + pose_meta = AAPoseMeta() + pose_meta.width = width + pose_meta.height = height + kps_body = ( + kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + + kp2ds[ + [0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21] + ] + ) / 2 + kps_lhand = kp2ds[91:112] + kps_rhand = kp2ds[112:133] + kps_face = np.concatenate([kp2ds[23 : 23 + 68], kp2ds[1:3]], axis=0) + pose_meta.kps_body = kps_body[:, :2] + pose_meta.kps_body_p = kps_body[:, 2] + pose_meta.kps_lhand = kps_lhand[:, :2] + pose_meta.kps_lhand_p = kps_lhand[:, 2] + pose_meta.kps_rhand = kps_rhand[:, :2] + pose_meta.kps_rhand_p = kps_rhand[:, 2] + pose_meta.kps_face = kps_face[:, :2] + pose_meta.kps_face_p = kps_face[:, 2] + return pose_meta + + @staticmethod + def from_dwpose(dwpose_det_res, height, width): + pose_meta = AAPoseMeta() + pose_meta.kps_body = dwpose_det_res["bodies"]["candidate"] + pose_meta.kps_body_p = dwpose_det_res["bodies"]["score"] + pose_meta.kps_body[:, 0] *= width + pose_meta.kps_body[:, 1] *= height + + pose_meta.kps_lhand, pose_meta.kps_rhand = dwpose_det_res["hands"] + pose_meta.kps_lhand[:, 0] *= width + pose_meta.kps_lhand[:, 1] *= height + pose_meta.kps_rhand[:, 0] *= width + pose_meta.kps_rhand[:, 1] *= height + pose_meta.kps_lhand_p, pose_meta.kps_rhand_p = dwpose_det_res["hands_score"] + + pose_meta.kps_face = dwpose_det_res["faces"][0] + pose_meta.kps_face[:, 0] *= width + pose_meta.kps_face[:, 1] *= height + pose_meta.kps_face_p = dwpose_det_res["faces_score"][0] + return pose_meta + + def save_json(self): + pass + + def draw_aapose( + self, img, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True + ): + from .human_visualization import draw_aapose_by_meta + + return draw_aapose_by_meta( + img, self, threshold, stick_width_norm, draw_hand, draw_head + ) + + def translate(self, x0, y0): + all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] + for kps in all_kps: + if kps is not None: + kps[:, 0] -= x0 + kps[:, 1] -= y0 + + def scale(self, sx, sy): + all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] + for kps in all_kps: + if kps is not None: + kps[:, 0] *= sx + kps[:, 1] *= sy + + def padding_resize2(self, height=512, width=512): + """kps will be changed inplace""" + + all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] + + ori_height, ori_width = self.height, self.width + + if (ori_height / ori_width) > (height / width): + new_width = int(height / ori_height * ori_width) + padding = int((width - new_width) / 2) + padding_width = padding + padding_height = 0 + scale = height / ori_height + + for kps in all_kps: + if kps is not None: + kps[:, 0] = kps[:, 0] * scale + padding + kps[:, 1] = kps[:, 1] * scale + + else: + new_height = int(width / ori_width * ori_height) + padding = int((height - new_height) / 2) + padding_width = 0 + padding_height = padding + scale = width / ori_width + for kps in all_kps: + if kps is not None: + kps[:, 1] = kps[:, 1] * scale + padding + kps[:, 0] = kps[:, 0] * scale + + self.width = width + self.height = height + return self + + +def transform_preds(coords, center, scale, output_size, use_udp=False): + """Get final keypoint predictions from heatmaps and apply scaling and + translation to map them back to the image. + + Note: + num_keypoints: K + + Args: + coords (np.ndarray[K, ndims]): + + * If ndims=2, corrds are predicted keypoint location. + * If ndims=4, corrds are composed of (x, y, scores, tags) + * If ndims=5, corrds are composed of (x, y, scores, tags, + flipped_tags) + + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + use_udp (bool): Use unbiased data processing + + Returns: + np.ndarray: Predicted coordinates in the images. + """ + assert coords.shape[1] in (2, 4, 5) + assert len(center) == 2 + assert len(scale) == 2 + assert len(output_size) == 2 + + # Recover the scale which is normalized by a factor of 200. + # scale = scale * 200.0 + + if use_udp: + scale_x = scale[0] / (output_size[0] - 1.0) + scale_y = scale[1] / (output_size[1] - 1.0) + else: + scale_x = scale[0] / output_size[0] + scale_y = scale[1] / output_size[1] + + target_coords = np.ones_like(coords) + target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5 + target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5 + + return target_coords + + +def _calc_distances(preds, targets, mask, normalize): + """Calculate the normalized distances between preds and target. + + Note: + batch_size: N + num_keypoints: K + dimension of keypoints: D (normally, D=2 or D=3) + + Args: + preds (np.ndarray[N, K, D]): Predicted keypoint location. + targets (np.ndarray[N, K, D]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + normalize (np.ndarray[N, D]): Typical value is heatmap_size + + Returns: + np.ndarray[K, N]: The normalized distances. \ + If target keypoints are missing, the distance is -1. + """ + N, K, _ = preds.shape + # set mask=0 when normalize==0 + _mask = mask.copy() + _mask[np.where((normalize == 0).sum(1))[0], :] = False + distances = np.full((N, K), -1, dtype=np.float32) + # handle invalid values + normalize[np.where(normalize <= 0)] = 1e6 + distances[_mask] = np.linalg.norm( + ((preds - targets) / normalize[:, None, :])[_mask], axis=-1 + ) + return distances.T + + +def _distance_acc(distances, thr=0.5): + """Return the percentage below the distance threshold, while ignoring + distances values with -1. + + Note: + batch_size: N + Args: + distances (np.ndarray[N, ]): The normalized distances. + thr (float): Threshold of the distances. + + Returns: + float: Percentage of distances below the threshold. \ + If all target keypoints are missing, return -1. + """ + distance_valid = distances != -1 + num_distance_valid = distance_valid.sum() + if num_distance_valid > 0: + return (distances[distance_valid] < thr).sum() / num_distance_valid + return -1 + + +def _get_max_preds(heatmaps): + """Get keypoint predictions from score maps. + + Note: + batch_size: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. + + Returns: + tuple: A tuple containing aggregated results. + + - preds (np.ndarray[N, K, 2]): Predicted keypoint location. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + assert isinstance(heatmaps, np.ndarray), "heatmaps should be numpy.ndarray" + assert heatmaps.ndim == 4, "batch_images should be 4-ndim" + + N, K, _, W = heatmaps.shape + heatmaps_reshaped = heatmaps.reshape((N, K, -1)) + idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1)) + maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1)) + + preds = np.tile(idx, (1, 1, 2)).astype(np.float32) + preds[:, :, 0] = preds[:, :, 0] % W + preds[:, :, 1] = preds[:, :, 1] // W + + preds = np.where(np.tile(maxvals, (1, 1, 2)) > 0.0, preds, -1) + return preds, maxvals + + +def _get_max_preds_3d(heatmaps): + """Get keypoint predictions from 3D score maps. + + Note: + batch size: N + num keypoints: K + heatmap depth size: D + heatmap height: H + heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps. + + Returns: + tuple: A tuple containing aggregated results. + + - preds (np.ndarray[N, K, 3]): Predicted keypoint location. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + assert isinstance(heatmaps, np.ndarray), "heatmaps should be numpy.ndarray" + assert heatmaps.ndim == 5, "heatmaps should be 5-ndim" + + N, K, D, H, W = heatmaps.shape + heatmaps_reshaped = heatmaps.reshape((N, K, -1)) + idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1)) + maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1)) + + preds = np.zeros((N, K, 3), dtype=np.float32) + _idx = idx[..., 0] + preds[..., 2] = _idx // (H * W) + preds[..., 1] = (_idx // W) % H + preds[..., 0] = _idx % W + + preds = np.where(maxvals > 0.0, preds, -1) + return preds, maxvals + + +def pose_pck_accuracy(output, target, mask, thr=0.05, normalize=None): + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints from heatmaps. + + Note: + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + - batch_size: N + - num_keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + output (np.ndarray[N, K, H, W]): Model output heatmaps. + target (np.ndarray[N, K, H, W]): Groundtruth heatmaps. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + thr (float): Threshold of PCK calculation. Default 0.05. + normalize (np.ndarray[N, 2]): Normalization factor for H&W. + + Returns: + tuple: A tuple containing keypoint accuracy. + + - np.ndarray[K]: Accuracy of each keypoint. + - float: Averaged accuracy across all keypoints. + - int: Number of valid keypoints. + """ + N, K, H, W = output.shape + if K == 0: + return None, 0, 0 + if normalize is None: + normalize = np.tile(np.array([[H, W]]), (N, 1)) + + pred, _ = _get_max_preds(output) + gt, _ = _get_max_preds(target) + return keypoint_pck_accuracy(pred, gt, mask, thr, normalize) + + +def keypoint_pck_accuracy(pred, gt, mask, thr, normalize): + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints for coordinates. + + Note: + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + - batch_size: N + - num_keypoints: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + thr (float): Threshold of PCK calculation. + normalize (np.ndarray[N, 2]): Normalization factor for H&W. + + Returns: + tuple: A tuple containing keypoint accuracy. + + - acc (np.ndarray[K]): Accuracy of each keypoint. + - avg_acc (float): Averaged accuracy across all keypoints. + - cnt (int): Number of valid keypoints. + """ + distances = _calc_distances(pred, gt, mask, normalize) + + acc = np.array([_distance_acc(d, thr) for d in distances]) + valid_acc = acc[acc >= 0] + cnt = len(valid_acc) + avg_acc = valid_acc.mean() if cnt > 0 else 0 + return acc, avg_acc, cnt + + +def keypoint_auc(pred, gt, mask, normalize, num_step=20): + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints for coordinates. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + normalize (float): Normalization factor. + + Returns: + float: Area under curve. + """ + nor = np.tile(np.array([[normalize, normalize]]), (pred.shape[0], 1)) + x = [1.0 * i / num_step for i in range(num_step)] + y = [] + for thr in x: + _, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor) + y.append(avg_acc) + + auc = 0 + for i in range(num_step): + auc += 1.0 / num_step * y[i] + return auc + + +def keypoint_nme(pred, gt, mask, normalize_factor): + """Calculate the normalized mean error (NME). + + Note: + - batch_size: N + - num_keypoints: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + normalize_factor (np.ndarray[N, 2]): Normalization factor. + + Returns: + float: normalized mean error + """ + distances = _calc_distances(pred, gt, mask, normalize_factor) + distance_valid = distances[distances != -1] + return distance_valid.sum() / max(1, len(distance_valid)) + + +def keypoint_epe(pred, gt, mask): + """Calculate the end-point error. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + + Returns: + float: Average end-point error. + """ + + distances = _calc_distances( + pred, gt, mask, np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32) + ) + distance_valid = distances[distances != -1] + return distance_valid.sum() / max(1, len(distance_valid)) + + +def _taylor(heatmap, coord): + """Distribution aware coordinate decoding method. + + Note: + - heatmap height: H + - heatmap width: W + + Args: + heatmap (np.ndarray[H, W]): Heatmap of a particular joint type. + coord (np.ndarray[2,]): Coordinates of the predicted keypoints. + + Returns: + np.ndarray[2,]: Updated coordinates. + """ + H, W = heatmap.shape[:2] + px, py = int(coord[0]), int(coord[1]) + if 1 < px < W - 2 and 1 < py < H - 2: + dx = 0.5 * (heatmap[py][px + 1] - heatmap[py][px - 1]) + dy = 0.5 * (heatmap[py + 1][px] - heatmap[py - 1][px]) + dxx = 0.25 * (heatmap[py][px + 2] - 2 * heatmap[py][px] + heatmap[py][px - 2]) + dxy = 0.25 * ( + heatmap[py + 1][px + 1] + - heatmap[py - 1][px + 1] + - heatmap[py + 1][px - 1] + + heatmap[py - 1][px - 1] + ) + dyy = 0.25 * ( + heatmap[py + 2 * 1][px] - 2 * heatmap[py][px] + heatmap[py - 2 * 1][px] + ) + derivative = np.array([[dx], [dy]]) + hessian = np.array([[dxx, dxy], [dxy, dyy]]) + if dxx * dyy - dxy**2 != 0: + hessianinv = np.linalg.inv(hessian) + offset = -hessianinv @ derivative + offset = np.squeeze(np.array(offset.T), axis=0) + coord += offset + return coord + + +def post_dark_udp(coords, batch_heatmaps, kernel=3): + """DARK post-pocessing. Implemented by udp. Paper ref: Huang et al. The + Devil is in the Details: Delving into Unbiased Data Processing for Human + Pose Estimation (CVPR 2020). Zhang et al. Distribution-Aware Coordinate + Representation for Human Pose Estimation (CVPR 2020). + + Note: + - batch size: B + - num keypoints: K + - num persons: N + - height of heatmaps: H + - width of heatmaps: W + + B=1 for bottom_up paradigm where all persons share the same heatmap. + B=N for top_down paradigm where each person has its own heatmaps. + + Args: + coords (np.ndarray[N, K, 2]): Initial coordinates of human pose. + batch_heatmaps (np.ndarray[B, K, H, W]): batch_heatmaps + kernel (int): Gaussian kernel size (K) for modulation. + + Returns: + np.ndarray([N, K, 2]): Refined coordinates. + """ + if not isinstance(batch_heatmaps, np.ndarray): + batch_heatmaps = batch_heatmaps.cpu().numpy() + B, K, H, W = batch_heatmaps.shape + N = coords.shape[0] + assert B == 1 or B == N + for heatmaps in batch_heatmaps: + for heatmap in heatmaps: + cv2.GaussianBlur(heatmap, (kernel, kernel), 0, heatmap) + np.clip(batch_heatmaps, 0.001, 50, batch_heatmaps) + np.log(batch_heatmaps, batch_heatmaps) + + batch_heatmaps_pad = np.pad( + batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="edge" + ).flatten() + + index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (W + 2) + index += (W + 2) * (H + 2) * np.arange(0, B * K).reshape(-1, K) + index = index.astype(int).reshape(-1, 1) + i_ = batch_heatmaps_pad[index] + ix1 = batch_heatmaps_pad[index + 1] + iy1 = batch_heatmaps_pad[index + W + 2] + ix1y1 = batch_heatmaps_pad[index + W + 3] + ix1_y1_ = batch_heatmaps_pad[index - W - 3] + ix1_ = batch_heatmaps_pad[index - 1] + iy1_ = batch_heatmaps_pad[index - 2 - W] + + dx = 0.5 * (ix1 - ix1_) + dy = 0.5 * (iy1 - iy1_) + derivative = np.concatenate([dx, dy], axis=1) + derivative = derivative.reshape(N, K, 2, 1) + dxx = ix1 - 2 * i_ + ix1_ + dyy = iy1 - 2 * i_ + iy1_ + dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) + hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1) + hessian = hessian.reshape(N, K, 2, 2) + hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) + coords -= np.einsum("ijmn,ijnk->ijmk", hessian, derivative).squeeze() + return coords + + +def _gaussian_blur(heatmaps, kernel=11): + """Modulate heatmap distribution with Gaussian. + sigma = 0.3*((kernel_size-1)*0.5-1)+0.8 + sigma~=3 if k=17 + sigma=2 if k=11; + sigma~=1.5 if k=7; + sigma~=1 if k=3; + + Note: + - batch_size: N + - num_keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. + kernel (int): Gaussian kernel size (K) for modulation, which should + match the heatmap gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + + Returns: + np.ndarray ([N, K, H, W]): Modulated heatmap distribution. + """ + assert kernel % 2 == 1 + + border = (kernel - 1) // 2 + batch_size = heatmaps.shape[0] + num_joints = heatmaps.shape[1] + height = heatmaps.shape[2] + width = heatmaps.shape[3] + for i in range(batch_size): + for j in range(num_joints): + origin_max = np.max(heatmaps[i, j]) + dr = np.zeros((height + 2 * border, width + 2 * border), dtype=np.float32) + dr[border:-border, border:-border] = heatmaps[i, j].copy() + dr = cv2.GaussianBlur(dr, (kernel, kernel), 0) + heatmaps[i, j] = dr[border:-border, border:-border].copy() + heatmaps[i, j] *= origin_max / np.max(heatmaps[i, j]) + return heatmaps + + +def keypoints_from_regression(regression_preds, center, scale, img_size): + """Get final keypoint predictions from regression vectors and transform + them back to the image. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + regression_preds (np.ndarray[N, K, 2]): model prediction. + center (np.ndarray[N, 2]): Center of the bounding box (x, y). + scale (np.ndarray[N, 2]): Scale of the bounding box + wrt height/width. + img_size (list(img_width, img_height)): model input image size. + + Returns: + tuple: + + - preds (np.ndarray[N, K, 2]): Predicted keypoint location in images. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + N, K, _ = regression_preds.shape + preds, maxvals = regression_preds, np.ones((N, K, 1), dtype=np.float32) + + preds = preds * img_size + + # Transform back to the image + for i in range(N): + preds[i] = transform_preds(preds[i], center[i], scale[i], img_size) + + return preds, maxvals + + +def keypoints_from_heatmaps( + heatmaps, + center, + scale, + unbiased=False, + post_process="default", + kernel=11, + valid_radius_factor=0.0546875, + use_udp=False, + target_type="GaussianHeatmap", +): + """Get final keypoint predictions from heatmaps and transform them back to + the image. + + Note: + - batch size: N + - num keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. + center (np.ndarray[N, 2]): Center of the bounding box (x, y). + scale (np.ndarray[N, 2]): Scale of the bounding box + wrt height/width. + post_process (str/None): Choice of methods to post-process + heatmaps. Currently supported: None, 'default', 'unbiased', + 'megvii'. + unbiased (bool): Option to use unbiased decoding. Mutually + exclusive with megvii. + Note: this arg is deprecated and unbiased=True can be replaced + by post_process='unbiased' + Paper ref: Zhang et al. Distribution-Aware Coordinate + Representation for Human Pose Estimation (CVPR 2020). + kernel (int): Gaussian kernel size (K) for modulation, which should + match the heatmap gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + valid_radius_factor (float): The radius factor of the positive area + in classification heatmap for UDP. + use_udp (bool): Use unbiased data processing. + target_type (str): 'GaussianHeatmap' or 'CombinedTarget'. + GaussianHeatmap: Classification target with gaussian distribution. + CombinedTarget: The combination of classification target + (response map) and regression target (offset map). + Paper ref: Huang et al. The Devil is in the Details: Delving into + Unbiased Data Processing for Human Pose Estimation (CVPR 2020). + + Returns: + tuple: A tuple containing keypoint predictions and scores. + + - preds (np.ndarray[N, K, 2]): Predicted keypoint location in images. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + # Avoid being affected + heatmaps = heatmaps.copy() + + # detect conflicts + if unbiased: + assert post_process not in [False, None, "megvii"] + if post_process in ["megvii", "unbiased"]: + assert kernel > 0 + if use_udp: + assert not post_process == "megvii" + + # normalize configs + if post_process is False: + warnings.warn( + "post_process=False is deprecated, " "please use post_process=None instead", + DeprecationWarning, + ) + post_process = None + elif post_process is True: + if unbiased is True: + warnings.warn( + "post_process=True, unbiased=True is deprecated," + " please use post_process='unbiased' instead", + DeprecationWarning, + ) + post_process = "unbiased" + else: + warnings.warn( + "post_process=True, unbiased=False is deprecated, " + "please use post_process='default' instead", + DeprecationWarning, + ) + post_process = "default" + elif post_process == "default": + if unbiased is True: + warnings.warn( + "unbiased=True is deprecated, please use " + "post_process='unbiased' instead", + DeprecationWarning, + ) + post_process = "unbiased" + + # start processing + if post_process == "megvii": + heatmaps = _gaussian_blur(heatmaps, kernel=kernel) + + N, K, H, W = heatmaps.shape + if use_udp: + if target_type.lower() == "GaussianHeatMap".lower(): + preds, maxvals = _get_max_preds(heatmaps) + preds = post_dark_udp(preds, heatmaps, kernel=kernel) + elif target_type.lower() == "CombinedTarget".lower(): + for person_heatmaps in heatmaps: + for i, heatmap in enumerate(person_heatmaps): + kt = 2 * kernel + 1 if i % 3 == 0 else kernel + cv2.GaussianBlur(heatmap, (kt, kt), 0, heatmap) + # valid radius is in direct proportion to the height of heatmap. + valid_radius = valid_radius_factor * H + offset_x = heatmaps[:, 1::3, :].flatten() * valid_radius + offset_y = heatmaps[:, 2::3, :].flatten() * valid_radius + heatmaps = heatmaps[:, ::3, :] + preds, maxvals = _get_max_preds(heatmaps) + index = preds[..., 0] + preds[..., 1] * W + index += W * H * np.arange(0, N * K / 3) + index = index.astype(int).reshape(N, K // 3, 1) + preds += np.concatenate((offset_x[index], offset_y[index]), axis=2) + else: + raise ValueError( + "target_type should be either " "'GaussianHeatmap' or 'CombinedTarget'" + ) + else: + preds, maxvals = _get_max_preds(heatmaps) + if post_process == "unbiased": # alleviate biased coordinate + # apply Gaussian distribution modulation. + heatmaps = np.log(np.maximum(_gaussian_blur(heatmaps, kernel), 1e-10)) + for n in range(N): + for k in range(K): + preds[n][k] = _taylor(heatmaps[n][k], preds[n][k]) + elif post_process is not None: + # add +/-0.25 shift to the predicted locations for higher acc. + for n in range(N): + for k in range(K): + heatmap = heatmaps[n][k] + px = int(preds[n][k][0]) + py = int(preds[n][k][1]) + if 1 < px < W - 1 and 1 < py < H - 1: + diff = np.array( + [ + heatmap[py][px + 1] - heatmap[py][px - 1], + heatmap[py + 1][px] - heatmap[py - 1][px], + ] + ) + preds[n][k] += np.sign(diff) * 0.25 + if post_process == "megvii": + preds[n][k] += 0.5 + + # Transform back to the image + for i in range(N): + preds[i] = transform_preds( + preds[i], center[i], scale[i], [W, H], use_udp=use_udp + ) + + if post_process == "megvii": + maxvals = maxvals / 255.0 + 0.5 + + return preds, maxvals + + +def keypoints_from_heatmaps3d(heatmaps, center, scale): + """Get final keypoint predictions from 3d heatmaps and transform them back + to the image. + + Note: + - batch size: N + - num keypoints: K + - heatmap depth size: D + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps. + center (np.ndarray[N, 2]): Center of the bounding box (x, y). + scale (np.ndarray[N, 2]): Scale of the bounding box + wrt height/width. + + Returns: + tuple: A tuple containing keypoint predictions and scores. + + - preds (np.ndarray[N, K, 3]): Predicted 3d keypoint location \ + in images. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + N, K, D, H, W = heatmaps.shape + preds, maxvals = _get_max_preds_3d(heatmaps) + # Transform back to the image + for i in range(N): + preds[i, :, :2] = transform_preds(preds[i, :, :2], center[i], scale[i], [W, H]) + return preds, maxvals + + +def multilabel_classification_accuracy(pred, gt, mask, thr=0.5): + """Get multi-label classification accuracy. + + Note: + - batch size: N + - label number: L + + Args: + pred (np.ndarray[N, L, 2]): model predicted labels. + gt (np.ndarray[N, L, 2]): ground-truth labels. + mask (np.ndarray[N, 1] or np.ndarray[N, L] ): reliability of + ground-truth labels. + + Returns: + float: multi-label classification accuracy. + """ + # we only compute accuracy on the samples with ground-truth of all labels. + valid = (mask > 0).min(axis=1) if mask.ndim == 2 else (mask > 0) + pred, gt = pred[valid], gt[valid] + + if pred.shape[0] == 0: + acc = 0.0 # when no sample is with gt labels, set acc to 0. + else: + # The classification of a sample is regarded as correct + # only if it's correct for all labels. + acc = (((pred - thr) * (gt - thr)) > 0).all(axis=1).mean() + return acc + + +def get_transform(center, scale, res, rot=0): + """Generate transformation matrix.""" + # res: (height, width), (rows, cols) + crop_aspect_ratio = res[0] / float(res[1]) + h = 200 * scale + w = h / crop_aspect_ratio + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / w + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / w + 0.5) + t[1, 2] = res[0] * (-float(center[1]) / h + 0.5) + t[2, 2] = 1 + if not rot == 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3, 3)) + rot_rad = rot * np.pi / 180 + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0, :2] = [cs, -sn] + rot_mat[1, :2] = [sn, cs] + rot_mat[2, 2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0, 2] = -res[1] / 2 + t_mat[1, 2] = -res[0] / 2 + t_inv = t_mat.copy() + t_inv[:2, 2] *= -1 + t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) + return t + + +def transform(pt, center, scale, res, invert=0, rot=0): + """Transform pixel location to different reference.""" + t = get_transform(center, scale, res, rot=rot) + if invert: + t = np.linalg.inv(t) + new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T + new_pt = np.dot(t, new_pt) + return np.array([round(new_pt[0]), round(new_pt[1])], dtype=int) + 1 + + +def bbox_from_detector(bbox, input_resolution=(224, 224), rescale=1.25): + """ + Get center and scale of bounding box from bounding box. + The expected format is [min_x, min_y, max_x, max_y]. + """ + CROP_IMG_HEIGHT, CROP_IMG_WIDTH = input_resolution + CROP_ASPECT_RATIO = CROP_IMG_HEIGHT / float(CROP_IMG_WIDTH) + + # center + center_x = (bbox[0] + bbox[2]) / 2.0 + center_y = (bbox[1] + bbox[3]) / 2.0 + center = np.array([center_x, center_y]) + + # scale + bbox_w = bbox[2] - bbox[0] + bbox_h = bbox[3] - bbox[1] + bbox_size = max(bbox_w * CROP_ASPECT_RATIO, bbox_h) + + scale = np.array([bbox_size / CROP_ASPECT_RATIO, bbox_size]) / 200.0 + # scale = bbox_size / 200.0 + # adjust bounding box tightness + scale *= rescale + return center, scale + + +def crop(img, center, scale, res): + """ + Crop image according to the supplied bounding box. + res: [rows, cols] + """ + # Upper left point + ul = np.array(transform([1, 1], center, max(scale), res, invert=1)) - 1 + # Bottom right point + br = ( + np.array(transform([res[1] + 1, res[0] + 1], center, max(scale), res, invert=1)) + - 1 + ) + + # Padding so that when rotated proper amount of context is included + pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape, dtype=np.float32) + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] + new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(len(img[0]), br[0]) + old_y = max(0, ul[1]), min(len(img), br[1]) + try: + new_img[new_y[0] : new_y[1], new_x[0] : new_x[1]] = img[ + old_y[0] : old_y[1], old_x[0] : old_x[1] + ] + except Exception as e: + print(e) + + new_img = cv2.resize(new_img, (res[1], res[0])) # (cols, rows) + return new_img, new_shape, (old_x, old_y), (new_x, new_y) # , ul, br + + +def split_kp2ds_for_aa(kp2ds, ret_face=False): + kp2ds_body = ( + kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]] + ) / 2 + kp2ds_lhand = kp2ds[91:112] + kp2ds_rhand = kp2ds[112:133] + kp2ds_face = kp2ds[22:91] + if ret_face: + return ( + kp2ds_body.copy(), + kp2ds_lhand.copy(), + kp2ds_rhand.copy(), + kp2ds_face.copy(), + ) + return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy() + + +def load_pose_metas_from_kp2ds_seq_list(kp2ds_seq, width, height): + metas = [] + for kps in kp2ds_seq: + if len(kps) != 1: + return None + kps = kps[0].copy() + kps[:, 0] /= width + kps[:, 1] /= height + kp2ds_body, kp2ds_lhand, kp2ds_rhand, kp2ds_face = split_kp2ds_for_aa( + kps, ret_face=True + ) + + if kp2ds_body[:, :2].min(axis=1).max() < 0: + kp2ds_body = last_kp2ds_body + last_kp2ds_body = kp2ds_body + + meta = { + "width": width, + "height": height, + "keypoints_body": kp2ds_body.tolist(), + "keypoints_left_hand": kp2ds_lhand.tolist(), + "keypoints_right_hand": kp2ds_rhand.tolist(), + "keypoints_face": kp2ds_face.tolist(), + } + metas.append(meta) + return metas + + +def load_pose_metas_from_kp2ds_seq(kp2ds_seq, width, height): + metas = [] + for kps in kp2ds_seq: + kps = kps.copy() + kps[:, 0] /= width + kps[:, 1] /= height + kp2ds_body, kp2ds_lhand, kp2ds_rhand, kp2ds_face = split_kp2ds_for_aa( + kps, ret_face=True + ) + + # 排除全部小于0的情况 + if kp2ds_body[:, :2].min(axis=1).max() < 0: + kp2ds_body = last_kp2ds_body + last_kp2ds_body = kp2ds_body + + meta = { + "width": width, + "height": height, + "keypoints_body": kp2ds_body, + "keypoints_left_hand": kp2ds_lhand, + "keypoints_right_hand": kp2ds_rhand, + "keypoints_face": kp2ds_face, + } + metas.append(meta) + return metas diff --git a/wan/modules/animate/preprocess/preprocess_data.py b/wan/modules/animate/preprocess/preprocess_data.py new file mode 100644 index 00000000..f077d5c8 --- /dev/null +++ b/wan/modules/animate/preprocess/preprocess_data.py @@ -0,0 +1,140 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import os +import argparse +from .process_pipepline import ProcessPipeline + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="The preprocessing pipeline for Wan-animate." + ) + + parser.add_argument( + "--ckpt_path", + type=str, + default=None, + help="The path to the preprocessing model's checkpoint directory. ", + ) + + parser.add_argument( + "--video_path", type=str, default=None, help="The path to the driving video." + ) + parser.add_argument( + "--refer_path", + type=str, + default=None, + help="The path to the refererence image.", + ) + parser.add_argument( + "--save_path", + type=str, + default=None, + help="The path to save the processed results.", + ) + + parser.add_argument( + "--resolution_area", + type=int, + nargs=2, + default=[1280, 720], + help="The target resolution for processing, specified as [width, height]. To handle different aspect ratios, the video is resized to have a total area equivalent to width * height, while preserving the original aspect ratio.", + ) + parser.add_argument( + "--fps", + type=int, + default=30, + help="The target FPS for processing the driving video. Set to -1 to use the video's original FPS.", + ) + + parser.add_argument( + "--replace_flag", + action="store_true", + default=False, + help="Whether to use replacement mode.", + ) + parser.add_argument( + "--retarget_flag", + action="store_true", + default=False, + help="Whether to use pose retargeting. Currently only supported in animation mode", + ) + parser.add_argument( + "--use_flux", + action="store_true", + default=False, + help="Whether to use image editing in pose retargeting. Recommended if the character in the reference image or the first frame of the driving video is not in a standard, front-facing pose", + ) + + # Parameters for the mask strategy in replacement mode. These control the mask's size and shape. Refer to https://arxiv.org/pdf/2502.06145 + parser.add_argument( + "--iterations", + type=int, + default=3, + help="Number of iterations for mask dilation.", + ) + parser.add_argument( + "--k", type=int, default=7, help="Number of kernel size for mask dilation." + ) + parser.add_argument( + "--w_len", + type=int, + default=1, + help="The number of subdivisions for the grid along the 'w' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed.", + ) + parser.add_argument( + "--h_len", + type=int, + default=1, + help="The number of subdivisions for the grid along the 'h' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed.", + ) + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = _parse_args() + args_dict = vars(args) + print(args_dict) + + assert ( + len(args.resolution_area) == 2 + ), "resolution_area should be a list of two integers [width, height]" + assert ( + not args.use_flux or args.retarget_flag + ), "Image editing with FLUX can only be used when pose retargeting is enabled." + + pose2d_checkpoint_path = os.path.join( + args.ckpt_path, "pose2d/vitpose_h_wholebody.onnx" + ) + det_checkpoint_path = os.path.join(args.ckpt_path, "det/yolov10m.onnx") + + sam2_checkpoint_path = ( + os.path.join(args.ckpt_path, "sam2/sam2_hiera_large.pt") + if args.replace_flag + else None + ) + flux_kontext_path = ( + os.path.join(args.ckpt_path, "FLUX.1-Kontext-dev") if args.use_flux else None + ) + process_pipeline = ProcessPipeline( + det_checkpoint_path=det_checkpoint_path, + pose2d_checkpoint_path=pose2d_checkpoint_path, + sam_checkpoint_path=sam2_checkpoint_path, + flux_kontext_path=flux_kontext_path, + ) + os.makedirs(args.save_path, exist_ok=True) + process_pipeline( + video_path=args.video_path, + refer_image_path=args.refer_path, + output_path=args.save_path, + resolution_area=args.resolution_area, + fps=args.fps, + iterations=args.iterations, + k=args.k, + w_len=args.w_len, + h_len=args.h_len, + retarget_flag=args.retarget_flag, + use_flux=args.use_flux, + replace_flag=args.replace_flag, + ) diff --git a/wan/modules/animate/preprocess/process_pipepline.py b/wan/modules/animate/preprocess/process_pipepline.py new file mode 100644 index 00000000..1cca4c9d --- /dev/null +++ b/wan/modules/animate/preprocess/process_pipepline.py @@ -0,0 +1,486 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import os +import numpy as np +import shutil +import torch +from diffusers import FluxKontextPipeline +import cv2 +from loguru import logger +from PIL import Image + +try: + import moviepy.editor as mpy +except ImportError: # pragma: no cover - fallback path + import moviepy as mpy + +from decord import VideoReader + +from .pose2d import Pose2d +from .pose2d_utils import AAPoseMeta +from .utils import ( + get_aug_mask, + get_face_bboxes, + get_frame_indices, + get_mask_body_img, + padding_resize, + resize_by_area, +) +from .human_visualization import draw_aapose_by_meta_new +from .retarget_pose import get_retarget_pose +import sam2.modeling.sam.transformer as transformer + +transformer.USE_FLASH_ATTN = False +transformer.MATH_KERNEL_ON = True +transformer.OLD_GPU = True +from .sam_utils import build_sam2_video_predictor + + +class ProcessPipeline: + def __init__( + self, + det_checkpoint_path, + pose2d_checkpoint_path, + sam_checkpoint_path, + flux_kontext_path, + ): + self.pose2d = Pose2d( + checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path + ) + + model_cfg = "sam2_hiera_l.yaml" + if sam_checkpoint_path is not None: + self.predictor = build_sam2_video_predictor(model_cfg, sam_checkpoint_path) + if flux_kontext_path is not None: + self.flux_kontext = FluxKontextPipeline.from_pretrained( + flux_kontext_path, torch_dtype=torch.bfloat16 + ).to("cuda") + + def __call__( + self, + video_path, + refer_image_path, + output_path, + resolution_area=[1280, 720], + fps=30, + iterations=3, + k=7, + w_len=1, + h_len=1, + retarget_flag=False, + use_flux=False, + replace_flag=False, + ): + if replace_flag: + + video_reader = VideoReader(video_path) + frame_num = len(video_reader) + print("frame_num: {}".format(frame_num)) + + video_fps = video_reader.get_avg_fps() + print("video_fps: {}".format(video_fps)) + print("fps: {}".format(fps)) + + # TODO: Maybe we can switch to PyAV later, which can get accurate frame num + duration = video_reader.get_frame_timestamp(-1)[-1] + expected_frame_num = int(duration * video_fps + 0.5) + ratio = abs((frame_num - expected_frame_num) / frame_num) + if ratio > 0.1: + print( + "Warning: The difference between the actual number of frames and the expected number of frames is two large" + ) + frame_num = expected_frame_num + + if fps == -1: + fps = video_fps + + target_num = int(frame_num / video_fps * fps) + print("target_num: {}".format(target_num)) + idxs = get_frame_indices(frame_num, video_fps, target_num, fps) + frames = video_reader.get_batch(idxs).asnumpy() + + frames = [ + resize_by_area( + frame, resolution_area[0] * resolution_area[1], divisor=16 + ) + for frame in frames + ] + height, width = frames[0].shape[:2] + logger.info(f"Processing pose meta") + + tpl_pose_metas = self.pose2d(frames) + + face_images = [] + for idx, meta in enumerate(tpl_pose_metas): + face_bbox_for_image = get_face_bboxes( + meta["keypoints_face"][:, :2], + scale=1.3, + image_shape=(frames[0].shape[0], frames[0].shape[1]), + ) + + x1, x2, y1, y2 = face_bbox_for_image + face_image = frames[idx][y1:y2, x1:x2] + face_image = cv2.resize(face_image, (512, 512)) + face_images.append(face_image) + + logger.info(f"Processing reference image: {refer_image_path}") + refer_img = cv2.imread(refer_image_path) + src_ref_path = os.path.join(output_path, "src_ref.png") + shutil.copy(refer_image_path, src_ref_path) + refer_img = refer_img[..., ::-1] + + refer_img = padding_resize(refer_img, height, width) + logger.info(f"Processing template video: {video_path}") + tpl_retarget_pose_metas = [ + AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas + ] + cond_images = [] + + for idx, meta in enumerate(tpl_retarget_pose_metas): + canvas = np.zeros_like(refer_img) + conditioning_image = draw_aapose_by_meta_new(canvas, meta) + cond_images.append(conditioning_image) + masks = self.get_mask(frames, 400, tpl_pose_metas) + + bg_images = [] + aug_masks = [] + + for frame, mask in zip(frames, masks): + if iterations > 0: + _, each_mask = get_mask_body_img( + frame, mask, iterations=iterations, k=k + ) + each_aug_mask = get_aug_mask(each_mask, w_len=w_len, h_len=h_len) + else: + each_aug_mask = mask + + each_bg_image = frame * (1 - each_aug_mask[:, :, None]) + bg_images.append(each_bg_image) + aug_masks.append(each_aug_mask) + + src_face_path = os.path.join(output_path, "src_face.mp4") + mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path) + + src_pose_path = os.path.join(output_path, "src_pose.mp4") + mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path) + + src_bg_path = os.path.join(output_path, "src_bg.mp4") + mpy.ImageSequenceClip(bg_images, fps=fps).write_videofile(src_bg_path) + + aug_masks_new = [ + np.stack([mask * 255, mask * 255, mask * 255], axis=2) + for mask in aug_masks + ] + src_mask_path = os.path.join(output_path, "src_mask.mp4") + mpy.ImageSequenceClip(aug_masks_new, fps=fps).write_videofile(src_mask_path) + return True + else: + logger.info(f"Processing reference image: {refer_image_path}") + refer_img = cv2.imread(refer_image_path) + src_ref_path = os.path.join(output_path, "src_ref.png") + shutil.copy(refer_image_path, src_ref_path) + refer_img = refer_img[..., ::-1] + + refer_img = resize_by_area( + refer_img, resolution_area[0] * resolution_area[1], divisor=16 + ) + + refer_pose_meta = self.pose2d([refer_img])[0] + + logger.info(f"Processing template video: {video_path}") + video_reader = VideoReader(video_path) + frame_num = len(video_reader) + print("frame_num: {}".format(frame_num)) + + video_fps = video_reader.get_avg_fps() + print("video_fps: {}".format(video_fps)) + print("fps: {}".format(fps)) + + # TODO: Maybe we can switch to PyAV later, which can get accurate frame num + duration = video_reader.get_frame_timestamp(-1)[-1] + expected_frame_num = int(duration * video_fps + 0.5) + ratio = abs((frame_num - expected_frame_num) / frame_num) + if ratio > 0.1: + print( + "Warning: The difference between the actual number of frames and the expected number of frames is two large" + ) + frame_num = expected_frame_num + + if fps == -1: + fps = video_fps + + target_num = int(frame_num / video_fps * fps) + print("target_num: {}".format(target_num)) + idxs = get_frame_indices(frame_num, video_fps, target_num, fps) + frames = video_reader.get_batch(idxs).asnumpy() + + logger.info(f"Processing pose meta") + + tpl_pose_meta0 = self.pose2d(frames[:1])[0] + tpl_pose_metas = self.pose2d(frames) + + face_images = [] + for idx, meta in enumerate(tpl_pose_metas): + face_bbox_for_image = get_face_bboxes( + meta["keypoints_face"][:, :2], + scale=1.3, + image_shape=(frames[0].shape[0], frames[0].shape[1]), + ) + + x1, x2, y1, y2 = face_bbox_for_image + face_image = frames[idx][y1:y2, x1:x2] + face_image = cv2.resize(face_image, (512, 512)) + face_images.append(face_image) + + if retarget_flag: + if use_flux: + tpl_prompt, refer_prompt = self.get_editing_prompts( + tpl_pose_metas, refer_pose_meta + ) + refer_input = Image.fromarray(refer_img) + refer_edit = self.flux_kontext( + image=refer_input, + height=refer_img.shape[0], + width=refer_img.shape[1], + prompt=refer_prompt, + guidance_scale=2.5, + num_inference_steps=28, + ).images[0] + + refer_edit = Image.fromarray( + padding_resize( + np.array(refer_edit), refer_img.shape[0], refer_img.shape[1] + ) + ) + refer_edit_path = os.path.join(output_path, "refer_edit.png") + refer_edit.save(refer_edit_path) + refer_edit_pose_meta = self.pose2d([np.array(refer_edit)])[0] + + tpl_img = frames[1] + tpl_input = Image.fromarray(tpl_img) + + tpl_edit = self.flux_kontext( + image=tpl_input, + height=tpl_img.shape[0], + width=tpl_img.shape[1], + prompt=tpl_prompt, + guidance_scale=2.5, + num_inference_steps=28, + ).images[0] + + tpl_edit = Image.fromarray( + padding_resize( + np.array(tpl_edit), tpl_img.shape[0], tpl_img.shape[1] + ) + ) + tpl_edit_path = os.path.join(output_path, "tpl_edit.png") + tpl_edit.save(tpl_edit_path) + tpl_edit_pose_meta0 = self.pose2d([np.array(tpl_edit)])[0] + tpl_retarget_pose_metas = get_retarget_pose( + tpl_pose_meta0, + refer_pose_meta, + tpl_pose_metas, + tpl_edit_pose_meta0, + refer_edit_pose_meta, + ) + else: + tpl_retarget_pose_metas = get_retarget_pose( + tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, None, None + ) + else: + tpl_retarget_pose_metas = [ + AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas + ] + + cond_images = [] + for idx, meta in enumerate(tpl_retarget_pose_metas): + if retarget_flag: + canvas = np.zeros_like(refer_img) + conditioning_image = draw_aapose_by_meta_new(canvas, meta) + else: + canvas = np.zeros_like(frames[0]) + conditioning_image = draw_aapose_by_meta_new(canvas, meta) + conditioning_image = padding_resize( + conditioning_image, refer_img.shape[0], refer_img.shape[1] + ) + + cond_images.append(conditioning_image) + + src_face_path = os.path.join(output_path, "src_face.mp4") + mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path) + + src_pose_path = os.path.join(output_path, "src_pose.mp4") + mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path) + return True + + def get_editing_prompts(self, tpl_pose_metas, refer_pose_meta): + arm_visible = False + leg_visible = False + for tpl_pose_meta in tpl_pose_metas: + tpl_keypoints = tpl_pose_meta["keypoints_body"] + if ( + tpl_keypoints[3].all() != 0 + or tpl_keypoints[4].all() != 0 + or tpl_keypoints[6].all() != 0 + or tpl_keypoints[7].all() != 0 + ): + if ( + ( + tpl_keypoints[3][0] <= 1 + and tpl_keypoints[3][1] <= 1 + and tpl_keypoints[3][2] >= 0.75 + ) + or ( + tpl_keypoints[4][0] <= 1 + and tpl_keypoints[4][1] <= 1 + and tpl_keypoints[4][2] >= 0.75 + ) + or ( + tpl_keypoints[6][0] <= 1 + and tpl_keypoints[6][1] <= 1 + and tpl_keypoints[6][2] >= 0.75 + ) + or ( + tpl_keypoints[7][0] <= 1 + and tpl_keypoints[7][1] <= 1 + and tpl_keypoints[7][2] >= 0.75 + ) + ): + arm_visible = True + if ( + tpl_keypoints[9].all() != 0 + or tpl_keypoints[12].all() != 0 + or tpl_keypoints[10].all() != 0 + or tpl_keypoints[13].all() != 0 + ): + if ( + ( + tpl_keypoints[9][0] <= 1 + and tpl_keypoints[9][1] <= 1 + and tpl_keypoints[9][2] >= 0.75 + ) + or ( + tpl_keypoints[12][0] <= 1 + and tpl_keypoints[12][1] <= 1 + and tpl_keypoints[12][2] >= 0.75 + ) + or ( + tpl_keypoints[10][0] <= 1 + and tpl_keypoints[10][1] <= 1 + and tpl_keypoints[10][2] >= 0.75 + ) + or ( + tpl_keypoints[13][0] <= 1 + and tpl_keypoints[13][1] <= 1 + and tpl_keypoints[13][2] >= 0.75 + ) + ): + leg_visible = True + if arm_visible and leg_visible: + break + + if leg_visible: + if tpl_pose_meta["width"] > tpl_pose_meta["height"]: + tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image." + else: + tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image." + + if refer_pose_meta["width"] > refer_pose_meta["height"]: + refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image." + else: + refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image." + elif arm_visible: + if tpl_pose_meta["width"] > tpl_pose_meta["height"]: + tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image." + else: + tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image." + + if refer_pose_meta["width"] > refer_pose_meta["height"]: + refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image." + else: + refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image." + else: + tpl_prompt = "Change the person to face forward." + refer_prompt = "Change the person to face forward." + + return tpl_prompt, refer_prompt + + def get_mask(self, frames, th_step, kp2ds_all): + frame_num = len(frames) + if frame_num < th_step: + num_step = 1 + else: + num_step = (frame_num + th_step) // th_step + + all_mask = [] + for index in range(num_step): + each_frames = frames[index * th_step : (index + 1) * th_step] + + kp2ds = kp2ds_all[index * th_step : (index + 1) * th_step] + if len(each_frames) > 4: + key_frame_num = 4 + elif 4 >= len(each_frames) > 0: + key_frame_num = 1 + else: + continue + + key_frame_step = len(kp2ds) // key_frame_num + key_frame_index_list = list(range(0, len(kp2ds), key_frame_step)) + + key_points_index = [0, 1, 2, 5, 8, 11, 10, 13] + key_frame_body_points_list = [] + for key_frame_index in key_frame_index_list: + keypoints_body_list = [] + body_key_points = kp2ds[key_frame_index]["keypoints_body"] + for each_index in key_points_index: + each_keypoint = body_key_points[each_index] + if None is each_keypoint: + continue + keypoints_body_list.append(each_keypoint) + + keypoints_body = np.array(keypoints_body_list)[:, :2] + wh = np.array([[kp2ds[0]["width"], kp2ds[0]["height"]]]) + points = (keypoints_body * wh).astype(np.int32) + key_frame_body_points_list.append(points) + + inference_state = self.predictor.init_state_v2(frames=each_frames) + self.predictor.reset_state(inference_state) + ann_obj_id = 1 + for ann_frame_idx, points in zip( + key_frame_index_list, key_frame_body_points_list + ): + labels = np.array([1] * points.shape[0], np.int32) + _, out_obj_ids, out_mask_logits = self.predictor.add_new_points( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=ann_obj_id, + points=points, + labels=labels, + ) + + video_segments = {} + for ( + out_frame_idx, + out_obj_ids, + out_mask_logits, + ) in self.predictor.propagate_in_video(inference_state): + video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() + for i, out_obj_id in enumerate(out_obj_ids) + } + + for out_frame_idx in range(len(video_segments)): + for out_obj_id, out_mask in video_segments[out_frame_idx].items(): + out_mask = out_mask[0].astype(np.uint8) + all_mask.append(out_mask) + + return all_mask + + def convert_list_to_array(self, metas): + metas_list = [] + for meta in metas: + for key, value in meta.items(): + if type(value) is list: + value = np.array(value) + meta[key] = value + metas_list.append(meta) + return metas_list diff --git a/wan/modules/animate/preprocess/retarget_pose.py b/wan/modules/animate/preprocess/retarget_pose.py new file mode 100644 index 00000000..0589b57e --- /dev/null +++ b/wan/modules/animate/preprocess/retarget_pose.py @@ -0,0 +1,1166 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import os +import cv2 +import numpy as np +import json +from tqdm import tqdm +import math +from typing import NamedTuple, List +import copy +from .pose2d_utils import AAPoseMeta + + +# load skeleton name and bone lines +keypoint_list = [ + "Nose", + "Neck", + "RShoulder", + "RElbow", + "RWrist", # No.4 + "LShoulder", + "LElbow", + "LWrist", # No.7 + "RHip", + "RKnee", + "RAnkle", # No.10 + "LHip", + "LKnee", + "LAnkle", # No.13 + "REye", + "LEye", + "REar", + "LEar", + "LToe", + "RToe", +] + + +limbSeq = [ + [2, 3], + [2, 6], # shoulders + [3, 4], + [4, 5], # left arm + [6, 7], + [7, 8], # right arm + [2, 9], + [9, 10], + [10, 11], # right leg + [2, 12], + [12, 13], + [13, 14], # left leg + [2, 1], + [1, 15], + [15, 17], + [1, 16], + [16, 18], # face (nose, eyes, ears) + [14, 19], # left foot + [11, 20], # right foot +] + +eps = 0.01 + + +class Keypoint(NamedTuple): + x: float + y: float + score: float = 1.0 + id: int = -1 + + +# for each limb, calculate src & dst bone's length +# and calculate their ratios +def get_length(skeleton, limb): + + k1_index, k2_index = limb + + H, W = skeleton["height"], skeleton["width"] + keypoints = skeleton["keypoints_body"] + keypoint1 = keypoints[k1_index - 1] + keypoint2 = keypoints[k2_index - 1] + + if keypoint1 is None or keypoint2 is None: + return None, None, None + + X = np.array([keypoint1[0], keypoint2[0]]) * float(W) + Y = np.array([keypoint1[1], keypoint2[1]]) * float(H) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + + return X, Y, length + + +def get_handpose_meta(keypoints, delta, src_H, src_W): + + new_keypoints = [] + + for idx, keypoint in enumerate(keypoints): + if keypoint is None: + new_keypoints.append(None) + continue + if keypoint.score == 0: + new_keypoints.append(None) + continue + + x, y = keypoint.x, keypoint.y + x = int(x * src_W + delta[0]) + y = int(y * src_H + delta[1]) + + new_keypoints.append( + Keypoint( + x=x, + y=y, + score=keypoint.score, + ) + ) + + return new_keypoints + + +def deal_hand_keypoints(hand_res, r_ratio, l_ratio, hand_score_th=0.5): + + left_hand = [] + right_hand = [] + + left_delta_x = hand_res["left"][0][0] * (l_ratio - 1) + left_delta_y = hand_res["left"][0][1] * (l_ratio - 1) + + right_delta_x = hand_res["right"][0][0] * (r_ratio - 1) + right_delta_y = hand_res["right"][0][1] * (r_ratio - 1) + + length = len(hand_res["left"]) + + for i in range(length): + # left hand + if hand_res["left"][i][2] < hand_score_th: + left_hand.append( + Keypoint( + x=-1, + y=-1, + score=0, + ) + ) + else: + left_hand.append( + Keypoint( + x=hand_res["left"][i][0] * l_ratio - left_delta_x, + y=hand_res["left"][i][1] * l_ratio - left_delta_y, + score=hand_res["left"][i][2], + ) + ) + + # right hand + if hand_res["right"][i][2] < hand_score_th: + right_hand.append( + Keypoint( + x=-1, + y=-1, + score=0, + ) + ) + else: + right_hand.append( + Keypoint( + x=hand_res["right"][i][0] * r_ratio - right_delta_x, + y=hand_res["right"][i][1] * r_ratio - right_delta_y, + score=hand_res["right"][i][2], + ) + ) + + return right_hand, left_hand + + +def get_scaled_pose( + canvas, + src_canvas, + keypoints, + keypoints_hand, + bone_ratio_list, + delta_ground_x, + delta_ground_y, + rescaled_src_ground_x, + body_flag, + id, + scale_min, + threshold=0.4, +): + + H, W = canvas + src_H, src_W = src_canvas + + new_length_list = [] + angle_list = [] + + # keypoints from 0-1 to H/W range + for idx in range(len(keypoints)): + if keypoints[idx] is None or len(keypoints[idx]) == 0: + continue + + keypoints[idx] = [ + keypoints[idx][0] * src_W, + keypoints[idx][1] * src_H, + keypoints[idx][2], + ] + + # first traverse, get new_length_list and angle_list + for idx, (k1_index, k2_index) in enumerate(limbSeq): + keypoint1 = keypoints[k1_index - 1] + keypoint2 = keypoints[k2_index - 1] + + if ( + keypoint1 is None + or keypoint2 is None + or len(keypoint1) == 0 + or len(keypoint2) == 0 + ): + new_length_list.append(None) + angle_list.append(None) + continue + + Y = np.array([keypoint1[0], keypoint2[0]]) # * float(W) + X = np.array([keypoint1[1], keypoint2[1]]) # * float(H) + + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + + new_length = length * bone_ratio_list[idx] + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + + new_length_list.append(new_length) + angle_list.append(angle) + + # Keep foot length within 0.5x calf length + foot_lower_leg_ratio = 0.5 + if new_length_list[8] != None and new_length_list[18] != None: + if new_length_list[18] > new_length_list[8] * foot_lower_leg_ratio: + new_length_list[18] = new_length_list[8] * foot_lower_leg_ratio + + if new_length_list[11] != None and new_length_list[17] != None: + if new_length_list[17] > new_length_list[11] * foot_lower_leg_ratio: + new_length_list[17] = new_length_list[11] * foot_lower_leg_ratio + + # second traverse, calculate new keypoints + rescale_keypoints = keypoints.copy() + + for idx, (k1_index, k2_index) in enumerate(limbSeq): + # update dst_keypoints + start_keypoint = rescale_keypoints[k1_index - 1] + new_length = new_length_list[idx] + angle = angle_list[idx] + + if ( + rescale_keypoints[k1_index - 1] is None + or rescale_keypoints[k2_index - 1] is None + or len(rescale_keypoints[k1_index - 1]) == 0 + or len(rescale_keypoints[k2_index - 1]) == 0 + ): + continue + + # calculate end_keypoint + delta_x = new_length * math.cos(math.radians(angle)) + delta_y = new_length * math.sin(math.radians(angle)) + + end_keypoint_x = start_keypoint[0] - delta_x + end_keypoint_y = start_keypoint[1] - delta_y + + # update keypoints + rescale_keypoints[k2_index - 1] = [ + end_keypoint_x, + end_keypoint_y, + rescale_keypoints[k2_index - 1][2], + ] + + if id == 0: + if ( + body_flag == "full_body" + and rescale_keypoints[8] != None + and rescale_keypoints[11] != None + ): + delta_ground_x_offset_first_frame = ( + rescale_keypoints[8][0] + rescale_keypoints[11][0] + ) / 2 - rescaled_src_ground_x + delta_ground_x += delta_ground_x_offset_first_frame + elif body_flag == "half_body" and rescale_keypoints[1] != None: + delta_ground_x_offset_first_frame = ( + rescale_keypoints[1][0] - rescaled_src_ground_x + ) + delta_ground_x += delta_ground_x_offset_first_frame + + # offset all keypoints + for idx in range(len(rescale_keypoints)): + if rescale_keypoints[idx] is None or len(rescale_keypoints[idx]) == 0: + continue + rescale_keypoints[idx][0] -= delta_ground_x + rescale_keypoints[idx][1] -= delta_ground_y + + # rescale keypoints to original size + rescale_keypoints[idx][0] /= scale_min + rescale_keypoints[idx][1] /= scale_min + + # Scale hand proportions based on body skeletal ratios + r_ratio = max(bone_ratio_list[0], bone_ratio_list[1]) / scale_min + l_ratio = max(bone_ratio_list[0], bone_ratio_list[1]) / scale_min + left_hand, right_hand = deal_hand_keypoints( + keypoints_hand, r_ratio, l_ratio, hand_score_th=threshold + ) + + left_hand_new = left_hand.copy() + right_hand_new = right_hand.copy() + + if rescale_keypoints[4] == None and rescale_keypoints[7] == None: + pass + + elif rescale_keypoints[4] == None and rescale_keypoints[7] != None: + right_hand_delta = np.array(rescale_keypoints[7][:2]) - np.array( + keypoints[7][:2] + ) + right_hand_new = get_handpose_meta(right_hand, right_hand_delta, src_H, src_W) + + elif rescale_keypoints[4] != None and rescale_keypoints[7] == None: + left_hand_delta = np.array(rescale_keypoints[4][:2]) - np.array( + keypoints[4][:2] + ) + left_hand_new = get_handpose_meta(left_hand, left_hand_delta, src_H, src_W) + + else: + # get left_hand and right_hand offset + left_hand_delta = np.array(rescale_keypoints[4][:2]) - np.array( + keypoints[4][:2] + ) + right_hand_delta = np.array(rescale_keypoints[7][:2]) - np.array( + keypoints[7][:2] + ) + + if keypoints[4][0] != None and left_hand[0].x != -1: + left_hand_root_offset = np.array( + ( + keypoints[4][0] - left_hand[0].x * src_W, + keypoints[4][1] - left_hand[0].y * src_H, + ) + ) + left_hand_delta += left_hand_root_offset + + if keypoints[7][0] != None and right_hand[0].x != -1: + right_hand_root_offset = np.array( + ( + keypoints[7][0] - right_hand[0].x * src_W, + keypoints[7][1] - right_hand[0].y * src_H, + ) + ) + right_hand_delta += right_hand_root_offset + + dis_left_hand = ( + (keypoints[4][0] - left_hand[0].x * src_W) ** 2 + + (keypoints[4][1] - left_hand[0].y * src_H) ** 2 + ) ** 0.5 + dis_right_hand = ( + (keypoints[7][0] - left_hand[0].x * src_W) ** 2 + + (keypoints[7][1] - left_hand[0].y * src_H) ** 2 + ) ** 0.5 + + if dis_left_hand > dis_right_hand: + right_hand_new = get_handpose_meta( + left_hand, right_hand_delta, src_H, src_W + ) + left_hand_new = get_handpose_meta(right_hand, left_hand_delta, src_H, src_W) + else: + left_hand_new = get_handpose_meta(left_hand, left_hand_delta, src_H, src_W) + right_hand_new = get_handpose_meta( + right_hand, right_hand_delta, src_H, src_W + ) + + # get normalized keypoints_body + norm_body_keypoints = [] + for body_keypoint in rescale_keypoints: + if body_keypoint != None: + norm_body_keypoints.append( + [body_keypoint[0] / W, body_keypoint[1] / H, body_keypoint[2]] + ) + else: + norm_body_keypoints.append(None) + + frame_info = { + "height": H, + "width": W, + "keypoints_body": norm_body_keypoints, + "keypoints_left_hand": left_hand_new, + "keypoints_right_hand": right_hand_new, + } + + return frame_info + + +def rescale_skeleton(H, W, keypoints, bone_ratio_list): + + rescale_keypoints = keypoints.copy() + + new_length_list = [] + angle_list = [] + + # keypoints from 0-1 to H/W range + for idx in range(len(rescale_keypoints)): + if rescale_keypoints[idx] is None or len(rescale_keypoints[idx]) == 0: + continue + + rescale_keypoints[idx] = [ + rescale_keypoints[idx][0] * W, + rescale_keypoints[idx][1] * H, + ] + + # first traverse, get new_length_list and angle_list + for idx, (k1_index, k2_index) in enumerate(limbSeq): + keypoint1 = rescale_keypoints[k1_index - 1] + keypoint2 = rescale_keypoints[k2_index - 1] + + if ( + keypoint1 is None + or keypoint2 is None + or len(keypoint1) == 0 + or len(keypoint2) == 0 + ): + new_length_list.append(None) + angle_list.append(None) + continue + + Y = np.array([keypoint1[0], keypoint2[0]]) # * float(W) + X = np.array([keypoint1[1], keypoint2[1]]) # * float(H) + + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + + new_length = length * bone_ratio_list[idx] + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + + new_length_list.append(new_length) + angle_list.append(angle) + + # # second traverse, calculate new keypoints + for idx, (k1_index, k2_index) in enumerate(limbSeq): + # update dst_keypoints + start_keypoint = rescale_keypoints[k1_index - 1] + new_length = new_length_list[idx] + angle = angle_list[idx] + + if ( + rescale_keypoints[k1_index - 1] is None + or rescale_keypoints[k2_index - 1] is None + or len(rescale_keypoints[k1_index - 1]) == 0 + or len(rescale_keypoints[k2_index - 1]) == 0 + ): + continue + + # calculate end_keypoint + delta_x = new_length * math.cos(math.radians(angle)) + delta_y = new_length * math.sin(math.radians(angle)) + + end_keypoint_x = start_keypoint[0] - delta_x + end_keypoint_y = start_keypoint[1] - delta_y + + # update keypoints + rescale_keypoints[k2_index - 1] = [end_keypoint_x, end_keypoint_y] + + return rescale_keypoints + + +def fix_lack_keypoints_use_sym(skeleton): + + keypoints = skeleton["keypoints_body"] + H, W = skeleton["height"], skeleton["width"] + + limb_points_list = [ + [3, 4, 5], + [6, 7, 8], + [12, 13, 14, 19], + [9, 10, 11, 20], + ] + + for limb_points in limb_points_list: + miss_flag = False + for point in limb_points: + if keypoints[point - 1] is None: + miss_flag = True + continue + if miss_flag: + skeleton["keypoints_body"][point - 1] = None + + repair_limb_seq_left = [ + [3, 4], + [4, 5], # left arm + [12, 13], + [13, 14], # left leg + [14, 19], # left foot + ] + + repair_limb_seq_right = [ + [6, 7], + [7, 8], # right arm + [9, 10], + [10, 11], # right leg + [11, 20], # right foot + ] + + repair_limb_seq = [repair_limb_seq_left, repair_limb_seq_right] + + for idx_part, part in enumerate(repair_limb_seq): + for idx, limb in enumerate(part): + + k1_index, k2_index = limb + keypoint1 = keypoints[k1_index - 1] + keypoint2 = keypoints[k2_index - 1] + + if keypoint1 != None and keypoint2 is None: + # reference to symmetric limb + sym_limb = repair_limb_seq[1 - idx_part][idx] + k1_index_sym, k2_index_sym = sym_limb + keypoint1_sym = keypoints[k1_index_sym - 1] + keypoint2_sym = keypoints[k2_index_sym - 1] + ref_length = 0 + + if keypoint1_sym != None and keypoint2_sym != None: + X = np.array([keypoint1_sym[0], keypoint2_sym[0]]) * float(W) + Y = np.array([keypoint1_sym[1], keypoint2_sym[1]]) * float(H) + ref_length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + else: + ref_length_left, ref_length_right = 0, 0 + if keypoints[1] != None and keypoints[8] != None: + X = np.array([keypoints[1][0], keypoints[8][0]]) * float(W) + Y = np.array([keypoints[1][1], keypoints[8][1]]) * float(H) + ref_length_left = ( + (X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2 + ) ** 0.5 + if idx <= 1: # arms + ref_length_left /= 2 + + if keypoints[1] != None and keypoints[11] != None: + X = np.array([keypoints[1][0], keypoints[11][0]]) * float(W) + Y = np.array([keypoints[1][1], keypoints[11][1]]) * float(H) + ref_length_right = ( + (X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2 + ) ** 0.5 + if idx <= 1: # arms + ref_length_right /= 2 + elif idx == 4: # foot + ref_length_right /= 5 + + ref_length = max(ref_length_left, ref_length_right) + + if ref_length != 0: + skeleton["keypoints_body"][k2_index - 1] = [0, 0] # init + skeleton["keypoints_body"][k2_index - 1][0] = skeleton[ + "keypoints_body" + ][k1_index - 1][0] + skeleton["keypoints_body"][k2_index - 1][1] = ( + skeleton["keypoints_body"][k1_index - 1][1] + ref_length / H + ) + return skeleton + + +def rescale_shorten_skeleton(ratio_list, src_length_list, dst_length_list): + + modify_bone_list = [[0, 1], [2, 4], [3, 5], [6, 9], [7, 10], [8, 11], [17, 18]] + + for modify_bone in modify_bone_list: + new_ratio = max(ratio_list[modify_bone[0]], ratio_list[modify_bone[1]]) + ratio_list[modify_bone[0]] = new_ratio + ratio_list[modify_bone[1]] = new_ratio + + if ratio_list[13] != None and ratio_list[15] != None: + ratio_eye_avg = (ratio_list[13] + ratio_list[15]) / 2 + ratio_list[13] = ratio_eye_avg + ratio_list[15] = ratio_eye_avg + + if ratio_list[14] != None and ratio_list[16] != None: + ratio_eye_avg = (ratio_list[14] + ratio_list[16]) / 2 + ratio_list[14] = ratio_eye_avg + ratio_list[16] = ratio_eye_avg + + return ratio_list, src_length_list, dst_length_list + + +def check_full_body(keypoints, threshold=0.4): + + body_flag = "half_body" + + # 1. If ankle points exist, confidence is greater than the threshold, and points do not exceed the frame, return full_body + if ( + keypoints[10] != None + and keypoints[13] != None + and keypoints[8] != None + and keypoints[11] != None + ): + if ( + (keypoints[10][1] <= 1 and keypoints[13][1] <= 1) + and (keypoints[10][2] >= threshold and keypoints[13][2] >= threshold) + and (keypoints[8][1] <= 1 and keypoints[11][1] <= 1) + and (keypoints[8][2] >= threshold and keypoints[11][2] >= threshold) + ): + body_flag = "full_body" + return body_flag + + # 2. If hip points exist, return three_quarter_body + if keypoints[8] != None and keypoints[11] != None: + if (keypoints[8][1] <= 1 and keypoints[11][1] <= 1) and ( + keypoints[8][2] >= threshold and keypoints[11][2] >= threshold + ): + body_flag = "three_quarter_body" + return body_flag + + return body_flag + + +def check_full_body_both(flag1, flag2): + body_flag_dict = {"full_body": 2, "three_quarter_body": 1, "half_body": 0} + + body_flag_dict_reverse = {2: "full_body", 1: "three_quarter_body", 0: "half_body"} + + flag1_num = body_flag_dict[flag1] + flag2_num = body_flag_dict[flag2] + flag_both_num = min(flag1_num, flag2_num) + return body_flag_dict_reverse[flag_both_num] + + +def write_to_poses( + data_to_json, + none_idx, + dst_shape, + bone_ratio_list, + delta_ground_x, + delta_ground_y, + rescaled_src_ground_x, + body_flag, + scale_min, +): + outputs = [] + length = len(data_to_json) + for id in tqdm(range(length)): + + src_height, src_width = data_to_json[id]["height"], data_to_json[id]["width"] + width, height = dst_shape + keypoints = data_to_json[id]["keypoints_body"] + for idx in range(len(keypoints)): + if idx in none_idx: + keypoints[idx] = None + new_keypoints = keypoints.copy() + + # get hand keypoints + keypoints_hand = { + "left": data_to_json[id]["keypoints_left_hand"], + "right": data_to_json[id]["keypoints_right_hand"], + } + # Normalize hand coordinates to 0-1 range + for hand_idx in range(len(data_to_json[id]["keypoints_left_hand"])): + data_to_json[id]["keypoints_left_hand"][hand_idx][0] = ( + data_to_json[id]["keypoints_left_hand"][hand_idx][0] / src_width + ) + data_to_json[id]["keypoints_left_hand"][hand_idx][1] = ( + data_to_json[id]["keypoints_left_hand"][hand_idx][1] / src_height + ) + + for hand_idx in range(len(data_to_json[id]["keypoints_right_hand"])): + data_to_json[id]["keypoints_right_hand"][hand_idx][0] = ( + data_to_json[id]["keypoints_right_hand"][hand_idx][0] / src_width + ) + data_to_json[id]["keypoints_right_hand"][hand_idx][1] = ( + data_to_json[id]["keypoints_right_hand"][hand_idx][1] / src_height + ) + + frame_info = get_scaled_pose( + (height, width), + (src_height, src_width), + new_keypoints, + keypoints_hand, + bone_ratio_list, + delta_ground_x, + delta_ground_y, + rescaled_src_ground_x, + body_flag, + id, + scale_min, + ) + outputs.append(frame_info) + + return outputs + + +def calculate_scale_ratio(skeleton, skeleton_edit, scale_ratio_flag): + if scale_ratio_flag: + + headw = max( + skeleton["keypoints_body"][0][0], + skeleton["keypoints_body"][14][0], + skeleton["keypoints_body"][15][0], + skeleton["keypoints_body"][16][0], + skeleton["keypoints_body"][17][0], + ) - min( + skeleton["keypoints_body"][0][0], + skeleton["keypoints_body"][14][0], + skeleton["keypoints_body"][15][0], + skeleton["keypoints_body"][16][0], + skeleton["keypoints_body"][17][0], + ) + headw_edit = max( + skeleton_edit["keypoints_body"][0][0], + skeleton_edit["keypoints_body"][14][0], + skeleton_edit["keypoints_body"][15][0], + skeleton_edit["keypoints_body"][16][0], + skeleton_edit["keypoints_body"][17][0], + ) - min( + skeleton_edit["keypoints_body"][0][0], + skeleton_edit["keypoints_body"][14][0], + skeleton_edit["keypoints_body"][15][0], + skeleton_edit["keypoints_body"][16][0], + skeleton_edit["keypoints_body"][17][0], + ) + headw_ratio = headw / headw_edit + + _, _, shoulder = get_length(skeleton, [6, 3]) + _, _, shoulder_edit = get_length(skeleton_edit, [6, 3]) + shoulder_ratio = shoulder / shoulder_edit + + return max(headw_ratio, shoulder_ratio) + + else: + return 1 + + +def retarget_pose( + src_skeleton, + dst_skeleton, + all_src_skeleton, + src_skeleton_edit, + dst_skeleton_edit, + threshold=0.4, +): + + if src_skeleton_edit is not None and dst_skeleton_edit is not None: + use_edit_for_base = True + else: + use_edit_for_base = False + + src_skeleton_ori = copy.deepcopy(src_skeleton) + + dst_skeleton_ori_h, dst_skeleton_ori_w = ( + dst_skeleton["height"], + dst_skeleton["width"], + ) + if ( + src_skeleton["keypoints_body"][0] != None + and src_skeleton["keypoints_body"][10] != None + and src_skeleton["keypoints_body"][13] != None + and dst_skeleton["keypoints_body"][0] != None + and dst_skeleton["keypoints_body"][10] != None + and dst_skeleton["keypoints_body"][13] != None + and src_skeleton["keypoints_body"][0][2] > 0.5 + and src_skeleton["keypoints_body"][10][2] > 0.5 + and src_skeleton["keypoints_body"][13][2] > 0.5 + and dst_skeleton["keypoints_body"][0][2] > 0.5 + and dst_skeleton["keypoints_body"][10][2] > 0.5 + and dst_skeleton["keypoints_body"][13][2] > 0.5 + ): + + src_height = src_skeleton["height"] * abs( + ( + src_skeleton["keypoints_body"][10][1] + + src_skeleton["keypoints_body"][13][1] + ) + / 2 + - src_skeleton["keypoints_body"][0][1] + ) + dst_height = dst_skeleton["height"] * abs( + ( + dst_skeleton["keypoints_body"][10][1] + + dst_skeleton["keypoints_body"][13][1] + ) + / 2 + - dst_skeleton["keypoints_body"][0][1] + ) + scale_min = 1.0 * src_height / dst_height + elif ( + src_skeleton["keypoints_body"][0] != None + and src_skeleton["keypoints_body"][8] != None + and src_skeleton["keypoints_body"][11] != None + and dst_skeleton["keypoints_body"][0] != None + and dst_skeleton["keypoints_body"][8] != None + and dst_skeleton["keypoints_body"][11] != None + and src_skeleton["keypoints_body"][0][2] > 0.5 + and src_skeleton["keypoints_body"][8][2] > 0.5 + and src_skeleton["keypoints_body"][11][2] > 0.5 + and dst_skeleton["keypoints_body"][0][2] > 0.5 + and dst_skeleton["keypoints_body"][8][2] > 0.5 + and dst_skeleton["keypoints_body"][11][2] > 0.5 + ): + + src_height = src_skeleton["height"] * abs( + ( + src_skeleton["keypoints_body"][8][1] + + src_skeleton["keypoints_body"][11][1] + ) + / 2 + - src_skeleton["keypoints_body"][0][1] + ) + dst_height = dst_skeleton["height"] * abs( + ( + dst_skeleton["keypoints_body"][8][1] + + dst_skeleton["keypoints_body"][11][1] + ) + / 2 + - dst_skeleton["keypoints_body"][0][1] + ) + scale_min = 1.0 * src_height / dst_height + else: + scale_min = np.sqrt(src_skeleton["height"] * src_skeleton["width"]) / np.sqrt( + dst_skeleton["height"] * dst_skeleton["width"] + ) + + if use_edit_for_base: + scale_ratio_flag = False + if ( + src_skeleton_edit["keypoints_body"][0] != None + and src_skeleton_edit["keypoints_body"][10] != None + and src_skeleton_edit["keypoints_body"][13] != None + and dst_skeleton_edit["keypoints_body"][0] != None + and dst_skeleton_edit["keypoints_body"][10] != None + and dst_skeleton_edit["keypoints_body"][13] != None + and src_skeleton_edit["keypoints_body"][0][2] > 0.5 + and src_skeleton_edit["keypoints_body"][10][2] > 0.5 + and src_skeleton_edit["keypoints_body"][13][2] > 0.5 + and dst_skeleton_edit["keypoints_body"][0][2] > 0.5 + and dst_skeleton_edit["keypoints_body"][10][2] > 0.5 + and dst_skeleton_edit["keypoints_body"][13][2] > 0.5 + ): + + src_height_edit = src_skeleton_edit["height"] * abs( + ( + src_skeleton_edit["keypoints_body"][10][1] + + src_skeleton_edit["keypoints_body"][13][1] + ) + / 2 + - src_skeleton_edit["keypoints_body"][0][1] + ) + dst_height_edit = dst_skeleton_edit["height"] * abs( + ( + dst_skeleton_edit["keypoints_body"][10][1] + + dst_skeleton_edit["keypoints_body"][13][1] + ) + / 2 + - dst_skeleton_edit["keypoints_body"][0][1] + ) + scale_min_edit = 1.0 * src_height_edit / dst_height_edit + elif ( + src_skeleton_edit["keypoints_body"][0] != None + and src_skeleton_edit["keypoints_body"][8] != None + and src_skeleton_edit["keypoints_body"][11] != None + and dst_skeleton_edit["keypoints_body"][0] != None + and dst_skeleton_edit["keypoints_body"][8] != None + and dst_skeleton_edit["keypoints_body"][11] != None + and src_skeleton_edit["keypoints_body"][0][2] > 0.5 + and src_skeleton_edit["keypoints_body"][8][2] > 0.5 + and src_skeleton_edit["keypoints_body"][11][2] > 0.5 + and dst_skeleton_edit["keypoints_body"][0][2] > 0.5 + and dst_skeleton_edit["keypoints_body"][8][2] > 0.5 + and dst_skeleton_edit["keypoints_body"][11][2] > 0.5 + ): + + src_height_edit = src_skeleton_edit["height"] * abs( + ( + src_skeleton_edit["keypoints_body"][8][1] + + src_skeleton_edit["keypoints_body"][11][1] + ) + / 2 + - src_skeleton_edit["keypoints_body"][0][1] + ) + dst_height_edit = dst_skeleton_edit["height"] * abs( + ( + dst_skeleton_edit["keypoints_body"][8][1] + + dst_skeleton_edit["keypoints_body"][11][1] + ) + / 2 + - dst_skeleton_edit["keypoints_body"][0][1] + ) + scale_min_edit = 1.0 * src_height_edit / dst_height_edit + else: + scale_min_edit = np.sqrt( + src_skeleton_edit["height"] * src_skeleton_edit["width"] + ) / np.sqrt(dst_skeleton_edit["height"] * dst_skeleton_edit["width"]) + scale_ratio_flag = True + + # Flux may change the scale, compensate for it here + ratio_src = calculate_scale_ratio( + src_skeleton, src_skeleton_edit, scale_ratio_flag + ) + ratio_dst = calculate_scale_ratio( + dst_skeleton, dst_skeleton_edit, scale_ratio_flag + ) + + dst_skeleton_edit["height"] = int(dst_skeleton_edit["height"] * scale_min_edit) + dst_skeleton_edit["width"] = int(dst_skeleton_edit["width"] * scale_min_edit) + for idx in range(len(dst_skeleton_edit["keypoints_left_hand"])): + dst_skeleton_edit["keypoints_left_hand"][idx][0] *= scale_min_edit + dst_skeleton_edit["keypoints_left_hand"][idx][1] *= scale_min_edit + for idx in range(len(dst_skeleton_edit["keypoints_right_hand"])): + dst_skeleton_edit["keypoints_right_hand"][idx][0] *= scale_min_edit + dst_skeleton_edit["keypoints_right_hand"][idx][1] *= scale_min_edit + + dst_skeleton["height"] = int(dst_skeleton["height"] * scale_min) + dst_skeleton["width"] = int(dst_skeleton["width"] * scale_min) + for idx in range(len(dst_skeleton["keypoints_left_hand"])): + dst_skeleton["keypoints_left_hand"][idx][0] *= scale_min + dst_skeleton["keypoints_left_hand"][idx][1] *= scale_min + for idx in range(len(dst_skeleton["keypoints_right_hand"])): + dst_skeleton["keypoints_right_hand"][idx][0] *= scale_min + dst_skeleton["keypoints_right_hand"][idx][1] *= scale_min + + dst_body_flag = check_full_body(dst_skeleton["keypoints_body"], threshold) + src_body_flag = check_full_body(src_skeleton_ori["keypoints_body"], threshold) + body_flag = check_full_body_both(dst_body_flag, src_body_flag) + # print('body_flag: ', body_flag) + + if use_edit_for_base: + src_skeleton_edit = fix_lack_keypoints_use_sym(src_skeleton_edit) + dst_skeleton_edit = fix_lack_keypoints_use_sym(dst_skeleton_edit) + else: + src_skeleton = fix_lack_keypoints_use_sym(src_skeleton) + dst_skeleton = fix_lack_keypoints_use_sym(dst_skeleton) + + none_idx = [] + for idx in range(len(dst_skeleton["keypoints_body"])): + if ( + dst_skeleton["keypoints_body"][idx] == None + or src_skeleton["keypoints_body"][idx] == None + ): + src_skeleton["keypoints_body"][idx] = None + dst_skeleton["keypoints_body"][idx] = None + none_idx.append(idx) + + # get bone ratio list + ratio_list, src_length_list, dst_length_list = [], [], [] + for idx, limb in enumerate(limbSeq): + if use_edit_for_base: + src_X, src_Y, src_length = get_length(src_skeleton_edit, limb) + dst_X, dst_Y, dst_length = get_length(dst_skeleton_edit, limb) + + if src_X is None or src_Y is None or dst_X is None or dst_Y is None: + ratio = -1 + else: + ratio = 1.0 * dst_length * ratio_dst / src_length / ratio_src + + else: + src_X, src_Y, src_length = get_length(src_skeleton, limb) + dst_X, dst_Y, dst_length = get_length(dst_skeleton, limb) + + if src_X is None or src_Y is None or dst_X is None or dst_Y is None: + ratio = -1 + else: + ratio = 1.0 * dst_length / src_length + + ratio_list.append(ratio) + src_length_list.append(src_length) + dst_length_list.append(dst_length) + + for idx, ratio in enumerate(ratio_list): + if ratio == -1: + if ratio_list[0] != -1 and ratio_list[1] != -1: + ratio_list[idx] = (ratio_list[0] + ratio_list[1]) / 2 + + # Consider adding constraints when Flux fails to correct head pose, causing neck issues. + # if ratio_list[12] > (ratio_list[0]+ratio_list[1])/2*1.25: + # ratio_list[12] = (ratio_list[0]+ratio_list[1])/2*1.25 + + ratio_list, src_length_list, dst_length_list = rescale_shorten_skeleton( + ratio_list, src_length_list, dst_length_list + ) + + rescaled_src_skeleton_ori = rescale_skeleton( + src_skeleton_ori["height"], + src_skeleton_ori["width"], + src_skeleton_ori["keypoints_body"], + ratio_list, + ) + + # get global translation offset_x and offset_y + if body_flag == "full_body": + # print('use foot mark.') + dst_ground_y = ( + max( + dst_skeleton["keypoints_body"][10][1], + dst_skeleton["keypoints_body"][13][1], + ) + * dst_skeleton["height"] + ) + # The midpoint between toe and ankle + if ( + dst_skeleton["keypoints_body"][18] != None + and dst_skeleton["keypoints_body"][19] != None + ): + right_foot_mid = ( + dst_skeleton["keypoints_body"][10][1] + + dst_skeleton["keypoints_body"][19][1] + ) / 2 + left_foot_mid = ( + dst_skeleton["keypoints_body"][13][1] + + dst_skeleton["keypoints_body"][18][1] + ) / 2 + dst_ground_y = max(left_foot_mid, right_foot_mid) * dst_skeleton["height"] + + rescaled_src_ground_y = max( + rescaled_src_skeleton_ori[10][1], rescaled_src_skeleton_ori[13][1] + ) + delta_ground_y = rescaled_src_ground_y - dst_ground_y + + dst_ground_x = ( + ( + dst_skeleton["keypoints_body"][8][0] + + dst_skeleton["keypoints_body"][11][0] + ) + * dst_skeleton["width"] + / 2 + ) + rescaled_src_ground_x = ( + rescaled_src_skeleton_ori[8][0] + rescaled_src_skeleton_ori[11][0] + ) / 2 + delta_ground_x = rescaled_src_ground_x - dst_ground_x + delta_x, delta_y = delta_ground_x, delta_ground_y + + else: + # print('use neck mark.') + # use neck keypoint as mark + src_neck_y = rescaled_src_skeleton_ori[1][1] + dst_neck_y = dst_skeleton["keypoints_body"][1][1] + delta_neck_y = src_neck_y - dst_neck_y * dst_skeleton["height"] + + src_neck_x = rescaled_src_skeleton_ori[1][0] + dst_neck_x = dst_skeleton["keypoints_body"][1][0] + delta_neck_x = src_neck_x - dst_neck_x * dst_skeleton["width"] + delta_x, delta_y = delta_neck_x, delta_neck_y + rescaled_src_ground_x = src_neck_x + + dst_shape = (dst_skeleton_ori_w, dst_skeleton_ori_h) + output = write_to_poses( + all_src_skeleton, + none_idx, + dst_shape, + ratio_list, + delta_x, + delta_y, + rescaled_src_ground_x, + body_flag, + scale_min, + ) + return output + + +def get_retarget_pose( + tpl_pose_meta0, + refer_pose_meta, + tpl_pose_metas, + tql_edit_pose_meta0, + refer_edit_pose_meta, +): + + for key, value in tpl_pose_meta0.items(): + if type(value) is np.ndarray: + if key in ["keypoints_left_hand", "keypoints_right_hand"]: + value = value * np.array( + [[tpl_pose_meta0["width"], tpl_pose_meta0["height"], 1.0]] + ) + if not isinstance(value, list): + value = value.tolist() + tpl_pose_meta0[key] = value + + for key, value in refer_pose_meta.items(): + if type(value) is np.ndarray: + if key in ["keypoints_left_hand", "keypoints_right_hand"]: + value = value * np.array( + [[refer_pose_meta["width"], refer_pose_meta["height"], 1.0]] + ) + if not isinstance(value, list): + value = value.tolist() + refer_pose_meta[key] = value + + tpl_pose_metas_new = [] + for meta in tpl_pose_metas: + for key, value in meta.items(): + if type(value) is np.ndarray: + if key in ["keypoints_left_hand", "keypoints_right_hand"]: + value = value * np.array([[meta["width"], meta["height"], 1.0]]) + if not isinstance(value, list): + value = value.tolist() + meta[key] = value + tpl_pose_metas_new.append(meta) + + if tql_edit_pose_meta0 is not None: + for key, value in tql_edit_pose_meta0.items(): + if type(value) is np.ndarray: + if key in ["keypoints_left_hand", "keypoints_right_hand"]: + value = value * np.array( + [ + [ + tql_edit_pose_meta0["width"], + tql_edit_pose_meta0["height"], + 1.0, + ] + ] + ) + if not isinstance(value, list): + value = value.tolist() + tql_edit_pose_meta0[key] = value + + if refer_edit_pose_meta is not None: + for key, value in refer_edit_pose_meta.items(): + if type(value) is np.ndarray: + if key in ["keypoints_left_hand", "keypoints_right_hand"]: + value = value * np.array( + [ + [ + refer_edit_pose_meta["width"], + refer_edit_pose_meta["height"], + 1.0, + ] + ] + ) + if not isinstance(value, list): + value = value.tolist() + refer_edit_pose_meta[key] = value + + retarget_tpl_pose_metas = retarget_pose( + tpl_pose_meta0, + refer_pose_meta, + tpl_pose_metas_new, + tql_edit_pose_meta0, + refer_edit_pose_meta, + ) + + pose_metas = [] + for meta in retarget_tpl_pose_metas: + pose_meta = AAPoseMeta() + width, height = meta["width"], meta["height"] + pose_meta.width = width + pose_meta.height = height + pose_meta.kps_body = np.array(meta["keypoints_body"])[:, :2] * (width, height) + pose_meta.kps_body_p = np.array(meta["keypoints_body"])[:, 2] + + kps_lhand = [] + kps_lhand_p = [] + for each_kps_lhand in meta["keypoints_left_hand"]: + if each_kps_lhand is not None: + kps_lhand.append([each_kps_lhand.x, each_kps_lhand.y]) + kps_lhand_p.append(each_kps_lhand.score) + else: + kps_lhand.append([None, None]) + kps_lhand_p.append(0.0) + + pose_meta.kps_lhand = np.array(kps_lhand) + pose_meta.kps_lhand_p = np.array(kps_lhand_p) + + kps_rhand = [] + kps_rhand_p = [] + for each_kps_rhand in meta["keypoints_right_hand"]: + if each_kps_rhand is not None: + kps_rhand.append([each_kps_rhand.x, each_kps_rhand.y]) + kps_rhand_p.append(each_kps_rhand.score) + else: + kps_rhand.append([None, None]) + kps_rhand_p.append(0.0) + + pose_meta.kps_rhand = np.array(kps_rhand) + pose_meta.kps_rhand_p = np.array(kps_rhand_p) + + pose_metas.append(pose_meta) + + return pose_metas diff --git a/wan/modules/animate/preprocess/sam_utils.py b/wan/modules/animate/preprocess/sam_utils.py new file mode 100644 index 00000000..cbb1fcb2 --- /dev/null +++ b/wan/modules/animate/preprocess/sam_utils.py @@ -0,0 +1,157 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +# This file wraps and extends sam2.utils.misc for custom modifications. + +from sam2.utils import misc as sam2_misc +from sam2.utils.misc import * +from PIL import Image +import numpy as np +import torch +from tqdm import tqdm +import os + +import logging + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf + +from sam2.utils.misc import AsyncVideoFrameLoader, _load_img_as_tensor +from sam2.build_sam import _load_checkpoint + + +def _load_img_v2_as_tensor(img, image_size): + img_pil = Image.fromarray(img.astype(np.uint8)) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + frame_names=None, +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError("Only JPEG frames are supported at this moment") + if frame_names is None: + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, image_size, offload_video_to_cpu, img_mean, img_std + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.cuda() + img_mean = img_mean.cuda() + img_std = img_std.cuda() + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_v2( + frames, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + frame_names=None, +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + num_frames = len(frames) + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, frame in enumerate(tqdm(frames, desc="video frame")): + images[n], video_height, video_width = _load_img_v2_as_tensor(frame, image_size) + if not offload_video_to_cpu: + images = images.cuda() + img_mean = img_mean.cuda() + img_std = img_std.cuda() + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def build_sam2_video_predictor( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, +): + hydra_overrides = [ + "++model._target_=video_predictor.SAM2VideoPredictor", + ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + "++model.fill_hole_area=8", + ] + + hydra_overrides.extend(hydra_overrides_extra) + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model diff --git a/wan/modules/animate/preprocess/utils.py b/wan/modules/animate/preprocess/utils.py new file mode 100644 index 00000000..8b525977 --- /dev/null +++ b/wan/modules/animate/preprocess/utils.py @@ -0,0 +1,250 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import os +import cv2 +import math +import random +import numpy as np + + +def get_mask_boxes(mask): + """ + + Args: + mask: [h, w] + Returns: + + """ + y_coords, x_coords = np.nonzero(mask) + x_min = x_coords.min() + x_max = x_coords.max() + y_min = y_coords.min() + y_max = y_coords.max() + bbox = np.array([x_min, y_min, x_max, y_max]).astype(np.int32) + return bbox + + +def get_aug_mask(body_mask, w_len=10, h_len=20): + body_bbox = get_mask_boxes(body_mask) + + bbox_wh = body_bbox[2:4] - body_bbox[0:2] + w_slice = np.int32(bbox_wh[0] / w_len) + h_slice = np.int32(bbox_wh[1] / h_len) + + for each_w in range(body_bbox[0], body_bbox[2], w_slice): + w_start = min(each_w, body_bbox[2]) + w_end = min((each_w + w_slice), body_bbox[2]) + # print(w_start, w_end) + for each_h in range(body_bbox[1], body_bbox[3], h_slice): + h_start = min(each_h, body_bbox[3]) + h_end = min((each_h + h_slice), body_bbox[3]) + if body_mask[h_start:h_end, w_start:w_end].sum() > 0: + body_mask[h_start:h_end, w_start:w_end] = 1 + + return body_mask + + +def get_mask_body_img(img_copy, hand_mask, k=7, iterations=1): + kernel = np.ones((k, k), np.uint8) + dilation = cv2.dilate(hand_mask, kernel, iterations=iterations) + mask_hand_img = img_copy * (1 - dilation[:, :, None]) + + return mask_hand_img, dilation + + +def get_face_bboxes(kp2ds, scale, image_shape, ratio_aug): + h, w = image_shape + kp2ds_face = kp2ds.copy()[23:91, :2] + + min_x, min_y = np.min(kp2ds_face, axis=0) + max_x, max_y = np.max(kp2ds_face, axis=0) + + initial_width = max_x - min_x + initial_height = max_y - min_y + + initial_area = initial_width * initial_height + + expanded_area = initial_area * scale + + new_width = np.sqrt(expanded_area * (initial_width / initial_height)) + new_height = np.sqrt(expanded_area * (initial_height / initial_width)) + + delta_width = (new_width - initial_width) / 2 + delta_height = (new_height - initial_height) / 4 + + if ratio_aug: + if random.random() > 0.5: + delta_width += random.uniform(0, initial_width // 10) + else: + delta_height += random.uniform(0, initial_height // 10) + + expanded_min_x = max(min_x - delta_width, 0) + expanded_max_x = min(max_x + delta_width, w) + expanded_min_y = max(min_y - 3 * delta_height, 0) + expanded_max_y = min(max_y + delta_height, h) + + return [ + int(expanded_min_x), + int(expanded_max_x), + int(expanded_min_y), + int(expanded_max_y), + ] + + +def calculate_new_size(orig_w, orig_h, target_area, divisor=64): + + target_ratio = orig_w / orig_h + + def check_valid(w, h): + + if w <= 0 or h <= 0: + return False + return w * h <= target_area and w % divisor == 0 and h % divisor == 0 + + def get_ratio_diff(w, h): + + return abs(w / h - target_ratio) + + def round_to_64(value, round_up=False, divisor=64): + + if round_up: + return divisor * ((value + (divisor - 1)) // divisor) + return divisor * (value // divisor) + + possible_sizes = [] + + max_area_h = int(np.sqrt(target_area / target_ratio)) + max_area_w = int(max_area_h * target_ratio) + + max_h = round_to_64(max_area_h, round_up=True, divisor=divisor) + max_w = round_to_64(max_area_w, round_up=True, divisor=divisor) + + for h in range(divisor, max_h + divisor, divisor): + ideal_w = h * target_ratio + + w_down = round_to_64(ideal_w) + w_up = round_to_64(ideal_w, round_up=True) + + for w in [w_down, w_up]: + if check_valid(w, h, divisor): + possible_sizes.append((w, h, get_ratio_diff(w, h))) + + if not possible_sizes: + raise ValueError("Can not find suitable size") + + possible_sizes.sort(key=lambda x: (-x[0] * x[1], x[2])) + + best_w, best_h, _ = possible_sizes[0] + return int(best_w), int(best_h) + + +def resize_by_area( + image, target_area, keep_aspect_ratio=True, divisor=64, padding_color=(0, 0, 0) +): + h, w = image.shape[:2] + try: + new_w, new_h = calculate_new_size(w, h, target_area, divisor) + except: + aspect_ratio = w / h + + if keep_aspect_ratio: + new_h = math.sqrt(target_area / aspect_ratio) + new_w = target_area / new_h + else: + new_w = new_h = math.sqrt(target_area) + + new_w, new_h = int((new_w // divisor) * divisor), int( + (new_h // divisor) * divisor + ) + + interpolation = cv2.INTER_AREA if (new_w * new_h < w * h) else cv2.INTER_LINEAR + + resized_image = padding_resize( + image, + height=new_h, + width=new_w, + padding_color=padding_color, + interpolation=interpolation, + ) + return resized_image + + +def padding_resize( + img_ori, + height=512, + width=512, + padding_color=(0, 0, 0), + interpolation=cv2.INTER_LINEAR, +): + ori_height = img_ori.shape[0] + ori_width = img_ori.shape[1] + channel = img_ori.shape[2] + + img_pad = np.zeros((height, width, channel)) + if channel == 1: + img_pad[:, :, 0] = padding_color[0] + else: + img_pad[:, :, 0] = padding_color[0] + img_pad[:, :, 1] = padding_color[1] + img_pad[:, :, 2] = padding_color[2] + + if (ori_height / ori_width) > (height / width): + new_width = int(height / ori_height * ori_width) + img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation) + padding = int((width - new_width) / 2) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img_pad[:, padding : padding + new_width, :] = img + else: + new_height = int(width / ori_width * ori_height) + img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation) + padding = int((height - new_height) / 2) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img_pad[padding : padding + new_height, :, :] = img + + img_pad = np.uint8(img_pad) + + return img_pad + + +def get_frame_indices(frame_num, video_fps, clip_length, train_fps): + + start_frame = 0 + times = np.arange(0, clip_length) / train_fps + frame_indices = start_frame + np.round(times * video_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, frame_num - 1) + + return frame_indices.tolist() + + +def get_face_bboxes(kp2ds, scale, image_shape): + h, w = image_shape + kp2ds_face = kp2ds.copy()[1:] * (w, h) + + min_x, min_y = np.min(kp2ds_face, axis=0) + max_x, max_y = np.max(kp2ds_face, axis=0) + + initial_width = max_x - min_x + initial_height = max_y - min_y + + initial_area = initial_width * initial_height + + expanded_area = initial_area * scale + + new_width = np.sqrt(expanded_area * (initial_width / initial_height)) + new_height = np.sqrt(expanded_area * (initial_height / initial_width)) + + delta_width = (new_width - initial_width) / 2 + delta_height = (new_height - initial_height) / 4 + + expanded_min_x = max(min_x - delta_width, 0) + expanded_max_x = min(max_x + delta_width, w) + expanded_min_y = max(min_y - 3 * delta_height, 0) + expanded_max_y = min(max_y + delta_height, h) + + return [ + int(expanded_min_x), + int(expanded_max_x), + int(expanded_min_y), + int(expanded_max_y), + ] diff --git a/wan/modules/animate/preprocess/video_predictor.py b/wan/modules/animate/preprocess/video_predictor.py new file mode 100644 index 00000000..ce643b2f --- /dev/null +++ b/wan/modules/animate/preprocess/video_predictor.py @@ -0,0 +1,158 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +# A wrapper for sam2 functions +from collections import OrderedDict + +import torch +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.sam2_video_predictor import SAM2VideoPredictor as _SAM2VideoPredictor +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores + +from .sam_utils import load_video_frames, load_video_frames_v2 + + +class SAM2VideoPredictor(_SAM2VideoPredictor): + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + frame_names=None, + ): + """Initialize a inference state.""" + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + frame_names=frame_names, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = torch.device("cuda") + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = torch.device("cuda") + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + @torch.inference_mode() + def init_state_v2( + self, + frames, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + frame_names=None, + ): + """Initialize a inference state.""" + images, video_height, video_width = load_video_frames_v2( + frames=frames, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + frame_names=frame_names, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = torch.device("cuda") + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = torch.device("cuda") + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state diff --git a/wan/modules/animate/xlm_roberta.py b/wan/modules/animate/xlm_roberta.py new file mode 100644 index 00000000..47728fc7 --- /dev/null +++ b/wan/modules/animate/xlm_roberta.py @@ -0,0 +1,175 @@ +# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ["XLMRoberta", "xlm_roberta_large"] + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + + # compute attention + p = self.dropout.p if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, mask, p) + x = x.permute(0, 2, 1, 3).reshape(b, s, c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.post_norm = post_norm + self.eps = eps + + # layers + self.attn = SelfAttention(dim, num_heads, dropout, eps) + self.norm1 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 4), + nn.GELU(), + nn.Linear(dim * 4, dim), + nn.Dropout(dropout), + ) + self.norm2 = nn.LayerNorm(dim, eps=eps) + + def forward(self, x, mask): + if self.post_norm: + x = self.norm1(x + self.attn(x, mask)) + x = self.norm2(x + self.ffn(x)) + else: + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XLMRoberta(nn.Module): + """ + XLMRobertaModel with no pooler and no LM head. + """ + + def __init__( + self, + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5, + ): + super().__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.type_size = type_size + self.pad_id = pad_id + self.dim = dim + self.num_heads = num_heads + self.num_layers = num_layers + self.post_norm = post_norm + self.eps = eps + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) + self.type_embedding = nn.Embedding(type_size, dim) + self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) + self.dropout = nn.Dropout(dropout) + + # blocks + self.blocks = nn.ModuleList( + [ + AttentionBlock(dim, num_heads, post_norm, dropout, eps) + for _ in range(num_layers) + ] + ) + + # norm layer + self.norm = nn.LayerNorm(dim, eps=eps) + + def forward(self, ids): + """ + ids: [B, L] of torch.LongTensor. + """ + b, s = ids.shape + mask = ids.ne(self.pad_id).long() + + # embeddings + x = ( + self.token_embedding(ids) + + self.type_embedding(torch.zeros_like(ids)) + + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) + ) + if self.post_norm: + x = self.norm(x) + x = self.dropout(x) + + # blocks + mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min) + for block in self.blocks: + x = block(x, mask) + + # output + if not self.post_norm: + x = self.norm(x) + return x + + +def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs): + """ + XLMRobertaLarge adapted from Huggingface. + """ + # params + cfg = dict( + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5, + ) + cfg.update(**kwargs) + + # init a model on device + with torch.device(device): + model = XLMRoberta(**cfg) + return model diff --git a/wan/modules/s2v/__init__.py b/wan/modules/s2v/__init__.py new file mode 100644 index 00000000..19c2970f --- /dev/null +++ b/wan/modules/s2v/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from .audio_encoder import AudioEncoder +from .model_s2v import WanModel_S2V + +__all__ = ["WanModel_S2V", "AudioEncoder"] diff --git a/wan/modules/s2v/audio_encoder.py b/wan/modules/s2v/audio_encoder.py new file mode 100644 index 00000000..937be15b --- /dev/null +++ b/wan/modules/s2v/audio_encoder.py @@ -0,0 +1,193 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math + +import librosa +import numpy as np +import torch +import torch.nn.functional as F +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + + +def get_sample_indices( + original_fps, total_frames, target_fps, num_sample, fixed_start=None +): + required_duration = num_sample / target_fps + required_origin_frames = int(np.ceil(required_duration * original_fps)) + if required_duration > total_frames / original_fps: + raise ValueError("required_duration must be less than video length") + + if not fixed_start is None and fixed_start >= 0: + start_frame = fixed_start + else: + max_start = total_frames - required_origin_frames + if max_start < 0: + raise ValueError("video length is too short") + start_frame = np.random.randint(0, max_start + 1) + start_time = start_frame / original_fps + + end_time = start_time + required_duration + time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) + + frame_indices = np.round(np.array(time_points) * original_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, total_frames - 1) + return frame_indices + + +def linear_interpolation(features, input_fps, output_fps, output_len=None): + """ + features: shape=[1, T, 512] + input_fps: fps for audio, f_a + output_fps: fps for video, f_m + output_len: video length + """ + features = features.transpose(1, 2) # [1, 512, T] + seq_len = features.shape[2] / float(input_fps) # T/f_a + if output_len is None: + output_len = int(seq_len * output_fps) # f_m*T/f_a + output_features = F.interpolate( + features, size=output_len, align_corners=True, mode="linear" + ) # [1, 512, output_len] + return output_features.transpose(1, 2) # [1, output_len, 512] + + +class AudioEncoder: + + def __init__(self, device="cpu", model_id="facebook/wav2vec2-base-960h"): + # load pretrained model + self.processor = Wav2Vec2Processor.from_pretrained(model_id) + self.model = Wav2Vec2ForCTC.from_pretrained(model_id) + + self.model = self.model.to(device) + + self.video_rate = 30 + + def extract_audio_feat( + self, audio_path, return_all_layers=False, dtype=torch.float32 + ): + audio_input, sample_rate = librosa.load(audio_path, sr=16000) + + input_values = self.processor( + audio_input, sampling_rate=sample_rate, return_tensors="pt" + ).input_values + + # INFERENCE + + # retrieve logits & take argmax + res = self.model(input_values.to(self.model.device), output_hidden_states=True) + if return_all_layers: + feat = torch.cat(res.hidden_states) + else: + feat = res.hidden_states[-1] + feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate) + + z = feat.to(dtype) # Encoding for the motion + return z + + def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1 + + bucket_num = min_batch_num * batch_frames + batch_idx = [stride * i for i in range(bucket_num)] + batch_audio_eb = [] + for bi in batch_idx: + if bi < audio_frame_num: + audio_sample_stride = 2 + chosen_idx = list( + range( + bi - m * audio_sample_stride, + bi + (m + 1) * audio_sample_stride, + audio_sample_stride, + ) + ) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [ + audio_frame_num - 1 if c >= audio_frame_num else c + for c in chosen_idx + ] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten( + start_dim=-2, end_dim=-1 + ) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = ( + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) + if not return_all_layers + else torch.zeros( + [num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device + ) + ) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + scale = self.video_rate / fps + + min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 + + bucket_num = min_batch_num * batch_frames + padd_audio_num = ( + math.ceil(min_batch_num * batch_frames / fps * self.video_rate) + - audio_frame_num + ) + batch_idx = get_sample_indices( + original_fps=self.video_rate, + total_frames=audio_frame_num + padd_audio_num, + target_fps=fps, + num_sample=bucket_num, + fixed_start=0, + ) + batch_audio_eb = [] + audio_sample_stride = int(self.video_rate / fps) + for bi in batch_idx: + if bi < audio_frame_num: + + chosen_idx = list( + range( + bi - m * audio_sample_stride, + bi + (m + 1) * audio_sample_stride, + audio_sample_stride, + ) + ) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [ + audio_frame_num - 1 if c >= audio_frame_num else c + for c in chosen_idx + ] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten( + start_dim=-2, end_dim=-1 + ) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = ( + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) + if not return_all_layers + else torch.zeros( + [num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device + ) + ) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num diff --git a/wan/modules/s2v/audio_utils.py b/wan/modules/s2v/audio_utils.py new file mode 100644 index 00000000..07428b7a --- /dev/null +++ b/wan/modules/s2v/audio_utils.py @@ -0,0 +1,119 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +from typing import Tuple, Union + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.models.attention import AdaLayerNorm + +from ..model import WanAttentionBlock, WanCrossAttention +from .auxi_blocks import MotionEncoder_tc + + +class CausalAudioEncoder(nn.Module): + + def __init__( + self, + dim=5120, + num_layers=25, + out_dim=2048, + video_rate=8, + num_token=4, + need_global=False, + ): + super().__init__() + self.encoder = MotionEncoder_tc( + in_dim=dim, hidden_dim=out_dim, num_heads=num_token, need_global=need_global + ) + weight = torch.ones((1, num_layers, 1, 1)) * 0.01 + + self.weights = torch.nn.Parameter(weight) + self.act = torch.nn.SiLU() + + def forward(self, features): + with amp.autocast(dtype=torch.float32): + # features B * num_layers * dim * video_length + weights = self.act(self.weights) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f + weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim + res = self.encoder(weighted_feat) # b f n dim + + return res # b f n dim + + +class AudioCrossAttention(WanCrossAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class AudioInjector_WAN(nn.Module): + + def __init__( + self, + all_modules, + all_modules_names, + dim=2048, + num_heads=32, + inject_layer=[0, 27], + root_net=None, + enable_adain=False, + adain_dim=2048, + need_adain_ont=False, + ): + super().__init__() + num_injector_layers = len(inject_layer) + self.injected_block_id = {} + audio_injector_id = 0 + for mod_name, mod in zip(all_modules_names, all_modules): + if isinstance(mod, WanAttentionBlock): + for inject_id in inject_layer: + if f"transformer_blocks.{inject_id}" in mod_name: + self.injected_block_id[inject_id] = audio_injector_id + audio_injector_id += 1 + + self.injector = nn.ModuleList( + [ + AudioCrossAttention( + dim=dim, + num_heads=num_heads, + qk_norm=True, + ) + for _ in range(audio_injector_id) + ] + ) + self.injector_pre_norm_feat = nn.ModuleList( + [ + nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) + for _ in range(audio_injector_id) + ] + ) + self.injector_pre_norm_vec = nn.ModuleList( + [ + nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) + for _ in range(audio_injector_id) + ] + ) + if enable_adain: + self.injector_adain_layers = nn.ModuleList( + [ + AdaLayerNorm( + output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1 + ) + for _ in range(audio_injector_id) + ] + ) + if need_adain_ont: + self.injector_adain_output_layers = nn.ModuleList( + [nn.Linear(dim, dim) for _ in range(audio_injector_id)] + ) diff --git a/wan/modules/s2v/auxi_blocks.py b/wan/modules/s2v/auxi_blocks.py new file mode 100644 index 00000000..b275d109 --- /dev/null +++ b/wan/modules/s2v/auxi_blocks.py @@ -0,0 +1,239 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import importlib.metadata +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models import ModelMixin +from diffusers.utils import is_torch_version, logging +from einops import rearrange + +try: + from flash_attn import flash_attn_func, flash_attn_qkvpacked_func +except ImportError: + flash_attn_func = None + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="flash", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal + ) + elif mode == "flash": + x = flash_attn_func( + q, + k, + v, + ) + # x with shape [(bxs), a, d] + x = x.view( + batch_size, max_seqlen_q, x.shape[-2], x.shape[-1] + ) # reshape x to [b, s, a, d] + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert ( + attn_mask is None + ), "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril( + diagonal=0 + ) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + # TODO: Maybe force q and k to be float32 to avoid numerical overflow + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + + def __init__( + self, + chan_in, + chan_out, + kernel_size=3, + stride=1, + dilation=1, + pad_mode="replicate", + **kwargs, + ): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d( + chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs + ) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class MotionEncoder_tc(nn.Module): + + def __init__( + self, + in_dim: int, + hidden_dim: int, + num_heads=int, + need_global=True, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.need_global = need_global + self.conv1_local = CausalConv1d( + in_dim, hidden_dim // 4 * num_heads, 3, stride=1 + ) + if need_global: + self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1) + self.norm1 = nn.LayerNorm( + hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) + self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) + + if need_global: + self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs) + + self.norm1 = nn.LayerNorm( + hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + self.norm2 = nn.LayerNorm( + hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + self.norm3 = nn.LayerNorm( + hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + x = rearrange(x, "b t c -> b c t") + x_ori = x.clone() + b, c, t = x.shape + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + if not self.need_global: + return x_local + + x = self.conv1_global(x_ori) + x = rearrange(x, "b c t -> b t c") + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.final_linear(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + + return x, x_local diff --git a/wan/modules/s2v/model_s2v.py b/wan/modules/s2v/model_s2v.py new file mode 100644 index 00000000..25749acb --- /dev/null +++ b/wan/modules/s2v/model_s2v.py @@ -0,0 +1,964 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +import types +from copy import deepcopy + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from einops import rearrange + +from ...distributed.sequence_parallel import ( + distributed_attention, + gather_forward, + get_rank, + get_world_size, +) +from ..model import ( + Head, + WanAttentionBlock, + WanLayerNorm, + WanModel, + WanSelfAttention, + flash_attention, + rope_params, + sinusoidal_embedding_1d, +) +from .audio_utils import AudioInjector_WAN, CausalAudioEncoder +from .motioner import FramePackMotioner, MotionerTransformers +from .s2v_utils import rope_precompute + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def torch_dfs(model: nn.Module, parent_name="root"): + module_names, modules = [], [] + current_name = parent_name if parent_name else "root" + module_names.append(current_name) + modules.append(model) + + for name, child in model.named_children(): + if parent_name: + child_name = f"{parent_name}.{name}" + else: + child_name = name + child_modules, child_names = torch_dfs(child, child_name) + module_names += child_names + modules += child_modules + return modules, module_names + + +@amp.autocast(enabled=False) +def rope_apply(x, grid_sizes, freqs, start=None): + n, c = x.size(2), x.size(3) // 2 + # loop over samples + output = [] + for i, _ in enumerate(x): + s = x.size(1) + x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2)) + freqs_i = freqs[i, :s] + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, s:]]) + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +@amp.autocast(enabled=False) +def rope_apply_usp(x, grid_sizes, freqs): + s, n, c = x.size(1), x.size(2), x.size(3) // 2 + # loop over samples + output = [] + for i, _ in enumerate(x): + s = x.size(1) + # precompute multipliers + x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2)) + freqs_i = freqs[i] + freqs_i_rank = freqs_i + x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) + x_i = torch.cat([x_i, x[i, s:]]) + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +def sp_attn_forward_s2v(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + half_dtypes = (torch.float16, torch.bfloat16) + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + q = rope_apply_usp(q, grid_sizes, freqs) + k = rope_apply_usp(k, grid_sizes, freqs) + + x = distributed_attention( + half(q), + half(k), + half(v), + seq_lens, + window_size=self.window_size, + ) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class Head_S2V(Head): + + def forward(self, x, e): + """ + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, L1, C] + """ + assert e.dtype == torch.float32 + with amp.autocast(dtype=torch.float32): + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) + return x + + +class WanS2VSelfAttention(WanSelfAttention): + + def forward(self, x, seq_lens, grid_sizes, freqs): + """ + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + x = flash_attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size, + ) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanS2VAttentionBlock(WanAttentionBlock): + + def __init__( + self, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + ): + super().__init__( + dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps + ) + self.self_attn = WanS2VSelfAttention(dim, num_heads, window_size, qk_norm, eps) + + def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens): + assert e[0].dtype == torch.float32 + seg_idx = e[1].item() + seg_idx = min(max(0, seg_idx), x.size(1)) + seg_idx = [0, seg_idx, x.size(1)] + e = e[0] + modulation = self.modulation.unsqueeze(2) + with amp.autocast(dtype=torch.float32): + e = (modulation + e).chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + e = [element.squeeze(1) for element in e] + norm_x = self.norm1(x).float() + parts = [] + for i in range(2): + parts.append( + norm_x[:, seg_idx[i] : seg_idx[i + 1]] * (1 + e[1][:, i : i + 1]) + + e[0][:, i : i + 1] + ) + norm_x = torch.cat(parts, dim=1) + # self-attention + y = self.self_attn(norm_x, seq_lens, grid_sizes, freqs) + with amp.autocast(dtype=torch.float32): + z = [] + for i in range(2): + z.append(y[:, seg_idx[i] : seg_idx[i + 1]] * e[2][:, i : i + 1]) + y = torch.cat(z, dim=1) + x = x + y + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + x = x + self.cross_attn(self.norm3(x), context, context_lens) + norm2_x = self.norm2(x).float() + parts = [] + for i in range(2): + parts.append( + norm2_x[:, seg_idx[i] : seg_idx[i + 1]] * (1 + e[4][:, i : i + 1]) + + e[3][:, i : i + 1] + ) + norm2_x = torch.cat(parts, dim=1) + y = self.ffn(norm2_x) + with amp.autocast(dtype=torch.float32): + z = [] + for i in range(2): + z.append(y[:, seg_idx[i] : seg_idx[i + 1]] * e[5][:, i : i + 1]) + y = torch.cat(z, dim=1) + x = x + y + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class WanModel_S2V(ModelMixin, ConfigMixin): + ignore_for_config = [ + "args", + "kwargs", + "patch_size", + "cross_attn_norm", + "qk_norm", + "text_dim", + "window_size", + ] + _no_split_modules = ["WanS2VAttentionBlock"] + + @register_to_config + def __init__( + self, + cond_dim=0, + audio_dim=5120, + num_audio_token=4, + enable_adain=False, + adain_mode="attn_norm", + audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27], + zero_init=False, + zero_timestep=False, + enable_motioner=True, + add_last_motion=True, + enable_tsm=False, + trainable_token_pos_emb=False, + motion_token_num=1024, + enable_framepack=False, # Mutually exclusive with enable_motioner + framepack_drop_mode="drop", + model_type="s2v", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + *args, + **kwargs, + ): + super().__init__() + + assert model_type == "s2v" + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size + ) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim) + ) + + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + self.blocks = nn.ModuleList( + [ + WanS2VAttentionBlock( + dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps + ) + for _ in range(num_layers) + ] + ) + + # head + self.head = Head_S2V(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + dim=1, + ) + + # initialize weights + self.init_weights() + + self.use_context_parallel = False # will modify in _configure_model func + + if cond_dim > 0: + self.cond_encoder = nn.Conv3d( + cond_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size + ) + self.enbale_adain = enable_adain + self.casual_audio_encoder = CausalAudioEncoder( + dim=audio_dim, + out_dim=self.dim, + num_token=num_audio_token, + need_global=enable_adain, + ) + all_modules, all_modules_names = torch_dfs( + self.blocks, parent_name="root.transformer_blocks" + ) + self.audio_injector = AudioInjector_WAN( + all_modules, + all_modules_names, + dim=self.dim, + num_heads=self.num_heads, + inject_layer=audio_inject_layers, + root_net=self, + enable_adain=enable_adain, + adain_dim=self.dim, + need_adain_ont=adain_mode != "attn_norm", + ) + self.adain_mode = adain_mode + + self.trainable_cond_mask = nn.Embedding(3, self.dim) + + if zero_init: + self.zero_init_weights() + + self.zero_timestep = ( + zero_timestep # Whether to assign 0 value timestep to ref/motion + ) + + # init motioner + if enable_motioner and enable_framepack: + raise ValueError( + "enable_motioner and enable_framepack are mutually exclusive, please set one of them to False" + ) + self.enable_motioner = enable_motioner + self.add_last_motion = add_last_motion + if enable_motioner: + motioner_dim = 2048 + self.motioner = MotionerTransformers( + patch_size=(2, 4, 4), + dim=motioner_dim, + ffn_dim=motioner_dim, + freq_dim=256, + out_dim=16, + num_heads=16, + num_layers=13, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + motion_token_num=motion_token_num, + enable_tsm=enable_tsm, + motion_stride=4, + expand_ratio=2, + trainable_token_pos_emb=trainable_token_pos_emb, + ) + self.zip_motion_out = torch.nn.Sequential( + WanLayerNorm(motioner_dim), + zero_module(nn.Linear(motioner_dim, self.dim)), + ) + + self.trainable_token_pos_emb = trainable_token_pos_emb + if trainable_token_pos_emb: + d = self.dim // self.num_heads + x = torch.zeros([1, motion_token_num, self.num_heads, d]) + x[..., ::2] = 1 + + gride_sizes = [ + [ + torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor( + [ + 1, + self.motioner.motion_side_len, + self.motioner.motion_side_len, + ] + ) + .unsqueeze(0) + .repeat(1, 1), + torch.tensor( + [ + 1, + self.motioner.motion_side_len, + self.motioner.motion_side_len, + ] + ) + .unsqueeze(0) + .repeat(1, 1), + ] + ] + token_freqs = rope_apply(x, gride_sizes, self.freqs) + token_freqs = token_freqs[0, :, 0].reshape(motion_token_num, -1, 2) + token_freqs = token_freqs * 0.01 + self.token_freqs = torch.nn.Parameter(token_freqs) + + self.enable_framepack = enable_framepack + if enable_framepack: + self.frame_packer = FramePackMotioner( + inner_dim=self.dim, + num_heads=self.num_heads, + zip_frame_buckets=[1, 2, 16], + drop_mode=framepack_drop_mode, + ) + + def zero_init_weights(self): + with torch.no_grad(): + self.trainable_cond_mask = zero_module(self.trainable_cond_mask) + if hasattr(self, "cond_encoder"): + self.cond_encoder = zero_module(self.cond_encoder) + + for i in range(self.audio_injector.injector.__len__()): + self.audio_injector.injector[i].o = zero_module( + self.audio_injector.injector[i].o + ) + if self.enbale_adain: + self.audio_injector.injector_adain_layers[i].linear = zero_module( + self.audio_injector.injector_adain_layers[i].linear + ) + + def process_motion(self, motion_latents, drop_motion_frames=False): + if drop_motion_frames or motion_latents[0].shape[1] == 0: + return [], [] + self.lat_motion_frames = motion_latents[0].shape[1] + mot = [self.patch_embedding(m.unsqueeze(0)) for m in motion_latents] + batch_size = len(mot) + + mot_remb = [] + flattern_mot = [] + for bs in range(batch_size): + height, width = mot[bs].shape[3], mot[bs].shape[4] + flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous() + motion_grid_sizes = [ + [ + torch.tensor([-self.lat_motion_frames, 0, 0]) + .unsqueeze(0) + .repeat(1, 1), + torch.tensor([0, height, width]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.lat_motion_frames, height, width]) + .unsqueeze(0) + .repeat(1, 1), + ] + ] + motion_rope_emb = rope_precompute( + flat_mot.detach().view( + 1, flat_mot.shape[1], self.num_heads, self.dim // self.num_heads + ), + motion_grid_sizes, + self.freqs, + start=None, + ) + mot_remb.append(motion_rope_emb) + flattern_mot.append(flat_mot) + return flattern_mot, mot_remb + + def process_motion_frame_pack( + self, motion_latents, drop_motion_frames=False, add_last_motion=2 + ): + flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion) + if drop_motion_frames: + return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb] + else: + return flattern_mot, mot_remb + + def process_motion_transformer_motioner( + self, motion_latents, drop_motion_frames=False, add_last_motion=True + ): + batch_size, height, width = ( + len(motion_latents), + motion_latents[0].shape[2] // self.patch_size[1], + motion_latents[0].shape[3] // self.patch_size[2], + ) + + freqs = self.freqs + device = self.patch_embedding.weight.device + if freqs.device != device: + freqs = freqs.to(device) + if self.trainable_token_pos_emb: + with amp.autocast(dtype=torch.float64): + token_freqs = self.token_freqs.to(torch.float64) + token_freqs = token_freqs / token_freqs.norm(dim=-1, keepdim=True) + freqs = [freqs, torch.view_as_complex(token_freqs)] + + if not drop_motion_frames and add_last_motion: + last_motion_latent = [u[:, -1:] for u in motion_latents] + last_mot = [ + self.patch_embedding(m.unsqueeze(0)) for m in last_motion_latent + ] + last_mot = [m.flatten(2).transpose(1, 2) for m in last_mot] + last_mot = torch.cat(last_mot) + gride_sizes = [ + [ + torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor([0, height, width]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1), + ] + ] + else: + last_mot = torch.zeros( + [batch_size, 0, self.dim], + device=motion_latents[0].device, + dtype=motion_latents[0].dtype, + ) + gride_sizes = [] + + zip_motion = self.motioner(motion_latents) + zip_motion = self.zip_motion_out(zip_motion) + if drop_motion_frames: + zip_motion = zip_motion * 0.0 + zip_motion_grid_sizes = [ + [ + torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor( + [0, self.motioner.motion_side_len, self.motioner.motion_side_len] + ) + .unsqueeze(0) + .repeat(batch_size, 1), + torch.tensor( + [1 if not self.trainable_token_pos_emb else -1, height, width] + ) + .unsqueeze(0) + .repeat(batch_size, 1), + ] + ] + + mot = torch.cat([last_mot, zip_motion], dim=1) + gride_sizes = gride_sizes + zip_motion_grid_sizes + + motion_rope_emb = rope_precompute( + mot.detach().view( + batch_size, mot.shape[1], self.num_heads, self.dim // self.num_heads + ), + gride_sizes, + freqs, + start=None, + ) + return [m.unsqueeze(0) for m in mot], [r.unsqueeze(0) for r in motion_rope_emb] + + def inject_motion( + self, + x, + seq_lens, + rope_embs, + mask_input, + motion_latents, + drop_motion_frames=False, + add_last_motion=True, + ): + # inject the motion frames token to the hidden states + if self.enable_motioner: + mot, mot_remb = self.process_motion_transformer_motioner( + motion_latents, + drop_motion_frames=drop_motion_frames, + add_last_motion=add_last_motion, + ) + elif self.enable_framepack: + mot, mot_remb = self.process_motion_frame_pack( + motion_latents, + drop_motion_frames=drop_motion_frames, + add_last_motion=add_last_motion, + ) + else: + mot, mot_remb = self.process_motion( + motion_latents, drop_motion_frames=drop_motion_frames + ) + + if len(mot) > 0: + x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)] + seq_lens = seq_lens + torch.tensor( + [r.size(1) for r in mot], dtype=torch.long + ) + rope_embs = [torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb)] + mask_input = [ + torch.cat( + [ + m, + 2 + * torch.ones( + [1, u.shape[1] - m.shape[1]], device=m.device, dtype=m.dtype + ), + ], + dim=1, + ) + for m, u in zip(mask_input, x) + ] + return x, seq_lens, rope_embs, mask_input + + def after_transformer_block(self, block_idx, hidden_states): + if block_idx in self.audio_injector.injected_block_id.keys(): + audio_attn_id = self.audio_injector.injected_block_id[block_idx] + audio_emb = self.merged_audio_emb # b f n c + num_frames = audio_emb.shape[1] + + if self.use_context_parallel: + hidden_states = gather_forward(hidden_states, dim=1) + + input_hidden_states = hidden_states[ + :, : self.original_seq_len + ].clone() # b (f h w) c + input_hidden_states = rearrange( + input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames + ) + + if self.enbale_adain and self.adain_mode == "attn_norm": + audio_emb_global = self.audio_emb_global + audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c") + adain_hidden_states = self.audio_injector.injector_adain_layers[ + audio_attn_id + ](input_hidden_states, temb=audio_emb_global[:, 0]) + attn_hidden_states = adain_hidden_states + else: + attn_hidden_states = self.audio_injector.injector_pre_norm_feat[ + audio_attn_id + ](input_hidden_states) + audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames) + attn_audio_emb = audio_emb + residual_out = self.audio_injector.injector[audio_attn_id]( + x=attn_hidden_states, + context=attn_audio_emb, + context_lens=torch.ones( + attn_hidden_states.shape[0], + dtype=torch.long, + device=attn_hidden_states.device, + ) + * attn_audio_emb.shape[1], + ) + residual_out = rearrange( + residual_out, "(b t) n c -> b (t n) c", t=num_frames + ) + hidden_states[:, : self.original_seq_len] = ( + hidden_states[:, : self.original_seq_len] + residual_out + ) + + if self.use_context_parallel: + hidden_states = torch.chunk(hidden_states, get_world_size(), dim=1)[ + get_rank() + ] + + return hidden_states + + def forward( + self, + x, + t, + context, + seq_len, + ref_latents, + motion_latents, + cond_states, + audio_input=None, + motion_frames=[17, 5], + add_last_motion=2, + drop_motion_frames=False, + *extra_args, + **extra_kwargs, + ): + """ + x: A list of videos each with shape [C, T, H, W]. + t: [B]. + context: A list of text embeddings each with shape [L, C]. + seq_len: A list of video token lens, no need for this model. + ref_latents A list of reference image for each video with shape [C, 1, H, W]. + motion_latents A list of motion frames for each video with shape [C, T_m, H, W]. + cond_states A list of condition frames (i.e. pose) each with shape [C, T, H, W]. + audio_input The input audio embedding [B, num_wav2vec_layer, C_a, T_a]. + motion_frames The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5] + add_last_motion For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added. + For frame packing, the behavior depends on the value of add_last_motion: + add_last_motion = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included. + add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included. + add_last_motion = 2: All motion-related latents are used. + drop_motion_frames Bool, whether drop the motion frames info + """ + add_last_motion = self.add_last_motion * add_last_motion + audio_input = torch.cat( + [audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], + dim=-1, + ) + audio_emb_res = self.casual_audio_encoder(audio_input) + if self.enbale_adain: + audio_emb_global, audio_emb = audio_emb_res + self.audio_emb_global = audio_emb_global[:, motion_frames[1] :].clone() + else: + audio_emb = audio_emb_res + self.merged_audio_emb = audio_emb[:, motion_frames[1] :, :] + + device = self.patch_embedding.weight.device + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + # cond states + cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states] + x = [x_ + pose for x_, pose in zip(x, cond)] + + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x] + ) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + + original_grid_sizes = deepcopy(grid_sizes) + grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]] + + # ref and motion + self.lat_motion_frames = motion_latents[0].shape[1] + + ref = [self.patch_embedding(r.unsqueeze(0)) for r in ref_latents] + batch_size = len(ref) + height, width = ref[0].shape[3], ref[0].shape[4] + ref_grid_sizes = [ + [ + torch.tensor([30, 0, 0]) + .unsqueeze(0) + .repeat(batch_size, 1), # the start index + torch.tensor([31, height, width]) + .unsqueeze(0) + .repeat(batch_size, 1), # the end index + torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1), + ] # the range + ] + + ref = [r.flatten(2).transpose(1, 2) for r in ref] # r: 1 c f h w + self.original_seq_len = seq_lens[0] + + seq_lens = seq_lens + torch.tensor([r.size(1) for r in ref], dtype=torch.long) + + grid_sizes = grid_sizes + ref_grid_sizes + + x = [torch.cat([u, r], dim=1) for u, r in zip(x, ref)] + + # Initialize masks to indicate noisy latent, ref latent, and motion latent. + # However, at this point, only the first two (noisy and ref latents) are marked; + # the marking of motion latent will be implemented inside `inject_motion`. + mask_input = [ + torch.zeros([1, u.shape[1]], dtype=torch.long, device=x[0].device) + for u in x + ] + for i in range(len(mask_input)): + mask_input[i][:, self.original_seq_len :] = 1 + + # compute the rope embeddings for the input + x = torch.cat(x) + b, s, n, d = x.size(0), x.size(1), self.num_heads, self.dim // self.num_heads + self.pre_compute_freqs = rope_precompute( + x.detach().view(b, s, n, d), grid_sizes, self.freqs, start=None + ) + + x = [u.unsqueeze(0) for u in x] + self.pre_compute_freqs = [u.unsqueeze(0) for u in self.pre_compute_freqs] + + x, seq_lens, self.pre_compute_freqs, mask_input = self.inject_motion( + x, + seq_lens, + self.pre_compute_freqs, + mask_input, + motion_latents, + drop_motion_frames=drop_motion_frames, + add_last_motion=add_last_motion, + ) + + x = torch.cat(x, dim=0) + self.pre_compute_freqs = torch.cat(self.pre_compute_freqs, dim=0) + mask_input = torch.cat(mask_input, dim=0) + + x = x + self.trainable_cond_mask(mask_input).to(x.dtype) + + # time embeddings + if self.zero_timestep: + t = torch.cat([t, torch.zeros([1], dtype=t.dtype, device=t.device)]) + with amp.autocast(dtype=torch.float32): + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + if self.zero_timestep: + e = e[:-1] + zero_e0 = e0[-1:] + e0 = e0[:-1] + token_len = x.shape[1] + e0 = torch.cat( + [e0.unsqueeze(2), zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)], + dim=2, + ) + e0 = [e0, self.original_seq_len] + else: + e0 = e0.unsqueeze(2).repeat(1, 1, 2, 1) + e0 = [e0, 0] + + # context + context_lens = None + context = self.text_embedding( + torch.stack( + [ + torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ] + ) + ) + + # grad ckpt args + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs, **kwargs): + if return_dict is not None: + return module(*inputs, **kwargs, return_dict=return_dict) + else: + return module(*inputs, **kwargs) + + return custom_forward + + if self.use_context_parallel: + # sharded tensors for long context attn + sp_rank = get_rank() + x = torch.chunk(x, get_world_size(), dim=1) + sq_size = [u.shape[1] for u in x] + sq_start_size = sum(sq_size[:sp_rank]) + x = x[sp_rank] + # Confirm the application range of the time embedding in e0[0] for each sequence: + # - For tokens before seg_id: apply e0[0][:, :, 0] + # - For tokens after seg_id: apply e0[0][:, :, 1] + sp_size = x.shape[1] + seg_idx = e0[1] - sq_start_size + e0[1] = seg_idx + + self.pre_compute_freqs = torch.chunk( + self.pre_compute_freqs, get_world_size(), dim=1 + ) + self.pre_compute_freqs = self.pre_compute_freqs[sp_rank] + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.pre_compute_freqs, + context=context, + context_lens=context_lens, + ) + for idx, block in enumerate(self.blocks): + x = block(x, **kwargs) + x = self.after_transformer_block(idx, x) + + # Context Parallel + if self.use_context_parallel: + x = gather_forward(x.contiguous(), dim=1) + # unpatchify + x = x[:, : self.original_seq_len] + # head + x = self.head(x, e) + x = self.unpatchify(x, original_grid_sizes) + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + """ + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[: math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum("fhwpqrc->cfphqwr", u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) diff --git a/wan/modules/s2v/motioner.py b/wan/modules/s2v/motioner.py new file mode 100644 index 00000000..ac563a79 --- /dev/null +++ b/wan/modules/s2v/motioner.py @@ -0,0 +1,865 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +from typing import Any, Dict, List, Literal, Optional, Union + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.utils import BaseOutput, is_torch_version +from einops import rearrange, repeat + +from ..model import flash_attention +from .s2v_utils import rope_precompute + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half)) + ) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@amp.autocast(enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)), + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +@amp.autocast(enabled=False) +def rope_apply(x, grid_sizes, freqs, start=None): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + if type(freqs) is list: + trainable_freqs = freqs[1] + freqs = freqs[0] + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + output = x.clone() + seq_bucket = [0] + if not type(grid_sizes) is list: + grid_sizes = [grid_sizes] + for g in grid_sizes: + if not type(g) is list: + g = [torch.zeros_like(g), g] + batch_size = g[0].shape[0] + for i in range(batch_size): + if start is None: + f_o, h_o, w_o = g[0][i] + else: + f_o, h_o, w_o = start[i] + + f, h, w = g[1][i] + t_f, t_h, t_w = g[2][i] + seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o + seq_len = int(seq_f * seq_h * seq_w) + if seq_len > 0: + if t_f > 0: + factor_f, factor_h, factor_w = ( + (t_f / seq_f).item(), + (t_h / seq_h).item(), + (t_w / seq_w).item(), + ) + + if f_o >= 0: + f_sam = ( + np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f) + .astype(int) + .tolist() + ) + else: + f_sam = ( + np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f) + .astype(int) + .tolist() + ) + h_sam = ( + np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h) + .astype(int) + .tolist() + ) + w_sam = ( + np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w) + .astype(int) + .tolist() + ) + + assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 + freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj() + freqs_0 = freqs_0.view(seq_f, 1, 1, -1) + + freqs_i = torch.cat( + [ + freqs_0.expand(seq_f, seq_h, seq_w, -1), + freqs[1][h_sam] + .view(1, seq_h, 1, -1) + .expand(seq_f, seq_h, seq_w, -1), + freqs[2][w_sam] + .view(1, 1, seq_w, -1) + .expand(seq_f, seq_h, seq_w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + elif t_f < 0: + freqs_i = trainable_freqs.unsqueeze(1) + # apply rotary embedding + # precompute multipliers + x_i = torch.view_as_complex( + x[i, seq_bucket[-1] : seq_bucket[-1] + seq_len] + .to(torch.float64) + .reshape(seq_len, n, -1, 2) + ) + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + output[i, seq_bucket[-1] : seq_bucket[-1] + seq_len] = x_i + seq_bucket.append(seq_bucket[-1] + seq_len) + return output.float() + + +class RMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class LayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, seq_lens, grid_sizes, freqs): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + x = flash_attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size, + ) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class SwinSelfAttention(SelfAttention): + + def forward(self, x, seq_lens, grid_sizes, freqs): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + assert b == 1, "Only support batch_size 1" + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + T, H, W = grid_sizes[0].tolist() + + q = rearrange(q, "b (t h w) n d -> (b t) (h w) n d", t=T, h=H, w=W) + k = rearrange(k, "b (t h w) n d -> (b t) (h w) n d", t=T, h=H, w=W) + v = rearrange(v, "b (t h w) n d -> (b t) (h w) n d", t=T, h=H, w=W) + + ref_q = q[-1:] + q = q[:-1] + + ref_k = repeat(k[-1:], "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d + k = k[:-1] + k = torch.cat([k[:1], k, k[-1:]]) + k = torch.cat([k[1:-1], k[2:], k[:-2], ref_k], dim=1) # (bt) (3hw) n d + + ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=v.shape[0] - 1) + v = v[:-1] + v = torch.cat([v[:1], v, v[-1:]]) + v = torch.cat([v[1:-1], v[2:], v[:-2], ref_v], dim=1) + + # q: b (t h w) n d + # k: b (t h w) n d + out = flash_attention( + q=q, + k=k, + v=v, + # k_lens=torch.tensor([k.shape[1]] * k.shape[0], device=x.device, dtype=torch.long), + window_size=self.window_size, + ) + out = torch.cat([out, ref_v[:1]], axis=0) + out = rearrange(out, "(b t) (h w) n d -> b (t h w) n d", t=T, h=H, w=W) + x = out + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +# Fix the reference frame RoPE to 1,H,W. +# Set the current frame RoPE to 1. +# Set the previous frame RoPE to 0. +class CasualSelfAttention(SelfAttention): + + def forward(self, x, seq_lens, grid_sizes, freqs): + shifting = 3 + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + assert b == 1, "Only support batch_size 1" + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + T, H, W = grid_sizes[0].tolist() + + q = rearrange(q, "b (t h w) n d -> (b t) (h w) n d", t=T, h=H, w=W) + k = rearrange(k, "b (t h w) n d -> (b t) (h w) n d", t=T, h=H, w=W) + v = rearrange(v, "b (t h w) n d -> (b t) (h w) n d", t=T, h=H, w=W) + + ref_q = q[-1:] + q = q[:-1] + + grid_sizes = torch.tensor([[1, H, W]] * q.shape[0], dtype=torch.long) + start = [[shifting, 0, 0]] * q.shape[0] + q = rope_apply(q, grid_sizes, freqs, start=start) + + ref_k = k[-1:] + grid_sizes = torch.tensor([[1, H, W]], dtype=torch.long) + # start = [[shifting, H, W]] + + start = [[shifting + 10, 0, 0]] + ref_k = rope_apply(ref_k, grid_sizes, freqs, start) + ref_k = repeat(ref_k, "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d + + k = k[:-1] + k = torch.cat([*([k[:1]] * shifting), k]) + cat_k = [] + for i in range(shifting): + cat_k.append(k[i : i - shifting]) + cat_k.append(k[shifting:]) + k = torch.cat(cat_k, dim=1) # (bt) (3hw) n d + + grid_sizes = torch.tensor([[shifting + 1, H, W]] * q.shape[0], dtype=torch.long) + k = rope_apply(k, grid_sizes, freqs) + k = torch.cat([k, ref_k], dim=1) + + ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=q.shape[0]) # t hw n d + v = v[:-1] + v = torch.cat([*([v[:1]] * shifting), v]) + cat_v = [] + for i in range(shifting): + cat_v.append(v[i : i - shifting]) + cat_v.append(v[shifting:]) + v = torch.cat(cat_v, dim=1) # (bt) (3hw) n d + v = torch.cat([v, ref_v], dim=1) + + # q: b (t h w) n d + # k: b (t h w) n d + outs = [] + for i in range(q.shape[0]): + out = flash_attention( + q=q[i : i + 1], + k=k[i : i + 1], + v=v[i : i + 1], + window_size=self.window_size, + ) + outs.append(out) + out = torch.cat(outs, dim=0) + out = torch.cat([out, ref_v[:1]], axis=0) + out = rearrange(out, "(b t) (h w) n d -> b (t h w) n d", t=T, h=H, w=W) + x = out + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class MotionerAttentionBlock(nn.Module): + + def __init__( + self, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + self_attn_block="SelfAttention", + ): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = LayerNorm(dim, eps) + if self_attn_block == "SelfAttention": + self.self_attn = SelfAttention(dim, num_heads, window_size, qk_norm, eps) + elif self_attn_block == "SwinSelfAttention": + self.self_attn = SwinSelfAttention( + dim, num_heads, window_size, qk_norm, eps + ) + elif self_attn_block == "CasualSelfAttention": + self.self_attn = CasualSelfAttention( + dim, num_heads, window_size, qk_norm, eps + ) + + self.norm2 = LayerNorm(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), + nn.GELU(approximate="tanh"), + nn.Linear(ffn_dim, dim), + ) + + def forward( + self, + x, + seq_lens, + grid_sizes, + freqs, + ): + # self-attention + y = self.self_attn(self.norm1(x).float(), seq_lens, grid_sizes, freqs) + x = x + y + y = self.ffn(self.norm2(x).float()) + x = x + y + return x + + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = LayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + def forward(self, x): + x = self.head(self.norm(x)) + return x + + +class MotionerTransformers(nn.Module, PeftAdapterMixin): + + def __init__( + self, + patch_size=(1, 2, 2), + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + self_attn_block="SelfAttention", + motion_token_num=1024, + enable_tsm=False, + motion_stride=4, + expand_ratio=2, + trainable_token_pos_emb=False, + ): + super().__init__() + self.patch_size = patch_size + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + self.enable_tsm = enable_tsm + self.motion_stride = motion_stride + self.expand_ratio = expand_ratio + self.sample_c = self.patch_size[0] + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size + ) + + # blocks + self.blocks = nn.ModuleList( + [ + MotionerAttentionBlock( + dim, + ffn_dim, + num_heads, + window_size, + qk_norm, + cross_attn_norm, + eps, + self_attn_block=self_attn_block, + ) + for _ in range(num_layers) + ] + ) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + dim=1, + ) + + self.gradient_checkpointing = False + + self.motion_side_len = int(math.sqrt(motion_token_num)) + assert self.motion_side_len**2 == motion_token_num + self.token = nn.Parameter(torch.zeros(1, motion_token_num, dim).contiguous()) + + self.trainable_token_pos_emb = trainable_token_pos_emb + if trainable_token_pos_emb: + x = torch.zeros([1, motion_token_num, num_heads, d]) + x[..., ::2] = 1 + + gride_sizes = [ + [ + torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([1, self.motion_side_len, self.motion_side_len]) + .unsqueeze(0) + .repeat(1, 1), + torch.tensor([1, self.motion_side_len, self.motion_side_len]) + .unsqueeze(0) + .repeat(1, 1), + ] + ] + token_freqs = rope_apply(x, gride_sizes, self.freqs) + token_freqs = token_freqs[0, :, 0].reshape(motion_token_num, -1, 2) + token_freqs = token_freqs * 0.01 + self.token_freqs = torch.nn.Parameter(token_freqs) + + def after_patch_embedding(self, x): + return x + + def forward( + self, + x, + ): + """ + x: A list of videos each with shape [C, T, H, W]. + t: [B]. + context: A list of text embeddings each with shape [L, C]. + """ + # params + motion_frames = x[0].shape[1] + device = self.patch_embedding.weight.device + freqs = self.freqs + if freqs.device != device: + freqs = freqs.to(device) + + if self.trainable_token_pos_emb: + with amp.autocast(dtype=torch.float64): + token_freqs = self.token_freqs.to(torch.float64) + token_freqs = token_freqs / token_freqs.norm(dim=-1, keepdim=True) + freqs = [freqs, torch.view_as_complex(token_freqs)] + + if self.enable_tsm: + sample_idx = [ + sample_indices( + u.shape[1], + stride=self.motion_stride, + expand_ratio=self.expand_ratio, + c=self.sample_c, + ) + for u in x + ] + x = [ + torch.flip(torch.flip(u, [1])[:, idx], [1]) + for idx, u in zip(sample_idx, x) + ] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + x = self.after_patch_embedding(x) + + seq_f, seq_h, seq_w = x[0].shape[-3:] + batch_size = len(x) + if not self.enable_tsm: + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x] + ) + grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]] + seq_f = 0 + else: + grid_sizes = [] + for idx in sample_idx[0][::-1][:: self.sample_c]: + tsm_frame_grid_sizes = [ + [ + torch.tensor([idx, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor([idx + 1, seq_h, seq_w]) + .unsqueeze(0) + .repeat(batch_size, 1), + torch.tensor([1, seq_h, seq_w]) + .unsqueeze(0) + .repeat(batch_size, 1), + ] + ] + grid_sizes += tsm_frame_grid_sizes + seq_f = sample_idx[0][-1] + 1 + + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + x = torch.cat([u for u in x]) + + batch_size = len(x) + + token_grid_sizes = [ + [ + torch.tensor([seq_f, 0, 0]).unsqueeze(0).repeat(batch_size, 1), + torch.tensor([seq_f + 1, self.motion_side_len, self.motion_side_len]) + .unsqueeze(0) + .repeat(batch_size, 1), + torch.tensor( + [1 if not self.trainable_token_pos_emb else -1, seq_h, seq_w] + ) + .unsqueeze(0) + .repeat(batch_size, 1), + ] # 第三行代表rope emb的想要覆盖到的范围 + ] + + grid_sizes = grid_sizes + token_grid_sizes + token_unpatch_grid_sizes = torch.stack( + [torch.tensor([1, 32, 32], dtype=torch.long) for b in range(batch_size)] + ) + token_len = self.token.shape[1] + token = self.token.clone().repeat(x.shape[0], 1, 1).contiguous() + seq_lens = seq_lens + torch.tensor([t.size(0) for t in token], dtype=torch.long) + x = torch.cat([x, token], dim=1) + # arguments + kwargs = dict( + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=freqs, + ) + + # grad ckpt args + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs, **kwargs): + if return_dict is not None: + return module(*inputs, **kwargs, return_dict=return_dict) + else: + return module(*inputs, **kwargs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + + for idx, block in enumerate(self.blocks): + if self.training and self.gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + **kwargs, + **ckpt_kwargs, + ) + else: + x = block(x, **kwargs) + # head + out = x[:, -token_len:] + return out + + def unpatchify(self, x, grid_sizes): + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[: math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum("fhwpqrc->cfphqwr", u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + + +class FramePackMotioner(nn.Module): + + def __init__( + self, + inner_dim=1024, + num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design + zip_frame_buckets=[ + 1, + 2, + 16, + ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames + drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long) + + self.inner_dim = inner_dim + self.num_heads = num_heads + + assert (inner_dim % num_heads) == 0 and (inner_dim // num_heads) % 2 == 0 + d = inner_dim // num_heads + self.freqs = torch.cat( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + ], + dim=1, + ) + self.drop_mode = drop_mode + + def forward(self, motion_latents, add_last_motion=2): + motion_frames = motion_latents[0].shape[1] + mot = [] + mot_remb = [] + for m in motion_latents: + lat_height, lat_width = m.shape[2], m.shape[3] + padd_lat = torch.zeros( + 16, self.zip_frame_buckets.sum(), lat_height, lat_width + ).to(device=m.device, dtype=m.dtype) + overlap_frame = min(padd_lat.shape[1], m.shape[1]) + if overlap_frame > 0: + padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:] + + if add_last_motion < 2 and self.drop_mode != "drop": + zero_end_frame = self.zip_frame_buckets[ + : self.zip_frame_buckets.__len__() - add_last_motion - 1 + ].sum() + padd_lat[:, -zero_end_frame:] = 0 + + padd_lat = padd_lat.unsqueeze(0) + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[ + :, :, -self.zip_frame_buckets.sum() :, :, : + ].split( + list(self.zip_frame_buckets)[::-1], dim=2 + ) # 16, 2 ,1 + + # patchfy + clean_latents_post = ( + self.proj(clean_latents_post).flatten(2).transpose(1, 2) + ) + clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2) + + if add_last_motion < 2 and self.drop_mode == "drop": + clean_latents_post = ( + clean_latents_post[:, :0] + if add_last_motion < 2 + else clean_latents_post + ) + clean_latents_2x = ( + clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x + ) + + motion_lat = torch.cat( + [clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1 + ) + + # rope + start_time_id = -(self.zip_frame_buckets[:1].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[0] + grid_sizes = ( + [] + if add_last_motion < 2 and self.drop_mode == "drop" + else [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 2, lat_width // 2]) + .unsqueeze(0) + .repeat(1, 1), + torch.tensor( + [self.zip_frame_buckets[0], lat_height // 2, lat_width // 2] + ) + .unsqueeze(0) + .repeat(1, 1), + ] + ] + ) + + start_time_id = -(self.zip_frame_buckets[:2].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[1] // 2 + grid_sizes_2x = ( + [] + if add_last_motion < 1 and self.drop_mode == "drop" + else [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 4, lat_width // 4]) + .unsqueeze(0) + .repeat(1, 1), + torch.tensor( + [self.zip_frame_buckets[1], lat_height // 2, lat_width // 2] + ) + .unsqueeze(0) + .repeat(1, 1), + ] + ] + ) + + start_time_id = -(self.zip_frame_buckets[:3].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 + grid_sizes_4x = [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 8, lat_width // 8]) + .unsqueeze(0) + .repeat(1, 1), + torch.tensor( + [self.zip_frame_buckets[2], lat_height // 2, lat_width // 2] + ) + .unsqueeze(0) + .repeat(1, 1), + ] + ] + + grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x + + motion_rope_emb = rope_precompute( + motion_lat.detach().view( + 1, + motion_lat.shape[1], + self.num_heads, + self.inner_dim // self.num_heads, + ), + grid_sizes, + self.freqs, + start=None, + ) + + mot.append(motion_lat) + mot_remb.append(motion_rope_emb) + return mot, mot_remb + + +def sample_indices(N, stride, expand_ratio, c): + indices = [] + current_start = 0 + + while current_start < N: + bucket_width = int(stride * (expand_ratio ** (len(indices) / stride))) + + interval = int(bucket_width / stride * c) + current_end = min(N, current_start + bucket_width) + bucket_samples = [] + for i in range(current_end - 1, current_start - 1, -interval): + for near in range(c): + bucket_samples.append(i - near) + + indices += bucket_samples[::-1] + current_start += bucket_width + + return indices + + +if __name__ == "__main__": + device = "cuda" + model = FramePackMotioner(inner_dim=1024) + batch_size = 2 + num_frame, height, width = (28, 32, 32) + single_input = torch.ones([16, num_frame, height, width], device=device) + for i in range(num_frame): + single_input[:, num_frame - 1 - i] *= i + x = [single_input] * batch_size + model.forward(x) diff --git a/wan/modules/s2v/s2v_utils.py b/wan/modules/s2v/s2v_utils.py new file mode 100644 index 00000000..88e62c50 --- /dev/null +++ b/wan/modules/s2v/s2v_utils.py @@ -0,0 +1,86 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import numpy as np +import torch + + +def rope_precompute(x, grid_sizes, freqs, start=None): + b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 + + # split freqs + if type(freqs) is list: + trainable_freqs = freqs[1] + freqs = freqs[0] + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64)) + seq_bucket = [0] + if not type(grid_sizes) is list: + grid_sizes = [grid_sizes] + for g in grid_sizes: + if not type(g) is list: + g = [torch.zeros_like(g), g] + batch_size = g[0].shape[0] + for i in range(batch_size): + if start is None: + f_o, h_o, w_o = g[0][i] + else: + f_o, h_o, w_o = start[i] + + f, h, w = g[1][i] + t_f, t_h, t_w = g[2][i] + seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o + seq_len = int(seq_f * seq_h * seq_w) + if seq_len > 0: + if t_f > 0: + factor_f, factor_h, factor_w = ( + (t_f / seq_f).item(), + (t_h / seq_h).item(), + (t_w / seq_w).item(), + ) + # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item()) + if f_o >= 0: + f_sam = ( + np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f) + .astype(int) + .tolist() + ) + else: + f_sam = ( + np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f) + .astype(int) + .tolist() + ) + h_sam = ( + np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h) + .astype(int) + .tolist() + ) + w_sam = ( + np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w) + .astype(int) + .tolist() + ) + + assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 + freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj() + freqs_0 = freqs_0.view(seq_f, 1, 1, -1) + + freqs_i = torch.cat( + [ + freqs_0.expand(seq_f, seq_h, seq_w, -1), + freqs[1][h_sam] + .view(1, seq_h, 1, -1) + .expand(seq_f, seq_h, seq_w, -1), + freqs[2][w_sam] + .view(1, 1, seq_w, -1) + .expand(seq_f, seq_h, seq_w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + elif t_f < 0: + freqs_i = trainable_freqs.unsqueeze(1) + # apply rotary embedding + output[i, seq_bucket[-1] : seq_bucket[-1] + seq_len] = freqs_i + seq_bucket.append(seq_bucket[-1] + seq_len) + return output diff --git a/wan/speech2video.py b/wan/speech2video.py new file mode 100644 index 00000000..2fc6a0f9 --- /dev/null +++ b/wan/speech2video.py @@ -0,0 +1,730 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from copy import deepcopy +from functools import partial + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +import torchvision.transforms.functional as TF +from decord import VideoReader +from PIL import Image +from safetensors import safe_open +from torchvision import transforms +from tqdm import tqdm + +from .distributed.fsdp import shard_model +from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward +from .distributed.util import get_world_size +from .modules.s2v.audio_encoder import AudioEncoder +from .modules.s2v.model_s2v import WanModel_S2V, sp_attn_forward_s2v +from .modules.t5 import T5EncoderModel +from .modules.vae2_1 import Wan2_1_VAE +from .utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + + +def load_safetensors(path): + tensors = {} + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + return tensors + + +class WanS2V: + + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_sp=False, + t5_cpu=False, + init_on_cpu=True, + convert_model_dtype=False, + ): + r""" + Initializes the image-to-video generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for T5 model + dit_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for DiT model + use_sp (`bool`, *optional*, defaults to False): + Enable distribution strategy of sequence parallel. + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + init_on_cpu (`bool`, *optional*, defaults to True): + Enable initializing Transformer Model on CPU. Only works without FSDP or USP. + convert_model_dtype (`bool`, *optional*, defaults to False): + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.init_on_cpu = init_on_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + if t5_fsdp or dit_fsdp or use_sp: + self.init_on_cpu = False + + shard_fn = partial(shard_model, device_id=device_id) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device("cpu"), + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=shard_fn if t5_fsdp else None, + ) + + self.vae = Wan2_1_VAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device, + ) + + logging.info(f"Creating WanModel from {checkpoint_dir}") + if not dit_fsdp: + self.noise_model = WanModel_S2V.from_pretrained( + checkpoint_dir, torch_dtype=self.param_dtype, device_map=self.device + ) + else: + self.noise_model = WanModel_S2V.from_pretrained( + checkpoint_dir, torch_dtype=self.param_dtype + ) + + self.noise_model = self._configure_model( + model=self.noise_model, + use_sp=use_sp, + dit_fsdp=dit_fsdp, + shard_fn=shard_fn, + convert_model_dtype=convert_model_dtype, + ) + + self.audio_encoder = AudioEncoder( + model_id=os.path.join(checkpoint_dir, "wav2vec2-large-xlsr-53-english") + ) + + if use_sp: + self.sp_size = get_world_size() + else: + self.sp_size = 1 + + self.sample_neg_prompt = config.sample_neg_prompt + self.motion_frames = config.transformer.motion_frames + self.drop_first_motion = config.drop_first_motion + self.fps = config.sample_fps + self.audio_sample_m = 0 + + def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, convert_model_dtype): + """ + Configures a model object. This includes setting evaluation modes, + applying distributed parallel strategy, and handling device placement. + + Args: + model (torch.nn.Module): + The model instance to configure. + use_sp (`bool`): + Enable distribution strategy of sequence parallel. + dit_fsdp (`bool`): + Enable FSDP sharding for DiT model. + shard_fn (callable): + The function to apply FSDP sharding. + convert_model_dtype (`bool`): + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + + Returns: + torch.nn.Module: + The configured model. + """ + model.eval().requires_grad_(False) + if use_sp: + for block in model.blocks: + block.self_attn.forward = types.MethodType( + sp_attn_forward_s2v, block.self_attn + ) + model.use_context_parallel = True + + if dist.is_initialized(): + dist.barrier() + + if dit_fsdp: + model = shard_fn(model) + else: + if convert_model_dtype: + model.to(self.param_dtype) + if not self.init_on_cpu: + model.to(self.device) + + return model + + def get_size_less_than_area( + self, height, width, target_area=1024 * 704, divisor=64 + ): + if height * width <= target_area: + # If the original image area is already less than or equal to the target, + # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target. + max_upper_area = target_area + min_scale = 0.1 + max_scale = 1.0 + else: + # Resize to fit within the target area and then pad to multiples of `divisor` + max_upper_area = ( + target_area # Maximum allowed total pixel count after padding + ) + d = divisor - 1 + b = d * (height + width) + a = height * width + c = d**2 - max_upper_area + + # Calculate scale boundaries using quadratic equation + min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / ( + 2 * a + ) # Scale when maximum padding is applied + max_scale = math.sqrt( + max_upper_area / (height * width) + ) # Scale without any padding + + # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area + # Use binary search-like iteration to find this scale + find_it = False + for i in range(100): + scale = max_scale - (max_scale - min_scale) * i / 100 + new_height, new_width = int(height * scale), int(width * scale) + + # Pad to make dimensions divisible by 64 + pad_height = (64 - new_height % 64) % 64 + pad_width = (64 - new_width % 64) % 64 + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top + pad_left = pad_width // 2 + pad_right = pad_width - pad_left + + padded_height, padded_width = new_height + pad_height, new_width + pad_width + + if padded_height * padded_width <= max_upper_area: + find_it = True + break + + if find_it: + return padded_height, padded_width + else: + # Fallback: calculate target dimensions based on aspect ratio and divisor alignment + aspect_ratio = width / height + target_width = int((target_area * aspect_ratio) ** 0.5 // divisor * divisor) + target_height = int( + (target_area / aspect_ratio) ** 0.5 // divisor * divisor + ) + + # Ensure the result is not larger than the original resolution + if target_width >= width or target_height >= height: + target_width = int(width // divisor * divisor) + target_height = int(height // divisor * divisor) + + return target_height, target_width + + def prepare_default_cond_input( + self, + map_shape=[3, 12, 64, 64], + motion_frames=5, + lat_motion_frames=2, + enable_mano=False, + enable_kp=False, + enable_pose=False, + ): + default_value = [1.0, -1.0, -1.0] + cond_enable = [enable_mano, enable_kp, enable_pose] + cond = [] + for d, c in zip(default_value, cond_enable): + if c: + map_value = ( + torch.ones(map_shape, dtype=self.param_dtype, device=self.device) + * d + ) + cond_lat = torch.cat( + [map_value[:, :, 0:1].repeat(1, 1, motion_frames, 1, 1), map_value], + dim=2, + ) + cond_lat = torch.stack(self.vae.encode(cond_lat.to(self.param_dtype)))[ + :, :, lat_motion_frames: + ].to(self.param_dtype) + + cond.append(cond_lat) + if len(cond) >= 1: + cond = torch.cat(cond, dim=1) + else: + cond = None + return cond + + def encode_audio(self, audio_path, infer_frames): + z = self.audio_encoder.extract_audio_feat(audio_path, return_all_layers=True) + audio_embed_bucket, num_repeat = self.audio_encoder.get_audio_embed_bucket_fps( + z, fps=self.fps, batch_frames=infer_frames, m=self.audio_sample_m + ) + audio_embed_bucket = audio_embed_bucket.to(self.device, self.param_dtype) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0) + if len(audio_embed_bucket.shape) == 3: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) + elif len(audio_embed_bucket.shape) == 4: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + return audio_embed_bucket, num_repeat + + def read_last_n_frames(self, video_path, n_frames, target_fps=16, reverse=False): + """ + Read the last `n_frames` from a video at the specified frame rate. + + Parameters: + video_path (str): Path to the video file. + n_frames (int): Number of frames to read. + target_fps (int, optional): Target sampling frame rate. Defaults to 16. + reverse (bool, optional): Whether to read frames in reverse order. + If True, reads the first `n_frames` instead of the last ones. + + Returns: + np.ndarray: A NumPy array of shape [n_frames, H, W, 3], representing the sampled video frames. + """ + vr = VideoReader(video_path) + original_fps = vr.get_avg_fps() + total_frames = len(vr) + + interval = max(1, round(original_fps / target_fps)) + + required_span = (n_frames - 1) * interval + + start_frame = max(0, total_frames - required_span - 1) if not reverse else 0 + + sampled_indices = [] + for i in range(n_frames): + indice = start_frame + i * interval + if indice >= total_frames: + break + else: + sampled_indices.append(indice) + + return vr.get_batch(sampled_indices).asnumpy() + + def load_pose_cond(self, pose_video, num_repeat, infer_frames, size): + HEIGHT, WIDTH = size + if not pose_video is None: + pose_seq = self.read_last_n_frames( + pose_video, + n_frames=infer_frames * num_repeat, + target_fps=self.fps, + reverse=True, + ) + + resize_opreat = transforms.Resize(min(HEIGHT, WIDTH)) + crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH)) + tensor_trans = transforms.ToTensor() + + cond_tensor = torch.from_numpy(pose_seq) + cond_tensor = cond_tensor.permute(0, 3, 1, 2) / 255.0 * 2 - 1.0 + cond_tensor = ( + crop_opreat(resize_opreat(cond_tensor)).permute(1, 0, 2, 3).unsqueeze(0) + ) + + padding_frame_num = num_repeat * infer_frames - cond_tensor.shape[2] + cond_tensor = torch.cat( + [cond_tensor, -torch.ones([1, 3, padding_frame_num, HEIGHT, WIDTH])], + dim=2, + ) + + cond_tensors = torch.chunk(cond_tensor, num_repeat, dim=2) + else: + cond_tensors = [-torch.ones([1, 3, infer_frames, HEIGHT, WIDTH])] + + COND = [] + for r in range(len(cond_tensors)): + cond = cond_tensors[r] + cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2) + cond_lat = torch.stack( + self.vae.encode(cond.to(dtype=self.param_dtype, device=self.device)) + )[ + :, :, 1: + ].cpu() # for mem save + COND.append(cond_lat) + return COND + + def get_gen_size(self, size, max_area, ref_image_path, pre_video_path): + if not size is None: + HEIGHT, WIDTH = size + else: + if pre_video_path: + ref_image = self.read_last_n_frames(pre_video_path, n_frames=1)[0] + else: + ref_image = np.array(Image.open(ref_image_path).convert("RGB")) + HEIGHT, WIDTH = ref_image.shape[:2] + HEIGHT, WIDTH = self.get_size_less_than_area( + HEIGHT, WIDTH, target_area=max_area + ) + return (HEIGHT, WIDTH) + + def generate( + self, + input_prompt, + ref_image_path, + audio_path, + enable_tts, + tts_prompt_audio, + tts_prompt_text, + tts_text, + num_repeat=1, + pose_video=None, + max_area=720 * 1280, + infer_frames=80, + shift=5.0, + sample_solver="unipc", + sampling_steps=40, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + init_first_frame=False, + ): + r""" + Generates video frames from input image and text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation. + ref_image_path ('str'): + Input image path + audio_path ('str'): + Audio for video driven + num_repeat ('int'): + Number of clips to generate; will be automatically adjusted based on the audio length + pose_video ('str'): + If provided, uses a sequence of poses to drive the generated video + max_area (`int`, *optional*, defaults to 720*1280): + Maximum pixel area for latent space calculation. Controls video resolution scaling + infer_frames (`int`, *optional*, defaults to 80): + How many frames to generate per clips. The number should be 4n + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity. + If tuple, the first guide_scale will be used for low noise model and + the second guide_scale will be used for high noise model. + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + init_first_frame (`bool`, *optional*, defaults to False): + Whether to use the reference image as the first frame (i.e., standard image-to-video generation) + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from max_area) + - W: Frame width from max_area) + """ + # preprocess + size = self.get_gen_size( + size=None, + max_area=max_area, + ref_image_path=ref_image_path, + pre_video_path=None, + ) + HEIGHT, WIDTH = size + channel = 3 + + resize_opreat = transforms.Resize(min(HEIGHT, WIDTH)) + crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH)) + tensor_trans = transforms.ToTensor() + + ref_image = None + motion_latents = None + + if ref_image is None: + ref_image = np.array(Image.open(ref_image_path).convert("RGB")) + if motion_latents is None: + motion_latents = torch.zeros( + [1, channel, self.motion_frames, HEIGHT, WIDTH], + dtype=self.param_dtype, + device=self.device, + ) + + # extract audio emb + if enable_tts is True: + audio_path = self.tts(tts_prompt_audio, tts_prompt_text, tts_text) + audio_emb, nr = self.encode_audio(audio_path, infer_frames=infer_frames) + if num_repeat is None or num_repeat > nr: + num_repeat = nr + + lat_motion_frames = (self.motion_frames + 3) // 4 + model_pic = crop_opreat(resize_opreat(Image.fromarray(ref_image))) + + ref_pixel_values = tensor_trans(model_pic) + ref_pixel_values = ( + ref_pixel_values.unsqueeze(1).unsqueeze(0) * 2 - 1.0 + ) # b c 1 h w + ref_pixel_values = ref_pixel_values.to( + dtype=self.vae.dtype, device=self.vae.device + ) + ref_latents = torch.stack(self.vae.encode(ref_pixel_values)) + + # encode the motion latents + videos_last_frames = motion_latents.detach() + drop_first_motion = self.drop_first_motion + if init_first_frame: + drop_first_motion = False + motion_latents[:, :, -6:] = ref_pixel_values + motion_latents = torch.stack(self.vae.encode(motion_latents)) + + # get pose cond input if need + COND = self.load_pose_cond( + pose_video=pose_video, + num_repeat=num_repeat, + infer_frames=infer_frames, + size=size, + ) + + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + + # preprocess + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device("cpu")) + context_null = self.text_encoder([n_prompt], torch.device("cpu")) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + out = [] + # evaluation mode + with ( + torch.amp.autocast("cuda", dtype=self.param_dtype), + torch.no_grad(), + ): + for r in range(num_repeat): + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed + r) + + lat_target_frames = ( + infer_frames + 3 + self.motion_frames + ) // 4 - lat_motion_frames + target_shape = [lat_target_frames, HEIGHT // 8, WIDTH // 8] + noise = [ + torch.randn( + 16, + target_shape[0], + target_shape[1], + target_shape[2], + dtype=self.param_dtype, + device=self.device, + generator=seed_g, + ) + ] + max_seq_len = np.prod(target_shape) // 4 + + if sample_solver == "unipc": + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False, + ) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift + ) + timesteps = sample_scheduler.timesteps + elif sample_solver == "dpm++": + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False, + ) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, device=self.device, sigmas=sampling_sigmas + ) + else: + raise NotImplementedError("Unsupported solver.") + + latents = deepcopy(noise) + with torch.no_grad(): + left_idx = r * infer_frames + right_idx = r * infer_frames + infer_frames + cond_latents = COND[r] if pose_video else COND[0] * 0 + cond_latents = cond_latents.to( + dtype=self.param_dtype, device=self.device + ) + audio_input = audio_emb[..., left_idx:right_idx] + input_motion_latents = motion_latents.clone() + + arg_c = { + "context": context[0:1], + "seq_len": max_seq_len, + "cond_states": cond_latents, + "motion_latents": input_motion_latents, + "ref_latents": ref_latents, + "audio_input": audio_input, + "motion_frames": [self.motion_frames, lat_motion_frames], + "drop_motion_frames": drop_first_motion and r == 0, + } + if guide_scale > 1: + arg_null = { + "context": context_null[0:1], + "seq_len": max_seq_len, + "cond_states": cond_latents, + "motion_latents": input_motion_latents, + "ref_latents": ref_latents, + "audio_input": 0.0 * audio_input, + "motion_frames": [self.motion_frames, lat_motion_frames], + "drop_motion_frames": drop_first_motion and r == 0, + } + if offload_model or self.init_on_cpu: + self.noise_model.to(self.device) + torch.cuda.empty_cache() + + for i, t in enumerate(tqdm(timesteps)): + latent_model_input = latents[0:1] + timestep = [t] + + timestep = torch.stack(timestep).to(self.device) + + noise_pred_cond = self.noise_model( + latent_model_input, t=timestep, **arg_c + ) + + if guide_scale > 1: + noise_pred_uncond = self.noise_model( + latent_model_input, t=timestep, **arg_null + ) + noise_pred = [ + u + guide_scale * (c - u) + for c, u in zip(noise_pred_cond, noise_pred_uncond) + ] + else: + noise_pred = noise_pred_cond + + temp_x0 = sample_scheduler.step( + noise_pred[0].unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g, + )[0] + latents[0] = temp_x0.squeeze(0) + + if offload_model: + self.noise_model.cpu() + torch.cuda.synchronize() + torch.cuda.empty_cache() + latents = torch.stack(latents) + if not (drop_first_motion and r == 0): + decode_latents = torch.cat([motion_latents, latents], dim=2) + else: + decode_latents = torch.cat([ref_latents, latents], dim=2) + image = torch.stack(self.vae.decode(decode_latents)) + image = image[:, :, -(infer_frames):] + if drop_first_motion and r == 0: + image = image[:, :, 3:] + + overlap_frames_num = min(self.motion_frames, image.shape[2]) + videos_last_frames = torch.cat( + [ + videos_last_frames[:, :, overlap_frames_num:], + image[:, :, -overlap_frames_num:], + ], + dim=2, + ) + videos_last_frames = videos_last_frames.to( + dtype=motion_latents.dtype, device=motion_latents.device + ) + motion_latents = torch.stack(self.vae.encode(videos_last_frames)) + out.append(image.cpu()) + + videos = torch.cat(out, dim=2) + del noise, latents + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None + + def tts(self, tts_prompt_audio, tts_prompt_text, tts_text): + if not hasattr(self, "cosyvoice"): + self.load_tts() + speech_list = [] + import torchaudio + from cosyvoice.utils.file_utils import load_wav + + prompt_speech_16k = load_wav(tts_prompt_audio, 16000) + if tts_prompt_text is not None: + for i in self.cosyvoice.inference_zero_shot( + tts_text, tts_prompt_text, prompt_speech_16k + ): + speech_list.append(i["tts_speech"]) + else: + for i in self.cosyvoice.inference_cross_lingual( + tts_text, prompt_speech_16k + ): + speech_list.append(i["tts_speech"]) + torchaudio.save( + "tts.wav", torch.concat(speech_list, dim=1), self.cosyvoice.sample_rate + ) + return "tts.wav" + + def load_tts(self): + if not os.path.exists("CosyVoice"): + from wan.utils.utils import download_cosyvoice_repo + + download_cosyvoice_repo("CosyVoice") + if not os.path.exists("CosyVoice2-0.5B"): + from wan.utils.utils import download_cosyvoice_model + + download_cosyvoice_model("CosyVoice2-0.5B", "CosyVoice2-0.5B") + sys.path.append("CosyVoice") + sys.path.append("CosyVoice/third_party/Matcha-TTS") + from cosyvoice.cli.cosyvoice import CosyVoice2 + + self.cosyvoice = CosyVoice2("CosyVoice2-0.5B") diff --git a/wan/text2video.py b/wan/text2video.py index 7c79c667..fc45bb05 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -86,35 +86,45 @@ def __init__( self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, - device=torch.device('cpu'), + device=torch.device("cpu"), checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn=shard_fn if t5_fsdp else None) + shard_fn=shard_fn if t5_fsdp else None, + ) + # Diffusers-style handle to tokenizer for external use + self.tokenizer = self.text_encoder.tokenizer + # Track if text encoder has been offloaded/removed + self._text_encoder_offloaded = False self.vae_stride = config.vae_stride self.patch_size = config.patch_size self.vae = Wan2_1_VAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) + device=self.device, + ) logging.info(f"Creating WanModel from {checkpoint_dir}") self.low_noise_model = WanModel.from_pretrained( - checkpoint_dir, subfolder=config.low_noise_checkpoint) + checkpoint_dir, subfolder=config.low_noise_checkpoint + ) self.low_noise_model = self._configure_model( model=self.low_noise_model, use_sp=use_sp, dit_fsdp=dit_fsdp, shard_fn=shard_fn, - convert_model_dtype=convert_model_dtype) + convert_model_dtype=convert_model_dtype, + ) self.high_noise_model = WanModel.from_pretrained( - checkpoint_dir, subfolder=config.high_noise_checkpoint) + checkpoint_dir, subfolder=config.high_noise_checkpoint + ) self.high_noise_model = self._configure_model( model=self.high_noise_model, use_sp=use_sp, dit_fsdp=dit_fsdp, shard_fn=shard_fn, - convert_model_dtype=convert_model_dtype) + convert_model_dtype=convert_model_dtype, + ) if use_sp: self.sp_size = get_world_size() else: @@ -122,8 +132,24 @@ def __init__( self.sample_neg_prompt = config.sample_neg_prompt - def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, - convert_model_dtype): + # Optional diffusers-style API for memory optimization + def enable_model_cpu_offload(self): + """Enable CPU offload of the DiT during sampling.""" + self.init_on_cpu = True + + def enable_sequential_cpu_offload(self): + """Alias for compatibility with diffusers API.""" + self.enable_model_cpu_offload() + + def enable_attention_slicing(self, *args, **kwargs): + """No-op placeholder for API compatibility.""" + logging.info("attention slicing not applicable; using custom attention.") + + def enable_xformers_memory_efficient_attention(self): + """No-op placeholder; flash-attn path already optimized in WanModel.""" + logging.info("xFormers toggle ignored; using built-in optimized attention.") + + def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, convert_model_dtype): """ Configures a model object. This includes setting evaluation modes, applying distributed parallel strategy, and handling device placement. @@ -150,7 +176,8 @@ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, if use_sp: for block in model.blocks: block.self_attn.forward = types.MethodType( - sp_attn_forward, block.self_attn) + sp_attn_forward, block.self_attn + ) model.forward = types.MethodType(sp_dit_forward, model) if dist.is_initialized(): @@ -184,33 +211,40 @@ def _prepare_model_for_timestep(self, t, boundary, offload_model): The active model on the target device for the current timestep. """ if t.item() >= boundary: - required_model_name = 'high_noise_model' - offload_model_name = 'low_noise_model' + required_model_name = "high_noise_model" + offload_model_name = "low_noise_model" else: - required_model_name = 'low_noise_model' - offload_model_name = 'high_noise_model' + required_model_name = "low_noise_model" + offload_model_name = "high_noise_model" if offload_model or self.init_on_cpu: - if next(getattr( - self, - offload_model_name).parameters()).device.type == 'cuda': - getattr(self, offload_model_name).to('cpu') - if next(getattr( - self, - required_model_name).parameters()).device.type == 'cpu': + if ( + next(getattr(self, offload_model_name).parameters()).device.type + == "cuda" + ): + getattr(self, offload_model_name).to("cpu") + if ( + next(getattr(self, required_model_name).parameters()).device.type + == "cpu" + ): getattr(self, required_model_name).to(self.device) return getattr(self, required_model_name) - def generate(self, - input_prompt, - size=(1280, 720), - frame_num=81, - shift=5.0, - sample_solver='unipc', - sampling_steps=50, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True): + def generate( + self, + input_prompt, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver="unipc", + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + prompt_embeds=None, + negative_prompt_embeds=None, + precision: str = "fp16", + ): r""" Generates video frames from text prompt using diffusion process. @@ -247,16 +281,28 @@ def generate(self, - W: Frame width from size) """ # preprocess - guide_scale = (guide_scale, guide_scale) if isinstance( - guide_scale, float) else guide_scale + guide_scale = ( + (guide_scale, guide_scale) + if isinstance(guide_scale, float) + else guide_scale + ) F = frame_num - target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, - size[1] // self.vae_stride[1], - size[0] // self.vae_stride[2]) - - seq_len = math.ceil((target_shape[2] * target_shape[3]) / - (self.patch_size[1] * self.patch_size[2]) * - target_shape[1] / self.sp_size) * self.sp_size + target_shape = ( + self.vae.model.z_dim, + (F - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2], + ) + + seq_len = ( + math.ceil( + (target_shape[2] * target_shape[3]) + / (self.patch_size[1] * self.patch_size[2]) + * target_shape[1] + / self.sp_size + ) + * self.sp_size + ) if n_prompt == "": n_prompt = self.sample_neg_prompt @@ -264,17 +310,79 @@ def generate(self, seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) - if not self.t5_cpu: - self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if offload_model: - self.text_encoder.model.cpu() + # Handle text conditioning + if prompt_embeds is not None or negative_prompt_embeds is not None: + # Defensive: ensure both are provided when CFG is used + if guide_scale is not None and (negative_prompt_embeds is None): + raise ValueError( + "negative_prompt_embeds must be provided when using guidance." + ) + context = prompt_embeds + context_null = ( + negative_prompt_embeds if negative_prompt_embeds is not None else [] + ) + # If user provides embeds, ensure encoder weights are off GPU + if self.text_encoder is not None: + logging.warning( + "prompt_embeds provided; preventing pipeline from re-encoding prompts." + ) + if hasattr(self.text_encoder, "model"): + self.text_encoder.model.cpu() + self._text_encoder_offloaded = True + # Cast dtype as requested + target_dtype = torch.float16 if precision == "fp16" else torch.bfloat16 + context = [t.to(dtype=target_dtype, device=self.device) for t in context] + context_null = [ + t.to(dtype=target_dtype, device=self.device) for t in context_null + ] else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) - context = [t.to(self.device) for t in context] - context_null = [t.to(self.device) for t in context_null] + # Compute on the fly, then offload encoder immediately + if self.text_encoder is None: + raise RuntimeError( + "text_encoder is not available. Provide prompt_embeds to generate()." + ) + # Embedding computation without grads + with torch.inference_mode(): + self.text_encoder.model.eval() + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + else: + context = self.text_encoder([input_prompt], torch.device("cpu")) + context_null = self.text_encoder([n_prompt], torch.device("cpu")) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + # Cast to requested precision for downstream model + target_dtype = torch.float16 if precision == "fp16" else torch.bfloat16 + context = [ + t.to(dtype=target_dtype, device=self.device) for t in context + ] + context_null = [ + t.to(dtype=target_dtype, device=self.device) for t in context_null + ] + # Offload and break references immediately + try: + if hasattr(torch.cuda, "memory_summary") and torch.cuda.is_available(): + logging.info( + "CUDA memory before T5 offload:\n" + + torch.cuda.memory_summary(device=self.device) + ) + except Exception: + pass + if hasattr(self.text_encoder, "model"): + self.text_encoder.model.cpu() + self._text_encoder_offloaded = True + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + try: + logging.info( + "CUDA memory after T5 offload:\n" + + torch.cuda.memory_summary(device=self.device) + ) + except Exception: + pass noise = [ torch.randn( @@ -284,53 +392,54 @@ def generate(self, target_shape[3], dtype=torch.float32, device=self.device, - generator=seed_g) + generator=seed_g, + ) ] @contextmanager def noop_no_sync(): yield - no_sync_low_noise = getattr(self.low_noise_model, 'no_sync', - noop_no_sync) - no_sync_high_noise = getattr(self.high_noise_model, 'no_sync', - noop_no_sync) + no_sync_low_noise = getattr(self.low_noise_model, "no_sync", noop_no_sync) + no_sync_high_noise = getattr(self.high_noise_model, "no_sync", noop_no_sync) # evaluation mode with ( - torch.amp.autocast('cuda', dtype=self.param_dtype), - torch.no_grad(), - no_sync_low_noise(), - no_sync_high_noise(), + torch.amp.autocast("cuda", dtype=self.param_dtype), + torch.no_grad(), + no_sync_low_noise(), + no_sync_high_noise(), ): boundary = self.boundary * self.num_train_timesteps - if sample_solver == 'unipc': + if sample_solver == "unipc": sample_scheduler = FlowUniPCMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, - use_dynamic_shifting=False) + use_dynamic_shifting=False, + ) sample_scheduler.set_timesteps( - sampling_steps, device=self.device, shift=shift) + sampling_steps, device=self.device, shift=shift + ) timesteps = sample_scheduler.timesteps - elif sample_solver == 'dpm++': + elif sample_solver == "dpm++": sample_scheduler = FlowDPMSolverMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, - use_dynamic_shifting=False) + use_dynamic_shifting=False, + ) sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) timesteps, _ = retrieve_timesteps( - sample_scheduler, - device=self.device, - sigmas=sampling_sigmas) + sample_scheduler, device=self.device, sigmas=sampling_sigmas + ) else: raise NotImplementedError("Unsupported solver.") # sample videos latents = noise - arg_c = {'context': context, 'seq_len': seq_len} - arg_null = {'context': context_null, 'seq_len': seq_len} + arg_c = {"context": context, "seq_len": seq_len} + arg_null = {"context": context_null, "seq_len": seq_len} for _, t in enumerate(tqdm(timesteps)): latent_model_input = latents @@ -338,25 +447,25 @@ def noop_no_sync(): timestep = torch.stack(timestep) - model = self._prepare_model_for_timestep( - t, boundary, offload_model) - sample_guide_scale = guide_scale[1] if t.item( - ) >= boundary else guide_scale[0] + model = self._prepare_model_for_timestep(t, boundary, offload_model) + sample_guide_scale = ( + guide_scale[1] if t.item() >= boundary else guide_scale[0] + ) - noise_pred_cond = model( - latent_model_input, t=timestep, **arg_c)[0] - noise_pred_uncond = model( - latent_model_input, t=timestep, **arg_null)[0] + noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0] + noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0] noise_pred = noise_pred_uncond + sample_guide_scale * ( - noise_pred_cond - noise_pred_uncond) + noise_pred_cond - noise_pred_uncond + ) temp_x0 = sample_scheduler.step( noise_pred.unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, - generator=seed_g)[0] + generator=seed_g, + )[0] latents = [temp_x0.squeeze(0)] x0 = latents diff --git a/wan/textimage2video.py b/wan/textimage2video.py index 67e9fd29..9260057e 100644 --- a/wan/textimage2video.py +++ b/wan/textimage2video.py @@ -88,16 +88,21 @@ def __init__( self.text_encoder = T5EncoderModel( text_len=config.text_len, dtype=config.t5_dtype, - device=torch.device('cpu'), + device=torch.device("cpu"), checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn=shard_fn if t5_fsdp else None) + shard_fn=shard_fn if t5_fsdp else None, + ) + # Diffusers-style handle to tokenizer for external use + self.tokenizer = self.text_encoder.tokenizer + self._text_encoder_offloaded = False self.vae_stride = config.vae_stride self.patch_size = config.patch_size self.vae = Wan2_2_VAE( vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) + device=self.device, + ) logging.info(f"Creating WanModel from {checkpoint_dir}") self.model = WanModel.from_pretrained(checkpoint_dir) @@ -106,7 +111,8 @@ def __init__( use_sp=use_sp, dit_fsdp=dit_fsdp, shard_fn=shard_fn, - convert_model_dtype=convert_model_dtype) + convert_model_dtype=convert_model_dtype, + ) if use_sp: self.sp_size = get_world_size() @@ -115,8 +121,20 @@ def __init__( self.sample_neg_prompt = config.sample_neg_prompt - def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, - convert_model_dtype): + # Optional diffusers-style API for memory optimization + def enable_model_cpu_offload(self): + self.init_on_cpu = True + + def enable_sequential_cpu_offload(self): + self.enable_model_cpu_offload() + + def enable_attention_slicing(self, *args, **kwargs): + logging.info("attention slicing not applicable; using custom attention.") + + def enable_xformers_memory_efficient_attention(self): + logging.info("xFormers toggle ignored; using built-in optimized attention.") + + def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, convert_model_dtype): """ Configures a model object. This includes setting evaluation modes, applying distributed parallel strategy, and handling device placement. @@ -143,7 +161,8 @@ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, if use_sp: for block in model.blocks: block.self_attn.forward = types.MethodType( - sp_attn_forward, block.self_attn) + sp_attn_forward, block.self_attn + ) model.forward = types.MethodType(sp_dit_forward, model) if dist.is_initialized(): @@ -159,19 +178,24 @@ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, return model - def generate(self, - input_prompt, - img=None, - size=(1280, 704), - max_area=704 * 1280, - frame_num=81, - shift=5.0, - sample_solver='unipc', - sampling_steps=50, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True): + def generate( + self, + input_prompt, + img=None, + size=(1280, 704), + max_area=704 * 1280, + frame_num=81, + shift=5.0, + sample_solver="unipc", + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + prompt_embeds=None, + negative_prompt_embeds=None, + precision: str = "fp16", + ): r""" Generates video frames from text prompt using diffusion process. @@ -222,7 +246,8 @@ def generate(self, guide_scale=guide_scale, n_prompt=n_prompt, seed=seed, - offload_model=offload_model) + offload_model=offload_model, + ) # t2v return self.t2v( input_prompt=input_prompt, @@ -234,19 +259,28 @@ def generate(self, guide_scale=guide_scale, n_prompt=n_prompt, seed=seed, - offload_model=offload_model) - - def t2v(self, - input_prompt, - size=(1280, 704), - frame_num=121, - shift=5.0, - sample_solver='unipc', - sampling_steps=50, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True): + offload_model=offload_model, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + precision=precision, + ) + + def t2v( + self, + input_prompt, + size=(1280, 704), + frame_num=121, + shift=5.0, + sample_solver="unipc", + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + prompt_embeds=None, + negative_prompt_embeds=None, + precision: str = "fp16", + ): r""" Generates video frames from text prompt using diffusion process. @@ -282,13 +316,22 @@ def t2v(self, """ # preprocess F = frame_num - target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, - size[1] // self.vae_stride[1], - size[0] // self.vae_stride[2]) - - seq_len = math.ceil((target_shape[2] * target_shape[3]) / - (self.patch_size[1] * self.patch_size[2]) * - target_shape[1] / self.sp_size) * self.sp_size + target_shape = ( + self.vae.model.z_dim, + (F - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2], + ) + + seq_len = ( + math.ceil( + (target_shape[2] * target_shape[3]) + / (self.patch_size[1] * self.patch_size[2]) + * target_shape[1] + / self.sp_size + ) + * self.sp_size + ) if n_prompt == "": n_prompt = self.sample_neg_prompt @@ -296,17 +339,72 @@ def t2v(self, seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) - if not self.t5_cpu: - self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if offload_model: - self.text_encoder.model.cpu() + # Handle text conditioning + if prompt_embeds is not None or negative_prompt_embeds is not None: + if guide_scale is not None and (negative_prompt_embeds is None): + raise ValueError( + "negative_prompt_embeds must be provided when using guidance." + ) + context = prompt_embeds + context_null = ( + negative_prompt_embeds if negative_prompt_embeds is not None else [] + ) + if self.text_encoder is not None: + logging.warning( + "prompt_embeds provided; preventing redundant text encoding." + ) + if hasattr(self.text_encoder, "model"): + self.text_encoder.model.cpu() + self._text_encoder_offloaded = True + target_dtype = torch.float16 if precision == "fp16" else torch.bfloat16 + context = [t.to(dtype=target_dtype, device=self.device) for t in context] + context_null = [ + t.to(dtype=target_dtype, device=self.device) for t in context_null + ] else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) - context = [t.to(self.device) for t in context] - context_null = [t.to(self.device) for t in context_null] + if self.text_encoder is None: + raise RuntimeError( + "text_encoder is not available. Provide prompt_embeds to t2v()." + ) + with torch.inference_mode(): + self.text_encoder.model.eval() + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + else: + context = self.text_encoder([input_prompt], torch.device("cpu")) + context_null = self.text_encoder([n_prompt], torch.device("cpu")) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + target_dtype = torch.float16 if precision == "fp16" else torch.bfloat16 + context = [ + t.to(dtype=target_dtype, device=self.device) for t in context + ] + context_null = [ + t.to(dtype=target_dtype, device=self.device) for t in context_null + ] + try: + if hasattr(torch.cuda, "memory_summary") and torch.cuda.is_available(): + logging.info( + "CUDA memory before T5 offload:\n" + + torch.cuda.memory_summary(device=self.device) + ) + except Exception: + pass + if hasattr(self.text_encoder, "model"): + self.text_encoder.model.cpu() + self._text_encoder_offloaded = True + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + try: + logging.info( + "CUDA memory after T5 offload:\n" + + torch.cuda.memory_summary(device=self.device) + ) + except Exception: + pass noise = [ torch.randn( @@ -316,40 +414,43 @@ def t2v(self, target_shape[3], dtype=torch.float32, device=self.device, - generator=seed_g) + generator=seed_g, + ) ] @contextmanager def noop_no_sync(): yield - no_sync = getattr(self.model, 'no_sync', noop_no_sync) + no_sync = getattr(self.model, "no_sync", noop_no_sync) # evaluation mode with ( - torch.amp.autocast('cuda', dtype=self.param_dtype), - torch.no_grad(), - no_sync(), + torch.amp.autocast("cuda", dtype=self.param_dtype), + torch.no_grad(), + no_sync(), ): - if sample_solver == 'unipc': + if sample_solver == "unipc": sample_scheduler = FlowUniPCMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, - use_dynamic_shifting=False) + use_dynamic_shifting=False, + ) sample_scheduler.set_timesteps( - sampling_steps, device=self.device, shift=shift) + sampling_steps, device=self.device, shift=shift + ) timesteps = sample_scheduler.timesteps - elif sample_solver == 'dpm++': + elif sample_solver == "dpm++": sample_scheduler = FlowDPMSolverMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, - use_dynamic_shifting=False) + use_dynamic_shifting=False, + ) sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) timesteps, _ = retrieve_timesteps( - sample_scheduler, - device=self.device, - sigmas=sampling_sigmas) + sample_scheduler, device=self.device, sigmas=sampling_sigmas + ) else: raise NotImplementedError("Unsupported solver.") @@ -357,8 +458,8 @@ def noop_no_sync(): latents = noise mask1, mask2 = masks_like(noise, zero=False) - arg_c = {'context': context, 'seq_len': seq_len} - arg_null = {'context': context_null, 'seq_len': seq_len} + arg_c = {"context": context, "seq_len": seq_len} + arg_null = {"context": context_null, "seq_len": seq_len} if offload_model or self.init_on_cpu: self.model.to(self.device) @@ -371,26 +472,30 @@ def noop_no_sync(): timestep = torch.stack(timestep) temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten() - temp_ts = torch.cat([ - temp_ts, - temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep - ]) + temp_ts = torch.cat( + [temp_ts, temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep] + ) timestep = temp_ts.unsqueeze(0) - noise_pred_cond = self.model( - latent_model_input, t=timestep, **arg_c)[0] + noise_pred_cond = self.model(latent_model_input, t=timestep, **arg_c)[0] + if offload_model: + # Proactively release cached blocks between cond/uncond passes + torch.cuda.empty_cache() noise_pred_uncond = self.model( - latent_model_input, t=timestep, **arg_null)[0] + latent_model_input, t=timestep, **arg_null + )[0] noise_pred = noise_pred_uncond + guide_scale * ( - noise_pred_cond - noise_pred_uncond) + noise_pred_cond - noise_pred_uncond + ) temp_x0 = sample_scheduler.step( noise_pred.unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, - generator=seed_g)[0] + generator=seed_g, + )[0] latents = [temp_x0.squeeze(0)] x0 = latents if offload_model: @@ -410,18 +515,23 @@ def noop_no_sync(): return videos[0] if self.rank == 0 else None - def i2v(self, - input_prompt, - img, - max_area=704 * 1280, - frame_num=121, - shift=5.0, - sample_solver='unipc', - sampling_steps=40, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True): + def i2v( + self, + input_prompt, + img, + max_area=704 * 1280, + frame_num=121, + shift=5.0, + sample_solver="unipc", + sampling_steps=40, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + prompt_embeds=None, + negative_prompt_embeds=None, + precision: str = "fp16", + ): r""" Generates video frames from input image and text prompt using diffusion process. @@ -460,8 +570,10 @@ def i2v(self, """ # preprocess ih, iw = img.height, img.width - dh, dw = self.patch_size[1] * self.vae_stride[1], self.patch_size[ - 2] * self.vae_stride[2] + dh, dw = ( + self.patch_size[1] * self.vae_stride[1], + self.patch_size[2] * self.vae_stride[2], + ) ow, oh = best_output_size(iw, ih, dw, dh, max_area) scale = max(ow / iw, oh / ih) @@ -477,37 +589,96 @@ def i2v(self, img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1) F = frame_num - seq_len = ((F - 1) // self.vae_stride[0] + 1) * ( - oh // self.vae_stride[1]) * (ow // self.vae_stride[2]) // ( - self.patch_size[1] * self.patch_size[2]) + seq_len = ( + ((F - 1) // self.vae_stride[0] + 1) + * (oh // self.vae_stride[1]) + * (ow // self.vae_stride[2]) + // (self.patch_size[1] * self.patch_size[2]) + ) seq_len = int(math.ceil(seq_len / self.sp_size)) * self.sp_size seed = seed if seed >= 0 else random.randint(0, sys.maxsize) seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) noise = torch.randn( - self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + self.vae.model.z_dim, + (F - 1) // self.vae_stride[0] + 1, oh // self.vae_stride[1], ow // self.vae_stride[2], dtype=torch.float32, generator=seed_g, - device=self.device) + device=self.device, + ) if n_prompt == "": n_prompt = self.sample_neg_prompt - # preprocess - if not self.t5_cpu: - self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if offload_model: - self.text_encoder.model.cpu() + # Text conditioning + if prompt_embeds is not None or negative_prompt_embeds is not None: + if guide_scale is not None and (negative_prompt_embeds is None): + raise ValueError( + "negative_prompt_embeds must be provided when using guidance." + ) + context = prompt_embeds + context_null = ( + negative_prompt_embeds if negative_prompt_embeds is not None else [] + ) + if self.text_encoder is not None: + logging.warning( + "prompt_embeds provided; preventing redundant text encoding." + ) + if hasattr(self.text_encoder, "model"): + self.text_encoder.model.cpu() + self._text_encoder_offloaded = True + target_dtype = torch.float16 if precision == "fp16" else torch.bfloat16 + context = [t.to(dtype=target_dtype, device=self.device) for t in context] + context_null = [ + t.to(dtype=target_dtype, device=self.device) for t in context_null + ] else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) - context = [t.to(self.device) for t in context] - context_null = [t.to(self.device) for t in context_null] + if self.text_encoder is None: + raise RuntimeError( + "text_encoder is not available. Provide prompt_embeds to i2v()." + ) + with torch.inference_mode(): + self.text_encoder.model.eval() + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + else: + context = self.text_encoder([input_prompt], torch.device("cpu")) + context_null = self.text_encoder([n_prompt], torch.device("cpu")) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + target_dtype = torch.float16 if precision == "fp16" else torch.bfloat16 + context = [ + t.to(dtype=target_dtype, device=self.device) for t in context + ] + context_null = [ + t.to(dtype=target_dtype, device=self.device) for t in context_null + ] + try: + if hasattr(torch.cuda, "memory_summary") and torch.cuda.is_available(): + logging.info( + "CUDA memory before T5 offload:\n" + + torch.cuda.memory_summary(device=self.device) + ) + except Exception: + pass + if hasattr(self.text_encoder, "model"): + self.text_encoder.model.cpu() + self._text_encoder_offloaded = True + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + try: + logging.info( + "CUDA memory after T5 offload:\n" + + torch.cuda.memory_summary(device=self.device) + ) + except Exception: + pass z = self.vae.encode([img]) @@ -515,49 +686,51 @@ def i2v(self, def noop_no_sync(): yield - no_sync = getattr(self.model, 'no_sync', noop_no_sync) + no_sync = getattr(self.model, "no_sync", noop_no_sync) # evaluation mode with ( - torch.amp.autocast('cuda', dtype=self.param_dtype), - torch.no_grad(), - no_sync(), + torch.amp.autocast("cuda", dtype=self.param_dtype), + torch.no_grad(), + no_sync(), ): - if sample_solver == 'unipc': + if sample_solver == "unipc": sample_scheduler = FlowUniPCMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, - use_dynamic_shifting=False) + use_dynamic_shifting=False, + ) sample_scheduler.set_timesteps( - sampling_steps, device=self.device, shift=shift) + sampling_steps, device=self.device, shift=shift + ) timesteps = sample_scheduler.timesteps - elif sample_solver == 'dpm++': + elif sample_solver == "dpm++": sample_scheduler = FlowDPMSolverMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, - use_dynamic_shifting=False) + use_dynamic_shifting=False, + ) sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) timesteps, _ = retrieve_timesteps( - sample_scheduler, - device=self.device, - sigmas=sampling_sigmas) + sample_scheduler, device=self.device, sigmas=sampling_sigmas + ) else: raise NotImplementedError("Unsupported solver.") # sample videos latent = noise mask1, mask2 = masks_like([noise], zero=True) - latent = (1. - mask2[0]) * z[0] + mask2[0] * latent + latent = (1.0 - mask2[0]) * z[0] + mask2[0] * latent arg_c = { - 'context': [context[0]], - 'seq_len': seq_len, + "context": [context[0]], + "seq_len": seq_len, } arg_null = { - 'context': context_null, - 'seq_len': seq_len, + "context": context_null, + "seq_len": seq_len, } if offload_model or self.init_on_cpu: @@ -571,31 +744,32 @@ def noop_no_sync(): timestep = torch.stack(timestep).to(self.device) temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten() - temp_ts = torch.cat([ - temp_ts, - temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep - ]) + temp_ts = torch.cat( + [temp_ts, temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep] + ) timestep = temp_ts.unsqueeze(0) - noise_pred_cond = self.model( - latent_model_input, t=timestep, **arg_c)[0] + noise_pred_cond = self.model(latent_model_input, t=timestep, **arg_c)[0] if offload_model: torch.cuda.empty_cache() noise_pred_uncond = self.model( - latent_model_input, t=timestep, **arg_null)[0] + latent_model_input, t=timestep, **arg_null + )[0] if offload_model: torch.cuda.empty_cache() noise_pred = noise_pred_uncond + guide_scale * ( - noise_pred_cond - noise_pred_uncond) + noise_pred_cond - noise_pred_uncond + ) temp_x0 = sample_scheduler.step( noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, - generator=seed_g)[0] + generator=seed_g, + )[0] latent = temp_x0.squeeze(0) - latent = (1. - mask2[0]) * z[0] + mask2[0] * latent + latent = (1.0 - mask2[0]) * z[0] + mask2[0] * latent x0 = [latent] del latent_model_input, timestep