Skip to content

Commit 70f5366

Browse files
jberchtold-nvidiaKshitijLakhani
authored andcommitted
[JAX] Ensure JAX reference impl uses an accurate backend in our tests (#2322)
Ensure JAX reference impl uses an accurate backend Signed-off-by: Jeremy Berchtold <[email protected]>
1 parent 9cc089a commit 70f5366

File tree

2 files changed

+4
-2
lines changed
  • qa

2 files changed

+4
-2
lines changed

qa/L1_jax_distributed_unittest/test.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ set -xe
88
: ${XML_LOG_DIR:=/logs}
99
mkdir -p "$XML_LOG_DIR"
1010

11-
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
11+
# Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate.
12+
XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
1213
SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh

qa/L2_jax_distributed_unittest/test.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ set -xe
88
: ${XML_LOG_DIR:=/logs}
99
mkdir -p "$XML_LOG_DIR"
1010

11-
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
11+
# Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate.
12+
XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*

0 commit comments

Comments
 (0)