Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/ltxv_2b_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ validation:
- "a professional portrait video of a person with blurry bokeh background"
- "a video of a person wearing a nice suit"
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
images: null # Set to a list of image paths to use first-frame conditioning, or null to disable
video_dims: [768, 448, 89] # [width, height, frames]
seed: 42
inference_steps: 50
Expand Down
1 change: 1 addition & 0 deletions configs/ltxv_2b_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ validation:
- "a professional portrait video of a person with blurry bokeh background"
- "a video of a person wearing a nice suit"
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
images: null # Set to a list of image paths to use first-frame conditioning, or null to disable
video_dims: [768, 448, 89] # [width, height, frames]
seed: 42
inference_steps: 50
Expand Down
1 change: 1 addition & 0 deletions configs/ltxv_2b_lora_low_vram.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ validation:
- "a professional portrait video of a person with blurry bokeh background"
- "a video of a person wearing a nice suit"
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
images: null # Set to a list of image paths to use first-frame conditioning, or null to disable
video_dims: [768, 448, 89] # [width, height, frames]
seed: 42
inference_steps: 50
Expand Down
1 change: 1 addition & 0 deletions configs/ltxv_2b_lora_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ validation:
- "a professional portrait video of a person with blurry bokeh background"
- "a video of a person wearing a nice suit"
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
images: null # Set to a list of image paths to use first-frame conditioning, or null to disable
video_dims: [768, 448, 89] # [width, height, frames]
seed: 42
inference_steps: 30
Expand Down
6 changes: 6 additions & 0 deletions src/ltxv_trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ class ValidationConfig(ConfigBaseModel):
description="Negative prompt to use for validation examples",
)

images: list[str] | None = Field(
default=None,
description="List of image paths to use for validation. "
"One image path must be provided for each validation prompt",
)

video_dims: tuple[int, int, int] = Field(
default=(704, 480, 161),
description="Dimensions of validation videos (width, height, frames)",
Expand Down
65 changes: 45 additions & 20 deletions src/ltxv_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers import LTXPipeline
from diffusers import LTXImageToVideoPipeline, LTXPipeline
from diffusers.utils import export_to_video
from loguru import logger
from peft import LoraConfig, get_peft_model_state_dict
Expand Down Expand Up @@ -49,7 +49,7 @@
from ltxv_trainer.model_loader import load_ltxv_components
from ltxv_trainer.quantization import quantize_model
from ltxv_trainer.timestep_samplers import SAMPLERS
from ltxv_trainer.utils import get_gpu_memory_gb
from ltxv_trainer.utils import get_gpu_memory_gb, open_image_as_srgb

# Disable irrelevant warnings from transformers
os.environ["TOKENIZERS_PARALLELISM"] = "true"
Expand Down Expand Up @@ -645,13 +645,30 @@ def _sample_videos(self, progress: Progress) -> list[Path] | None:
if not self._config.acceleration.load_text_encoder_in_8bit:
self._text_encoder.to(self._accelerator.device)

pipeline = LTXPipeline(
scheduler=deepcopy(self._scheduler),
vae=self._vae,
text_encoder=self._text_encoder,
tokenizer=self._tokenizer,
transformer=self._transformer,
)
use_images = self._config.validation.images is not None

if use_images:
if len(self._config.validation.images) != len(self._config.validation.prompts):
raise ValueError(
f"Number of images ({len(self._config.validation.images)}) must match "
f"number of prompts ({len(self._config.validation.prompts)})"
)

pipeline = LTXImageToVideoPipeline(
scheduler=deepcopy(self._scheduler),
vae=self._vae,
text_encoder=self._text_encoder,
tokenizer=self._tokenizer,
transformer=self._transformer,
)
else:
pipeline = LTXPipeline(
scheduler=deepcopy(self._scheduler),
vae=self._vae,
text_encoder=self._text_encoder,
tokenizer=self._tokenizer,
transformer=self._transformer,
)
pipeline.set_progress_bar_config(disable=True)

# Create a task in the sampling progress
Expand All @@ -665,25 +682,33 @@ def _sample_videos(self, progress: Progress) -> list[Path] | None:

video_paths = []
i = 0
for prompt in self._config.validation.prompts:
for j, prompt in enumerate(self._config.validation.prompts):
generator = torch.Generator(device=self._accelerator.device).manual_seed(self._config.validation.seed)

# Generate video
width, height, frames = self._config.validation.video_dims

pipeline_inputs = {
"prompt": prompt,
"negative_prompt": self._config.validation.negative_prompt,
"width": width,
"height": height,
"num_frames": frames,
"num_inference_steps": self._config.validation.inference_steps,
"generator": generator,
}

if use_images:
image_path = self._config.validation.images[j]
pipeline_inputs["image"] = open_image_as_srgb(image_path)

with autocast(self._accelerator.device.type, dtype=torch.bfloat16):
videos = pipeline(
prompt=prompt,
negative_prompt=self._config.validation.negative_prompt,
width=width,
height=height,
num_frames=frames,
num_inference_steps=self._config.validation.inference_steps,
generator=generator,
).frames
result = pipeline(**pipeline_inputs)
videos = result.frames

for video in videos:
video_path = output_dir / f"step_{self._global_step:06d}_{i}.mp4"
export_to_video(video, str(video_path), fps=24)
export_to_video(video, str(video_path), fps=25)
video_paths.append(video_path)
i += 1
progress.update(task, advance=1)
Expand Down