Skip to content

Commit 190d97e

Browse files
authored
Add support for I2V during validation (#11)
* add support I2V during validation
1 parent e3f7455 commit 190d97e

File tree

6 files changed

+55
-20
lines changed

6 files changed

+55
-20
lines changed

configs/ltxv_2b_full.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ validation:
3838
- "a professional portrait video of a person with blurry bokeh background"
3939
- "a video of a person wearing a nice suit"
4040
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
41+
images: null # Set to a list of image paths to use first-frame conditioning, or null to disable
4142
video_dims: [768, 448, 89] # [width, height, frames]
4243
seed: 42
4344
inference_steps: 50

configs/ltxv_2b_lora.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ validation:
4949
- "a professional portrait video of a person with blurry bokeh background"
5050
- "a video of a person wearing a nice suit"
5151
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
52+
images: null # Set to a list of image paths to use first-frame conditioning, or null to disable
5253
video_dims: [768, 448, 89] # [width, height, frames]
5354
seed: 42
5455
inference_steps: 50

configs/ltxv_2b_lora_low_vram.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ validation:
4949
- "a professional portrait video of a person with blurry bokeh background"
5050
- "a video of a person wearing a nice suit"
5151
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
52+
images: null # Set to a list of image paths to use first-frame conditioning, or null to disable
5253
video_dims: [768, 448, 89] # [width, height, frames]
5354
seed: 42
5455
inference_steps: 50

configs/ltxv_2b_lora_template.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ validation:
4848
- "a professional portrait video of a person with blurry bokeh background"
4949
- "a video of a person wearing a nice suit"
5050
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
51+
images: null # Set to a list of image paths to use first-frame conditioning, or null to disable
5152
video_dims: [768, 448, 89] # [width, height, frames]
5253
seed: 42
5354
inference_steps: 30

src/ltxv_trainer/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ class ValidationConfig(ConfigBaseModel):
189189
description="Negative prompt to use for validation examples",
190190
)
191191

192+
images: list[str] | None = Field(
193+
default=None,
194+
description="List of image paths to use for validation. "
195+
"One image path must be provided for each validation prompt",
196+
)
197+
192198
video_dims: tuple[int, int, int] = Field(
193199
default=(704, 480, 161),
194200
description="Dimensions of validation videos (width, height, frames)",

src/ltxv_trainer/trainer.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from accelerate import Accelerator
1313
from accelerate.utils import set_seed
14-
from diffusers import LTXPipeline
14+
from diffusers import LTXImageToVideoPipeline, LTXPipeline
1515
from diffusers.utils import export_to_video
1616
from loguru import logger
1717
from peft import LoraConfig, get_peft_model_state_dict
@@ -49,7 +49,7 @@
4949
from ltxv_trainer.model_loader import load_ltxv_components
5050
from ltxv_trainer.quantization import quantize_model
5151
from ltxv_trainer.timestep_samplers import SAMPLERS
52-
from ltxv_trainer.utils import get_gpu_memory_gb
52+
from ltxv_trainer.utils import get_gpu_memory_gb, open_image_as_srgb
5353

5454
# Disable irrelevant warnings from transformers
5555
os.environ["TOKENIZERS_PARALLELISM"] = "true"
@@ -645,13 +645,30 @@ def _sample_videos(self, progress: Progress) -> list[Path] | None:
645645
if not self._config.acceleration.load_text_encoder_in_8bit:
646646
self._text_encoder.to(self._accelerator.device)
647647

648-
pipeline = LTXPipeline(
649-
scheduler=deepcopy(self._scheduler),
650-
vae=self._vae,
651-
text_encoder=self._text_encoder,
652-
tokenizer=self._tokenizer,
653-
transformer=self._transformer,
654-
)
648+
use_images = self._config.validation.images is not None
649+
650+
if use_images:
651+
if len(self._config.validation.images) != len(self._config.validation.prompts):
652+
raise ValueError(
653+
f"Number of images ({len(self._config.validation.images)}) must match "
654+
f"number of prompts ({len(self._config.validation.prompts)})"
655+
)
656+
657+
pipeline = LTXImageToVideoPipeline(
658+
scheduler=deepcopy(self._scheduler),
659+
vae=self._vae,
660+
text_encoder=self._text_encoder,
661+
tokenizer=self._tokenizer,
662+
transformer=self._transformer,
663+
)
664+
else:
665+
pipeline = LTXPipeline(
666+
scheduler=deepcopy(self._scheduler),
667+
vae=self._vae,
668+
text_encoder=self._text_encoder,
669+
tokenizer=self._tokenizer,
670+
transformer=self._transformer,
671+
)
655672
pipeline.set_progress_bar_config(disable=True)
656673

657674
# Create a task in the sampling progress
@@ -665,25 +682,33 @@ def _sample_videos(self, progress: Progress) -> list[Path] | None:
665682

666683
video_paths = []
667684
i = 0
668-
for prompt in self._config.validation.prompts:
685+
for j, prompt in enumerate(self._config.validation.prompts):
669686
generator = torch.Generator(device=self._accelerator.device).manual_seed(self._config.validation.seed)
670687

671688
# Generate video
672689
width, height, frames = self._config.validation.video_dims
690+
691+
pipeline_inputs = {
692+
"prompt": prompt,
693+
"negative_prompt": self._config.validation.negative_prompt,
694+
"width": width,
695+
"height": height,
696+
"num_frames": frames,
697+
"num_inference_steps": self._config.validation.inference_steps,
698+
"generator": generator,
699+
}
700+
701+
if use_images:
702+
image_path = self._config.validation.images[j]
703+
pipeline_inputs["image"] = open_image_as_srgb(image_path)
704+
673705
with autocast(self._accelerator.device.type, dtype=torch.bfloat16):
674-
videos = pipeline(
675-
prompt=prompt,
676-
negative_prompt=self._config.validation.negative_prompt,
677-
width=width,
678-
height=height,
679-
num_frames=frames,
680-
num_inference_steps=self._config.validation.inference_steps,
681-
generator=generator,
682-
).frames
706+
result = pipeline(**pipeline_inputs)
707+
videos = result.frames
683708

684709
for video in videos:
685710
video_path = output_dir / f"step_{self._global_step:06d}_{i}.mp4"
686-
export_to_video(video, str(video_path), fps=24)
711+
export_to_video(video, str(video_path), fps=25)
687712
video_paths.append(video_path)
688713
i += 1
689714
progress.update(task, advance=1)

0 commit comments

Comments
 (0)