Skip to content

Run CPU Jax Testsuite #153

Run CPU Jax Testsuite

Run CPU Jax Testsuite #153

# Installs prebuilt binaries and runs jax testsuite
name: Run CPU Jax Testsuite
on:
workflow_dispatch:
schedule:
# Do the nightly dep roll at 2:30 PDT.
- cron: '30 21 * * *'
concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
# queued and in-progress runs for the same PR (presubmit) or commit
# (postsubmit).
group: run_jax_testsuite_${{ github.event.number || github.sha }}
cancel-in-progress: true
# Jobs are organized into groups and topologically sorted by dependencies
jobs:
build:
runs-on: ubuntu-20.04
steps:
- name: "Checking out repository"
uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0
- name: "Setting up Python"
uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3
with:
python-version: "3.10"
- name: Download last PJRT build
id: download-pjrt-build
uses: dawidd6/action-download-artifact@v2
with:
workflow: build_packages.yml
branch: main
name: artifact
- name: Download Last Golden result
id: download-last-golden
uses: dawidd6/action-download-artifact@v2
with:
workflow: run_jaxtests_cpu.yml
branch: main
name: artifact
if_no_artifact_found: warn
- name: Sync and install versions
run: |
# TODO: https://github.com/openxla/openxla-pjrt-plugin/issues/30
sudo apt install -y lld
# Since only building the runtime, exclude compiler deps (expensive).
python ./sync_deps.py --depth 1 --submodules-depth 1 --exclude-dep xla --exclude-submodule "iree:third_party/(.*)"
pip install -r requirements.txt
python -m pip install absl-py pytest
python -m pip install -e ctstools
- name: "Configure"
run: |
python ./configure.py --cc=clang --cxx=clang++
- name: "Run JAX Testsuite"
id: jax-testsuite
env:
PASSING_ARTIFACT: jaxsuite_passing.txt
FAILING_ARTIFACT: jaxsuite_failing.txt
GOLDEN_ARTIFACT: jaxsuite_golden.txt
run: |
source .env.sh
echo "testsuite-passing-artifact=${PASSING_ARTIFACT}" >> "${GITHUB_OUTPUT}"
echo "testsuite-failing-artifact=${FAILING_ARTIFACT}" >> "${GITHUB_OUTPUT}"
echo "testsuite-golden-artifact=${GOLDEN_ARTIFACT}" >> "${GITHUB_OUTPUT}"
# Point to the downloaded PJRT library.
export PJRT_NAMES_AND_LIBRARY_PATHS="iree_cpu:/home/runner/work/openxla-pjrt-plugin/openxla-pjrt-plugin/pjrt_plugins/pjrt_plugin_iree_cpu.so"
JAX_PLATFORMS=iree_cpu python test/test_simple.py
JAX_PLATFORMS=iree_cpu python test/test_jax.py \
/home/runner/work/openxla-pjrt-plugin/jax/tests/nn_test.py \
--passing ${PASSING_ARTIFACT} \
--failing ${FAILING_ARTIFACT} \
--expected ${GOLDEN_ARTIFACT} \
--logdir jax_testsuite
# If we passed we can update the golden.
if [ $? -eq 0 ]; then
cp ${PASSING_ARTIFACT} ${GOLDEN_ARTIFACT}
fi
- name: "Run JAX Testsuite Triage"
run: |
python test/triage_jaxtest.py --logdir jax_testsuite
- uses: actions/upload-artifact@0b7f8abb1508181956e8e162db84b466c27e18ce # v3.1.2
if: always()
with:
path: |
${{ steps.jax-testsuite.outputs.testsuite-passing-artifact }}
${{ steps.jax-testsuite.outputs.testsuite-failing-artifact }}
${{ steps.jax-testsuite.outputs.testsuite-golden-artifact }}
retention-days: 5