Skip to content

Commit 12737e5

Browse files
Fix incorrect function call in test_qwen3_grpo.py (#3212)
* Update test_qwen3_grpo.py to correct function call This test file uses the incorrect name for the function, which is gradient_checkpointing_disable(), not disable_gradient_checkpointing(). I copied the line from test_llama32_sft.py - I'm not sure if this actually is required, just wanted it consistent for when other people like me test this and have no clue what they're doing when it throws an exception. * Update blackwell/test_qwen3_grpo.py Co-authored-by: Daniel Han <[email protected]> --------- Co-authored-by: Daniel Han <[email protected]>
1 parent 64b4048 commit 12737e5

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

blackwell/test_qwen3_grpo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,8 @@ def check_numbers(prompts, completions, answer, **kwargs):
415415
top_k=50,
416416
max_tokens=1024,
417417
)
418-
model.disable_gradient_checkpointing()
418+
419+
419420
output = (
420421
model.fast_generate(
421422
[text],

0 commit comments

Comments
 (0)