Skip to content

Commit c4bb626

Browse files
authored
fix dataloader deadlock bug (#225)
* update README.md; Signed-off-by: lawrence-cj <[email protected]> * fix dataloader deadlock bug; Signed-off-by: lawrence-cj <[email protected]> * tmp change; Signed-off-by: lawrence-cj <[email protected]> * tmp change; Signed-off-by: lawrence-cj <[email protected]> * tmp change; Signed-off-by: lawrence-cj <[email protected]> * pre-commit; Signed-off-by: lawrence-cj <[email protected]> * Revert "tmp change;" This reverts commit 2da18bc. --------- Signed-off-by: lawrence-cj <[email protected]>
1 parent b628453 commit c4bb626

File tree

3 files changed

+22
-17
lines changed

3 files changed

+22
-17
lines changed

README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
<a href="https://nvlabs.github.io/Sana/"><img src="https://img.shields.io/static/v1?label=Project&message=Github&color=blue&logo=github-pages"></a> &ensp;
1111
<a href="https://hanlab.mit.edu/projects/sana/"><img src="https://img.shields.io/static/v1?label=Page&message=MIT&color=darkred&logo=github-pages"></a> &ensp;
1212
<a href="https://arxiv.org/abs/2410.10629"><img src="https://img.shields.io/static/v1?label=Arxiv&message=Sana&color=red&logo=arxiv"></a> &ensp;
13-
<a href="https://nv-sana.mit.edu/"><img src="https://img.shields.io/static/v1?label=Demo:5x3090&message=SANA&color=yellow"></a> &ensp;
14-
<a href="https://nv-sana.mit.edu/sprint/"><img src="https://img.shields.io/static/v1?label=Demo:1x3090&message=SANA-Sprint&color=yellow"></a> &ensp;
13+
<a href="https://nv-sana.mit.edu/"><img src="https://img.shields.io/static/v1?label=Demo:6x3090&message=MIT&color=yellow"></a> &ensp;
1514
<a href="https://nv-sana.mit.edu/4bit/"><img src="https://img.shields.io/static/v1?label=Demo:1x3090&message=4bit&color=yellow"></a> &ensp;
1615
<a href="https://nv-sana.mit.edu/ctrlnet/"><img src="https://img.shields.io/static/v1?label=Demo:1x3090&message=ControlNet&color=yellow"></a> &ensp;
1716
<a href="https://replicate.com/chenxwh/sana"><img src="https://img.shields.io/static/v1?label=API:H100&message=Replicate&color=pink"></a> &ensp;
@@ -26,7 +25,7 @@
2625

2726
### 🚶 Basic:
2827

29-
**Demo**: [SANA-1.5](https://nv-sana.mit.edu/) | [SANA-ControlNet](https://nv-sana.mit.edu/ctrlnet/) | [SANA-4bit](https://nv-sana.mit.edu/4bit/) | [SANA-Sprint](https://nv-sana.mit.edu/sprint/) <br>
28+
**Demo**: [SANA-1.5](https://nv-sana.mit.edu/) | [SANA-ControlNet](https://nv-sana.mit.edu/ctrlnet/) | [SANA-4bit](https://nv-sana.mit.edu/4bit/) | [SANA-Sprint (Coming)](<>) <br>
3029
**ComfyUI**: [ComfyUI Guidance](asset/docs/ComfyUI/comfyui.md) <br>
3130
**Model Zoo:** [Model Card Collects All Models](asset/docs/model_zoo.md) <br>
3231
**Env Preparation:** [One-Click Env Install](#-1-dependencies-and-installation) <br>
@@ -183,8 +182,8 @@ cd Sana
183182
DEMO_PORT=15432 \
184183
python app/app_sana.py \
185184
--share \
186-
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
187-
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth \
185+
--config=hf://Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth \
186+
--model_path=hf://Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth \
188187
--image_size=1024
189188
```
190189

tests/bash/entry.sh

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#/bin/bash
22
set -e
33

4-
for t in tests/bash/test_*.sh; do
5-
echo "========================== Testing $t =================================="
6-
bash $t;
7-
done
4+
5+
#echo "Testing inference"
6+
#bash tests/bash/test_inference.sh
7+
8+
echo "Testing training"
9+
bash tests/bash/test_training_1epoch.sh

train_scripts/train.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,7 @@ def train(
269269
skip_step = max(config.train.skip_step, global_step) % train_dataloader_len
270270
skip_step = skip_step if skip_step < (train_dataloader_len - 20) else 0
271271
loss_nan_timer = 0
272-
273-
if config.train.use_fsdp:
274-
model_instance = model
275-
elif model_ema is not None:
276-
model_instance = model_ema
277-
else:
278-
model_instance = model
272+
model_instance.to(accelerator.device)
279273

280274
# Cache Dataset for BatchSampler
281275
if args.caching and config.model.multi_scale:
@@ -542,9 +536,11 @@ def train(
542536
merged_state_dict = accelerator.get_state_dict(model)
543537

544538
accelerator.wait_for_everyone()
539+
print(rank, 111111)
545540
if accelerator.is_main_process:
546541
if config.train.use_fsdp:
547542
model_instance.load_state_dict(merged_state_dict)
543+
print(rank, 222222)
548544
if validation_noise is not None:
549545
log_validation(
550546
accelerator=accelerator,
@@ -567,6 +563,7 @@ def train(
567563
vae=vae,
568564
)
569565

566+
print(rank, 333333)
570567
# avoid dead-lock of multiscale data batch sampler
571568
if (
572569
config.model.multi_scale
@@ -629,7 +626,7 @@ def main(cfg: SanaConfig) -> None:
629626
global train_dataloader_len, start_epoch, start_step, vae, generator, num_replicas, rank, training_start_time
630627
global load_vae_feat, load_text_feat, validation_noise, text_encoder, tokenizer
631628
global max_length, validation_prompts, latent_size, valid_prompt_embed_suffix, null_embed_path
632-
global image_size, cache_file, total_steps, vae_dtype
629+
global image_size, cache_file, total_steps, vae_dtype, model_instance
633630

634631
config = cfg
635632
args = cfg
@@ -870,6 +867,13 @@ def main(cfg: SanaConfig) -> None:
870867
)
871868
)
872869

870+
if config.train.use_fsdp:
871+
model_instance = deepcopy(model)
872+
elif model_ema is not None:
873+
model_instance = deepcopy(model_ema)
874+
else:
875+
model_instance = model
876+
873877
# 4-1. load model
874878
if args.load_from is not None:
875879
config.model.load_from = args.load_from

0 commit comments

Comments
 (0)