diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml index 2e85c5dc73b7..a1efc77df5d8 100644 --- a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -158,6 +158,7 @@ model: trainer: devices: 1 # number of gpus (devices) accelerator: gpu + precision: 32 # 32, bf16, bf16-mixed max_epochs: 800 max_steps: -1 # computed at runtime if not set num_nodes: 1 diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/streaming_sortformer_diarizer_4spk-v2.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/streaming_sortformer_diarizer_4spk-v2.yaml index dfd3534dadf6..969739ab63d5 100644 --- a/examples/speaker_tasks/diarization/conf/neural_diarizer/streaming_sortformer_diarizer_4spk-v2.yaml +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/streaming_sortformer_diarizer_4spk-v2.yaml @@ -201,7 +201,8 @@ model: trainer: devices: 1 # number of gpus (devices) - accelerator: gpu + accelerator: gpu + precision: 32 # 32, bf16, bf16-mixed max_epochs: 800 max_steps: -1 # computed at runtime if not set num_nodes: 1 diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 1fec49bf186b..5ad368d92b0d 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -82,6 +82,7 @@ class DiarizationConfig: no_der: bool = False out_rttm_dir: Optional[str] = None save_preds_tensors: bool = False + precision: str = "32" # 32, bf16, bf16-mixed # General configs session_len_sec: float = -1 # End-to-end diarization session length in seconds @@ -346,10 +347,13 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: raise ValueError("cfg.model_path must end with.ckpt or.nemo!") diar_model._cfg.test_ds.session_len_sec = cfg.session_len_sec - trainer = pl.Trainer(devices=device, accelerator=accelerator) + trainer = pl.Trainer(devices=device, accelerator=accelerator, precision=cfg.precision) diar_model.set_trainer(trainer) - diar_model = diar_model.eval() + if torch.cuda.is_bf16_supported() and cfg.precision.startswith("bf16"): + diar_model = diar_model.to(dtype=torch.bfloat16).eval() + else: + diar_model = diar_model.eval() if cfg.presort_manifest: audio_key = cfg.get('audio_key', 'audio_filepath') @@ -405,7 +409,9 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: diar_model_preds_total_list = torch.load(tensor_path) else: logging.info("No saved prediction tensors found. Running inference on the dataset...") - diar_model.test_batch() + with torch.inference_mode(), torch.autocast(device_type=diar_model.device.type, dtype=diar_model.dtype): + diar_model.test_batch() + diar_model_preds_total_list = diar_model.preds_total_list if cfg.save_preds_tensors: torch.save(diar_model.preds_total_list, tensor_path) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index 13d3f1dab37f..9cce02022c6b 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -23,9 +23,12 @@ """ Example training session (single node training) +For training, you can use the following precisions: 32, bf16 and bf16-mixed. +You can train with a larger batch size using BF16 mixed precision. python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' \ --config-name='sortformer_diarizer_hybrid_loss_4spk-v1.yaml' \ + trainer.precision='bf16' \ trainer.devices=1 \ model.train_ds.manifest_filepath="" \ model.validation_ds.manifest_filepath="" \ diff --git a/examples/speaker_tasks/diarization/neural_diarizer/streaming_sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/streaming_sortformer_diar_train.py index eaada67bd262..736d9a7c9453 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/streaming_sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/streaming_sortformer_diar_train.py @@ -23,9 +23,12 @@ """ Example training session (single node training) +For training, you can use the following precisions: 32, bf16 and bf16-mixed. +You can train with a larger batch size using BF16 mixed precision. python ./streaming_sortformer_diar_train.py --config-path='../conf/neural_diarizer' \ --config-name='streaming_sortformer_diarizer_4spk-v2.yaml' \ + trainer.precision='bf16' \ trainer.devices=1 \ model.train_ds.manifest_filepath="" \ model.validation_ds.manifest_filepath="" \ diff --git a/nemo/collections/asr/losses/bce_loss.py b/nemo/collections/asr/losses/bce_loss.py index 36a7a0166f26..61a7b3b2946b 100644 --- a/nemo/collections/asr/losses/bce_loss.py +++ b/nemo/collections/asr/losses/bce_loss.py @@ -77,7 +77,7 @@ def __init__( self.eps = 1e-6 @typecheck() - def forward(self, probs, labels, target_lens): + def forward(self, probs, labels, target_lens, enable_autocast=False): """ Calculate binary cross entropy loss based on probs, labels and target_lens variables. @@ -123,13 +123,14 @@ def forward(self, probs, labels, target_lens): binary_weight = torch.ones_like(labels).detach().clone() norm_weight = torch.ones_like(labels).detach().clone() - if self.reduction == 'sum': - loss = self.loss_f(probs, labels) - elif self.reduction == 'mean': - loss = self.loss_f(probs, labels).mean() - elif self.reduction == 'none': - if self.class_normalization in ['class', 'class_binary', 'binary']: - loss = (binary_weight * norm_weight * self.loss_f(probs, labels)).sum() - else: + with torch.cuda.amp.autocast(enabled=enable_autocast): + if self.reduction == 'sum': loss = self.loss_f(probs, labels) + elif self.reduction == 'mean': + loss = self.loss_f(probs, labels).mean() + elif self.reduction == 'none': + if self.class_normalization in ['class', 'class_binary', 'binary']: + loss = (binary_weight * norm_weight * self.loss_f(probs, labels)).sum() + else: + loss = self.loss_f(probs, labels) return loss