-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Adding bf16 Sortformer train and inference #14627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: taejinp <[email protected]>
Signed-off-by: taejinp <[email protected]>
Signed-off-by: taejinp <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks Taejin!
@@ -82,6 +82,7 @@ class DiarizationConfig: | |||
no_der: bool = False | |||
out_rttm_dir: Optional[str] = None | |||
save_preds_tensors: bool = False | |||
precision: str = "bf16" # 32, bf16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also add the bf16-mixed option and maybe add a small comment about possible gains expected with bf16 training/inference?
Signed-off-by: taejinp <[email protected]>
Signed-off-by: taejinp <[email protected]>
Signed-off-by: tango4j <[email protected]>
Signed-off-by: taejinp <[email protected]>
Signed-off-by: taejinp <[email protected]>
… into bf16_sortformer_train
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
What does this PR do ?
This PR adds bf16 precision training and inference of Sortformer diarizer models.
Hardware: Starting from Ampere (e.g. A100) architecture, native bf16 operation is supported.
Collection: [Note which collection this PR will affect]
ASR/speaker_task
Changelog
NeMo/nemo/collections/asr/losses/bce_loss.py
NeMo/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py
Usage
Although model weights are FP32,
e2e_diarize_speech.py
script automatically converts the precision tobf16
then perform inference based onbf16
.for training, specify the following configuration:
trainer.precision="bf16"
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information