Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f8a0016
Adding disabled autocast on bce_loss
tango4j Aug 28, 2025
80e1f7f
Adding Sortformer BF16 inference
tango4j Sep 3, 2025
c5cd66f
Adding BF16 inference and adding a config
tango4j Sep 3, 2025
051bef0
Merge remote-tracking branch 'origin/main' into bf16_sortformer_train
tango4j Sep 3, 2025
90e251d
Apply isort and black reformatting
tango4j Sep 3, 2025
fc32917
Merge branch 'main' into bf16_sortformer_train
tango4j Sep 3, 2025
132675c
Adding bf16-mixed option for both training and inference
tango4j Sep 4, 2025
b47cc14
Adding bf16-mixed option for both training and inference
tango4j Sep 4, 2025
51dac61
Apply isort and black reformatting
tango4j Sep 4, 2025
51e2c77
Adding bf16-mixed option for e2e_diarize_speech.py
tango4j Sep 4, 2025
e1fd6c1
Resolving conflict
tango4j Sep 4, 2025
7d33e38
Apply isort and black reformatting
tango4j Sep 4, 2025
d5ef771
Merge branch 'main' into bf16_sortformer_train
tango4j Sep 4, 2025
cb6badb
Merge branch 'main' into bf16_sortformer_train
tango4j Sep 4, 2025
a647bc9
adding precision item to yaml files
tango4j Sep 4, 2025
29e2e18
Merge branch 'bf16_sortformer_train' of https://github.com/tango4j/Ne…
tango4j Sep 4, 2025
d1bccee
Merge branch 'main' into bf16_sortformer_train
tango4j Sep 4, 2025
73f3510
Adding bf16 description for offline and streaming training
tango4j Sep 5, 2025
8e9420c
Merge branch 'bf16_sortformer_train' of https://github.com/tango4j/Ne…
tango4j Sep 5, 2025
1e91c0c
Merge branch 'main' into bf16_sortformer_train
tango4j Sep 5, 2025
4b4b8b7
Merge branch 'main' into bf16_sortformer_train
chtruong814 Sep 5, 2025
7585c1a
Merge branch 'main' into bf16_sortformer_train
tango4j Sep 5, 2025
7e7fbe9
Merge branch 'main' into bf16_sortformer_train
chtruong814 Sep 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ model:
trainer:
devices: 1 # number of gpus (devices)
accelerator: gpu
precision: 32 # 32, bf16, bf16-mixed
max_epochs: 800
max_steps: -1 # computed at runtime if not set
num_nodes: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ model:

trainer:
devices: 1 # number of gpus (devices)
accelerator: gpu
accelerator: gpu
precision: 32 # 32, bf16, bf16-mixed
max_epochs: 800
max_steps: -1 # computed at runtime if not set
num_nodes: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class DiarizationConfig:
no_der: bool = False
out_rttm_dir: Optional[str] = None
save_preds_tensors: bool = False
precision: str = "32" # 32, bf16, bf16-mixed

# General configs
session_len_sec: float = -1 # End-to-end diarization session length in seconds
Expand Down Expand Up @@ -346,10 +347,13 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]:
raise ValueError("cfg.model_path must end with.ckpt or.nemo!")

diar_model._cfg.test_ds.session_len_sec = cfg.session_len_sec
trainer = pl.Trainer(devices=device, accelerator=accelerator)
trainer = pl.Trainer(devices=device, accelerator=accelerator, precision=cfg.precision)
diar_model.set_trainer(trainer)

diar_model = diar_model.eval()
if torch.cuda.is_bf16_supported() and cfg.precision.startswith("bf16"):
diar_model = diar_model.to(dtype=torch.bfloat16).eval()
else:
diar_model = diar_model.eval()

if cfg.presort_manifest:
audio_key = cfg.get('audio_key', 'audio_filepath')
Expand Down Expand Up @@ -405,7 +409,9 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]:
diar_model_preds_total_list = torch.load(tensor_path)
else:
logging.info("No saved prediction tensors found. Running inference on the dataset...")
diar_model.test_batch()
with torch.inference_mode(), torch.autocast(device_type=diar_model.device.type, dtype=diar_model.dtype):
diar_model.test_batch()

diar_model_preds_total_list = diar_model.preds_total_list
if cfg.save_preds_tensors:
torch.save(diar_model.preds_total_list, tensor_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@

"""
Example training session (single node training)
For training, you can use the following precisions: 32, bf16 and bf16-mixed.
You can train with a larger batch size using BF16 mixed precision.

python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' \
--config-name='sortformer_diarizer_hybrid_loss_4spk-v1.yaml' \
trainer.precision='bf16' \
trainer.devices=1 \
model.train_ds.manifest_filepath="<train_manifest_path>" \
model.validation_ds.manifest_filepath="<dev_manifest_path>" \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@

"""
Example training session (single node training)
For training, you can use the following precisions: 32, bf16 and bf16-mixed.
You can train with a larger batch size using BF16 mixed precision.

python ./streaming_sortformer_diar_train.py --config-path='../conf/neural_diarizer' \
--config-name='streaming_sortformer_diarizer_4spk-v2.yaml' \
trainer.precision='bf16' \
trainer.devices=1 \
model.train_ds.manifest_filepath="<train_manifest_path>" \
model.validation_ds.manifest_filepath="<dev_manifest_path>" \
Expand Down
19 changes: 10 additions & 9 deletions nemo/collections/asr/losses/bce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
self.eps = 1e-6

@typecheck()
def forward(self, probs, labels, target_lens):
def forward(self, probs, labels, target_lens, enable_autocast=False):
"""
Calculate binary cross entropy loss based on probs, labels and target_lens variables.

Expand Down Expand Up @@ -123,13 +123,14 @@ def forward(self, probs, labels, target_lens):
binary_weight = torch.ones_like(labels).detach().clone()
norm_weight = torch.ones_like(labels).detach().clone()

if self.reduction == 'sum':
loss = self.loss_f(probs, labels)
elif self.reduction == 'mean':
loss = self.loss_f(probs, labels).mean()
elif self.reduction == 'none':
if self.class_normalization in ['class', 'class_binary', 'binary']:
loss = (binary_weight * norm_weight * self.loss_f(probs, labels)).sum()
else:
with torch.cuda.amp.autocast(enabled=enable_autocast):
if self.reduction == 'sum':
loss = self.loss_f(probs, labels)
elif self.reduction == 'mean':
loss = self.loss_f(probs, labels).mean()
elif self.reduction == 'none':
if self.class_normalization in ['class', 'class_binary', 'binary']:
loss = (binary_weight * norm_weight * self.loss_f(probs, labels)).sum()
else:
loss = self.loss_f(probs, labels)
return loss
Loading