Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -506,15 +506,17 @@ We evaluate video generation on **a single RTX 5090 GPU**. The E2E Time refers t

In this repo, we provide training code based on Wan2.1 and its synthetic data. The training builds on the rCM codebase (https://github.com/NVlabs/rcm), with infrastructure support including FSDP2, Ulysses CP, and selective activation checkpointing (SAC). For rCM training instructions, please refer to the original rCM repository; [SLA (Sparse-Linear Attention)](https://github.com/thu-ml/SLA) training guidance is provided here.

#### Additional Installation
### Additional Installation

For rCM/SLA training, additionally run:

```bash
pip install megatron-core hydra-core wandb webdataset
pip install --no-build-isolation transformer_engine[pytorch]
```

#### Checkpoints Downloading
### Checkpoints Downloading

Download the Wan2.1 pretrained checkpoints in `.pth` format and VAE/text encoder to `assets/checkpoints`:

```bash
Expand All @@ -530,7 +532,7 @@ python -m torch.distributed.checkpoint.format_utils torch_to_dcp assets/checkpoi

After training, the saved `.dcp` checkpoints can be converted to `.pth` using the script `scripts/dcp_to_pth.py`.

#### Dataset Downloading
### Dataset Downloading

We provide Wan2.1-14B-synthesized datasets. Download to `assets/datasets` using:

Expand All @@ -539,7 +541,8 @@ We provide Wan2.1-14B-synthesized datasets. Download to `assets/datasets` using:
git clone https://huggingface.co/datasets/worstcoder/Wan_datasets assets/datasets
```

#### Start Training
### Start Training

We implement white-box SLA training by aligning the predictions of the SLA-enabled model with those of the full-attention pretrained model. Unlike black-box training in the original paper, which tunes the pretrained model using diffusion loss, white-box training mitigates distribution shift and is less sensitive to the training data.

Single-node training example:
Expand Down Expand Up @@ -572,7 +575,7 @@ torchrun --nproc_per_node=8 \

Please refer to `turbodiffusion/rcm/configs/experiments/sla/wan2pt1_t2v.py` for the 14B config or perform modifications as needed.

#### Model Merging
### Model Merging

The parameter updates from SLA training can be merged into rCM checkpoints using `turbodiffusion/scripts/merge_models.py`, enabling rCM to perform sparse attention inference. Specify `--base` as the rCM model, `--diff_base` as the pretrained model, and `--diff_target` as the SLA-tuned model.

Expand Down
1 change: 1 addition & 0 deletions turbodiffusion/imaginaire/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def train(
self.callbacks.on_train_end(model, iteration=iteration)
self.checkpointer.finalize()
distributed.barrier()
distributed.destroy_process_group()
self.callbacks.on_app_end()

def training_step(
Expand Down
6 changes: 6 additions & 0 deletions turbodiffusion/imaginaire/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ def barrier() -> None:
dist.barrier()


def destroy_process_group() -> None:
"""Destroy the distributed process group."""
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()


def rank0_first(func: Callable) -> Callable:
"""run the function on rank 0 first, then on other ranks."""

Expand Down