1111import torch
1212from accelerate import Accelerator
1313from accelerate .utils import set_seed
14- from diffusers import LTXPipeline
14+ from diffusers import LTXImageToVideoPipeline , LTXPipeline
1515from diffusers .utils import export_to_video
1616from loguru import logger
1717from peft import LoraConfig , get_peft_model_state_dict
4949from ltxv_trainer .model_loader import load_ltxv_components
5050from ltxv_trainer .quantization import quantize_model
5151from ltxv_trainer .timestep_samplers import SAMPLERS
52- from ltxv_trainer .utils import get_gpu_memory_gb
52+ from ltxv_trainer .utils import get_gpu_memory_gb , open_image_as_srgb
5353
5454# Disable irrelevant warnings from transformers
5555os .environ ["TOKENIZERS_PARALLELISM" ] = "true"
@@ -645,13 +645,30 @@ def _sample_videos(self, progress: Progress) -> list[Path] | None:
645645 if not self ._config .acceleration .load_text_encoder_in_8bit :
646646 self ._text_encoder .to (self ._accelerator .device )
647647
648- pipeline = LTXPipeline (
649- scheduler = deepcopy (self ._scheduler ),
650- vae = self ._vae ,
651- text_encoder = self ._text_encoder ,
652- tokenizer = self ._tokenizer ,
653- transformer = self ._transformer ,
654- )
648+ use_images = self ._config .validation .images is not None
649+
650+ if use_images :
651+ if len (self ._config .validation .images ) != len (self ._config .validation .prompts ):
652+ raise ValueError (
653+ f"Number of images ({ len (self ._config .validation .images )} ) must match "
654+ f"number of prompts ({ len (self ._config .validation .prompts )} )"
655+ )
656+
657+ pipeline = LTXImageToVideoPipeline (
658+ scheduler = deepcopy (self ._scheduler ),
659+ vae = self ._vae ,
660+ text_encoder = self ._text_encoder ,
661+ tokenizer = self ._tokenizer ,
662+ transformer = self ._transformer ,
663+ )
664+ else :
665+ pipeline = LTXPipeline (
666+ scheduler = deepcopy (self ._scheduler ),
667+ vae = self ._vae ,
668+ text_encoder = self ._text_encoder ,
669+ tokenizer = self ._tokenizer ,
670+ transformer = self ._transformer ,
671+ )
655672 pipeline .set_progress_bar_config (disable = True )
656673
657674 # Create a task in the sampling progress
@@ -665,25 +682,33 @@ def _sample_videos(self, progress: Progress) -> list[Path] | None:
665682
666683 video_paths = []
667684 i = 0
668- for prompt in self ._config .validation .prompts :
685+ for j , prompt in enumerate ( self ._config .validation .prompts ) :
669686 generator = torch .Generator (device = self ._accelerator .device ).manual_seed (self ._config .validation .seed )
670687
671688 # Generate video
672689 width , height , frames = self ._config .validation .video_dims
690+
691+ pipeline_inputs = {
692+ "prompt" : prompt ,
693+ "negative_prompt" : self ._config .validation .negative_prompt ,
694+ "width" : width ,
695+ "height" : height ,
696+ "num_frames" : frames ,
697+ "num_inference_steps" : self ._config .validation .inference_steps ,
698+ "generator" : generator ,
699+ }
700+
701+ if use_images :
702+ image_path = self ._config .validation .images [j ]
703+ pipeline_inputs ["image" ] = open_image_as_srgb (image_path )
704+
673705 with autocast (self ._accelerator .device .type , dtype = torch .bfloat16 ):
674- videos = pipeline (
675- prompt = prompt ,
676- negative_prompt = self ._config .validation .negative_prompt ,
677- width = width ,
678- height = height ,
679- num_frames = frames ,
680- num_inference_steps = self ._config .validation .inference_steps ,
681- generator = generator ,
682- ).frames
706+ result = pipeline (** pipeline_inputs )
707+ videos = result .frames
683708
684709 for video in videos :
685710 video_path = output_dir / f"step_{ self ._global_step :06d} _{ i } .mp4"
686- export_to_video (video , str (video_path ), fps = 24 )
711+ export_to_video (video , str (video_path ), fps = 25 )
687712 video_paths .append (video_path )
688713 i += 1
689714 progress .update (task , advance = 1 )
0 commit comments