Skip to content

Commit 64f0faa

Browse files
akoumpako3n1g
andauthored
[automodel] fallback FP8 + LCE -> FP8 + CE (#13349)
* fix Signed-off-by: Alexandros Koumparoulis <[email protected]> * make fp8 tests non-optional Signed-off-by: Alexandros Koumparoulis <[email protected]> * switch to gemma Signed-off-by: Alexandros Koumparoulis <[email protected]> --------- Signed-off-by: Alexandros Koumparoulis <[email protected]> Co-authored-by: oliver könig <[email protected]>
1 parent 28db904 commit 64f0faa

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

.github/workflows/cicd-main-automodel.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ jobs:
8484
script: L2_HF_Transformer_PEFT_2gpu_FSDP2_liger
8585
- runner: azure-gpu-vm-runner1-h100
8686
script: L2_HF_Transformer_PEFT_2gpu_FSDP2_fp8
87-
is_optional: true
8887
- runner: self-hosted-azure
8988
script: L2_HF_Transformer_PEFT_2gpu_FSDP2
9089
- runner: self-hosted-azure
@@ -95,7 +94,6 @@ jobs:
9594
script: L2_HF_Transformer_SFT_2gpu_FSDP2
9695
- runner: azure-gpu-vm-runner1-h100
9796
script: L2_HF_Transformer_SFT_2gpu_FSDP2_fp8
98-
is_optional: true
9997
- runner: self-hosted-azure
10098
script: L2_HF_Transformer_SFT_2gpu_nemorun
10199
- runner: self-hosted-azure

nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,14 @@ def configure_model(self):
265265

266266
te_accelerate(self.model, self.model_accelerator.fp8_autocast)
267267

268+
if self.use_linear_ce_loss:
269+
# scan the model for fp8 layers, if found disable lce
270+
for module in self.model.modules():
271+
if hasattr(module, 'fp8'):
272+
logging.warning("LCE does not support FP8, switching to regular CE.")
273+
self.use_linear_ce_loss = False
274+
break
275+
268276
if self.enable_grad_ckpt:
269277
if getattr(self.model, 'supports_gradient_checkpointing', False):
270278
self.model.gradient_checkpointing_enable()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export TRANSFORMERS_OFFLINE=1
22
export HF_HOME=/home/TestData/automodel/hf_home
33
coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/llm/peft/automodel.py \
4-
--model /home/TestData/akoumparouli/hf_mixtral_2l/ \
4+
--model /home/TestData/akoumparouli/hf_gemma_38m/ \
55
--max-steps 3 \
66
--devices 2 \
77
--strategy fsdp2 --fp8

0 commit comments

Comments
 (0)