File tree Expand file tree Collapse file tree 2 files changed +4
-2
lines changed
L1_jax_distributed_unittest
L2_jax_distributed_unittest Expand file tree Collapse file tree 2 files changed +4
-2
lines changed Original file line number Diff line number Diff line change 88: ${XML_LOG_DIR:=/ logs}
99mkdir -p " $XML_LOG_DIR "
1010
11- NVTE_JAX_UNITTEST_LEVEL=" L1" python3 -m pytest -c $TE_PATH /tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR /pytest.xml $TE_PATH /tests/jax/test_distributed_*
11+ # Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate.
12+ XLA_FLAGS=" $XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL=" L1" python3 -m pytest -c $TE_PATH /tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR /pytest.xml $TE_PATH /tests/jax/test_distributed_*
1213SCRIPT_NAME=$TE_PATH /tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH /tests/jax/multi_process_launch.sh
Original file line number Diff line number Diff line change 88: ${XML_LOG_DIR:=/ logs}
99mkdir -p " $XML_LOG_DIR "
1010
11- NVTE_JAX_UNITTEST_LEVEL=" L2" python3 -m pytest -c $TE_PATH /tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR /pytest.xml $TE_PATH /tests/jax/test_distributed_*
11+ # Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate.
12+ XLA_FLAGS=" $XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL=" L2" python3 -m pytest -c $TE_PATH /tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR /pytest.xml $TE_PATH /tests/jax/test_distributed_*
You can’t perform that action at this time.
0 commit comments