Skip to content

Commit 1695d36

Browse files
linoytsabanNaomi-Ken-Korem
authored andcommitted
push to hub only in the end of the training, push also comfyui compatible file
1 parent f05e60a commit 1695d36

File tree

7 files changed

+317
-190
lines changed

7 files changed

+317
-190
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,27 @@ The trainer loads your configuration, initializes models, applies optimizations,
302302
For LoRA training, the weights will be saved as `lora_weights.safetensors` in your output directory.
303303
For full model fine-tuning, the weights will be saved as `model_weights.safetensors`.
304304

305+
### 🤗 Pushing Models to Hugging Face Hub
306+
307+
You can automatically push your trained models to the Hugging Face Hub by adding the following to your configuration YAML:
308+
309+
```yaml
310+
hub:
311+
push_to_hub: true
312+
hub_model_id: "your-username/your-model-name" # Your HF username and desired repo name
313+
```
314+
315+
Before pushing, make sure you:
316+
1. Have a Hugging Face account
317+
2. Are logged in via `huggingface-cli login` or have set the `HUGGING_FACE_HUB_TOKEN` environment variable
318+
3. Have write access to the specified repository (it will be created if it doesn't exist)
319+
320+
The trainer will:
321+
- Create a model card with training details and sample outputs
322+
- Upload the model weights (both original and ComfyUI-compatible versions)
323+
- Push sample videos as GIFs in the model card
324+
- Include training configuration and prompts
325+
305326
---
306327

307328
## Fast and simple: Running the Complete Pipeline as one command

src/ltxv_trainer/config.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
22
from typing import Literal
33

4-
from pydantic import BaseModel, ConfigDict, Field, field_validator
4+
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
55

66
from ltxv_trainer.model_loader import LtxvModelVersion
77
from ltxv_trainer.quantization import QuantizationOptions
@@ -249,18 +249,17 @@ class CheckpointsConfig(ConfigBaseModel):
249249
class HubConfig(ConfigBaseModel):
250250
"""Configuration for Hugging Face Hub integration"""
251251

252-
push_to_hub: bool = Field(
253-
default=False,
254-
description="Whether to push the model weights to the Hugging Face Hub"
255-
)
256-
hub_model_id: str = Field(
257-
default=None,
258-
description="Hugging Face Hub repository ID (e.g., 'username/repo-name')"
259-
)
260-
hub_token: str = Field(
261-
default=None,
262-
description="Hugging Face token. If None, will use the token from the Hugging Face CLI"
263-
)
252+
push_to_hub: bool = Field(default=False, description="Whether to push the model weights to the Hugging Face Hub")
253+
hub_model_id: str = Field(default=None, description="Hugging Face Hub repository ID (e.g., 'username/repo-name')")
254+
255+
@field_validator("hub_model_id")
256+
@classmethod
257+
def validate_hub_model_id(cls, v: str | None, info: ValidationInfo) -> str | None:
258+
"""Validate that hub_model_id is not None when push_to_hub is True."""
259+
if info.data.get("push_to_hub", False) and v is None:
260+
raise ValueError("hub_model_id must be specified when push_to_hub is True")
261+
return v
262+
264263

265264
class FlowMatchingConfig(ConfigBaseModel):
266265
"""Configuration for flow matching training"""

src/ltxv_trainer/hub_utils.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
import tempfile
2+
from pathlib import Path
3+
from typing import List, Union
4+
5+
import imageio
6+
from huggingface_hub import HfApi, create_repo
7+
from loguru import logger
8+
9+
from ltxv_trainer.config import LtxvTrainerConfig
10+
from ltxv_trainer.model_loader import try_parse_version
11+
from scripts.convert_checkpoint import convert_checkpoint
12+
13+
14+
def convert_video_to_gif(video_path: Path, output_path: Path) -> None:
15+
"""Convert a video file to GIF format."""
16+
try:
17+
# Read the video file
18+
reader = imageio.get_reader(str(video_path))
19+
fps = reader.get_meta_data()["fps"]
20+
21+
# Write GIF file with infinite loop
22+
writer = imageio.get_writer(
23+
str(output_path),
24+
fps=min(fps, 15), # Cap FPS at 15 for reasonable file size
25+
loop=0, # 0 means infinite loop
26+
)
27+
28+
for frame in reader:
29+
writer.append_data(frame)
30+
31+
writer.close()
32+
reader.close()
33+
except Exception as e:
34+
logger.error(f"Failed to convert video to GIF: {e}")
35+
return None
36+
37+
38+
def create_model_card(
39+
output_dir: Union[str, Path],
40+
videos: List[Path],
41+
config: LtxvTrainerConfig,
42+
) -> Path:
43+
"""Generate and save a model card for the trained model."""
44+
45+
repo_id = config.hub.hub_model_id
46+
pretrained_model_name_or_path = config.model.model_source
47+
validation_prompts = config.validation.prompts
48+
output_dir = Path(output_dir)
49+
template_path = Path(__file__).parent.parent.parent / "templates" / "model_card.md"
50+
51+
if not template_path.exists():
52+
logger.warning("⚠️ Model card template not found, using default template")
53+
return
54+
55+
# Read the template
56+
template = template_path.read_text()
57+
58+
# Get model name from repo_id
59+
model_name = repo_id.split("/")[-1]
60+
61+
# Get base model information
62+
version = try_parse_version(pretrained_model_name_or_path)
63+
if version:
64+
base_model_link = version.safetensors_url
65+
base_model_name = str(version)
66+
else:
67+
base_model_link = f"https://huggingface.co/{pretrained_model_name_or_path}"
68+
base_model_name = pretrained_model_name_or_path
69+
70+
# Format validation prompts and create grid layout
71+
prompts_text = ""
72+
sample_grid = []
73+
74+
if validation_prompts and videos:
75+
prompts_text = "Example prompts used during validation:\n\n"
76+
77+
# Create samples directory
78+
samples_dir = output_dir / "samples"
79+
samples_dir.mkdir(exist_ok=True, parents=True)
80+
81+
# Process videos and create cells
82+
cells = []
83+
for i, (prompt, video) in enumerate(zip(validation_prompts, videos, strict=False)):
84+
if video.exists():
85+
# Add prompt to text section
86+
prompts_text += f"- `{prompt}`\n"
87+
88+
# Convert video to GIF
89+
gif_path = samples_dir / f"sample_{i}.gif"
90+
try:
91+
convert_video_to_gif(video, gif_path)
92+
93+
# Create grid cell with collapsible description
94+
cell = (
95+
f"![example{i + 1}](./samples/sample_{i}.gif)"
96+
"<br>"
97+
'<details style="max-width: 300px; margin: auto;">'
98+
f"<summary>Prompt</summary>"
99+
f"{prompt}"
100+
"</details>"
101+
)
102+
cells.append(cell)
103+
except Exception as e:
104+
logger.error(f"Failed to process video {video}: {e}")
105+
106+
# Calculate optimal grid dimensions
107+
num_cells = len(cells)
108+
if num_cells > 0:
109+
# Aim for a roughly square grid, with max 4 columns
110+
num_cols = min(4, num_cells)
111+
num_rows = (num_cells + num_cols - 1) // num_cols # Ceiling division
112+
113+
# Create grid rows
114+
for row in range(num_rows):
115+
start_idx = row * num_cols
116+
end_idx = min(start_idx + num_cols, num_cells)
117+
row_cells = cells[start_idx:end_idx]
118+
# Properly format the row with table markers and exact number of cells
119+
formatted_row = "| " + " | ".join(row_cells) + " |"
120+
sample_grid.append(formatted_row)
121+
122+
# Join grid rows with just the content, no headers needed
123+
grid_text = "\n".join(sample_grid) if sample_grid else ""
124+
125+
# Fill in the template
126+
model_card_content = template.format(
127+
base_model=base_model_name,
128+
base_model_link=base_model_link,
129+
model_name=model_name,
130+
training_type="LoRA fine-tuning" if config.model.training_mode == "lora" else "Full model fine-tuning",
131+
training_steps=config.optimization.steps,
132+
learning_rate=config.optimization.learning_rate,
133+
batch_size=config.optimization.batch_size,
134+
validation_prompts=prompts_text,
135+
sample_grid=grid_text,
136+
)
137+
138+
# Save the model card directly
139+
model_card_path = output_dir / "README.md"
140+
model_card_path.write_text(model_card_content)
141+
142+
return model_card_path
143+
144+
145+
def push_to_hub(weights_path: Path, sampled_videos_paths: List[Path], config: LtxvTrainerConfig) -> None:
146+
"""Push the trained LoRA weights to HuggingFace Hub."""
147+
if not config.hub.push_to_hub:
148+
return
149+
150+
if not config.hub.hub_model_id:
151+
logger.warning("⚠️ HuggingFace hub_model_id not specified, skipping push to hub")
152+
return
153+
154+
api = HfApi()
155+
156+
# Try to create repo if it doesn't exist
157+
try:
158+
create_repo(
159+
repo_id=config.hub.hub_model_id,
160+
repo_type="model",
161+
exist_ok=True, # Don't raise error if repo exists
162+
)
163+
except Exception as e:
164+
logger.error(f"❌ Failed to create repository: {e}")
165+
return
166+
167+
# Upload the original weights file
168+
try:
169+
api.upload_file(
170+
path_or_fileobj=str(weights_path),
171+
path_in_repo=weights_path.name,
172+
repo_id=config.hub.hub_model_id,
173+
repo_type="model",
174+
)
175+
except Exception as e:
176+
logger.error(f"❌ Failed to push {weights_path.name} to HuggingFace Hub: {e}")
177+
# Create a temporary directory for the files we want to upload
178+
with tempfile.TemporaryDirectory() as temp_dir:
179+
temp_path = Path(temp_dir)
180+
181+
try:
182+
# Save model card and copy videos to temp directory
183+
create_model_card(
184+
output_dir=temp_path,
185+
videos=sampled_videos_paths,
186+
config=config,
187+
)
188+
189+
# Upload the model card and samples directory
190+
api.upload_folder(
191+
folder_path=str(temp_path), # Convert to string for compatibility
192+
repo_id=config.hub.hub_model_id,
193+
repo_type="model",
194+
)
195+
196+
logger.info(f"✅ Successfully uploaded model card and sample videos to {config.hub.hub_model_id}")
197+
except Exception as e:
198+
logger.error(f"❌ Failed to save/upload model card and videos: {e}")
199+
200+
logger.info(f"✅ Successfully pushed original LoRA weights to {config.hub.hub_model_id}")
201+
202+
# Convert and upload ComfyUI version
203+
try:
204+
# Create a temporary directory for the converted file
205+
with tempfile.TemporaryDirectory() as temp_dir:
206+
# Convert the weights to ComfyUI format
207+
comfy_path = Path(temp_dir) / f"{weights_path.stem}_comfy{weights_path.suffix}"
208+
209+
convert_checkpoint(
210+
input_path=str(weights_path),
211+
to_comfy=True,
212+
output_path=str(comfy_path),
213+
)
214+
215+
# Find the converted file
216+
converted_files = list(Path(temp_dir).glob("*.safetensors"))
217+
if not converted_files:
218+
logger.warning("⚠️ No converted ComfyUI weights found")
219+
return
220+
221+
converted_file = converted_files[0]
222+
comfy_filename = f"comfyui_{weights_path.name}"
223+
224+
# Upload the converted file
225+
api.upload_file(
226+
path_or_fileobj=str(converted_file),
227+
path_in_repo=comfy_filename,
228+
repo_id=config.hub.hub_model_id,
229+
repo_type="model",
230+
)
231+
logger.info(f"✅ Successfully pushed ComfyUI LoRA weights to {config.hub.hub_model_id}")
232+
233+
except Exception as e:
234+
logger.error(f"❌ Failed to convert and push ComfyUI version: {e}")

src/ltxv_trainer/model_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def load_vae(
160160
"""
161161
if isinstance(source, str): # noqa: SIM102
162162
# Try to parse as version first
163-
if version := _try_parse_version(source):
163+
if version := try_parse_version(source):
164164
source = version
165165

166166
if isinstance(source, LtxvModelVersion):
@@ -217,7 +217,7 @@ def load_transformer(
217217
"""
218218
if isinstance(source, str): # noqa: SIM102
219219
# Try to parse as version first
220-
if version := _try_parse_version(source):
220+
if version := try_parse_version(source):
221221
source = version
222222

223223
if isinstance(source, LtxvModelVersion):
@@ -285,7 +285,7 @@ def load_ltxv_components(
285285
)
286286

287287

288-
def _try_parse_version(source: str | Path) -> LtxvModelVersion | None:
288+
def try_parse_version(source: str | Path) -> LtxvModelVersion | None:
289289
"""
290290
Try to parse a string as an LtxvModelVersion.
291291

0 commit comments

Comments
 (0)