Skip to content

Commit 73dd122

Browse files
Update mistral.py, showed flag to not call cut cross entropy (#3233)
* Update mistral.py, showed flag to not call cut cross entropy * Update mistral.py, made it so if its not equal to zero * Update unsloth/models/mistral.py --------- Co-authored-by: Daniel Han <[email protected]>
1 parent 9050d6f commit 73dd122

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

unsloth/models/mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def MistralForCausalLM_fast_forward(
300300
# < 1024 Normal Unsloth uses less VRAM!
301301
if bsz * q_len <= 1024: RETURN_LOGITS = True
302302

303-
if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None:
303+
if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and os.environ.get("UNSLOTH_ENABLE_CCE", "1") != "0" and labels is not None:
304304
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)
305305
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
306306
loss = fused_linear_cross_entropy(

0 commit comments

Comments
 (0)