Run CPU Jax Testsuite #166
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Installs prebuilt binaries and runs jax testsuite | |
name: Run CPU Jax Testsuite | |
on: | |
workflow_dispatch: | |
schedule: | |
# Do the nightly dep roll at 2:30 PDT. | |
- cron: '30 21 * * *' | |
concurrency: | |
# A PR number if a pull request and otherwise the commit hash. This cancels | |
# queued and in-progress runs for the same PR (presubmit) or commit | |
# (postsubmit). | |
group: run_jax_testsuite_${{ github.event.number || github.sha }} | |
cancel-in-progress: true | |
# Jobs are organized into groups and topologically sorted by dependencies | |
jobs: | |
build: | |
runs-on: ubuntu-20.04 | |
steps: | |
- name: "Checking out repository" | |
uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 | |
- name: "Setting up Python" | |
uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3 | |
with: | |
python-version: "3.10" | |
- name: Download last PJRT build | |
id: download-pjrt-build | |
uses: dawidd6/action-download-artifact@v2 | |
with: | |
workflow: build_packages.yml | |
branch: main | |
name: artifact | |
- name: Download Last Golden result | |
id: download-last-golden | |
uses: dawidd6/action-download-artifact@v2 | |
with: | |
workflow: run_jaxtests_cpu.yml | |
branch: main | |
name: artifact | |
if_no_artifact_found: warn | |
- name: Sync and install versions | |
run: | | |
# TODO: https://github.com/openxla/openxla-pjrt-plugin/issues/30 | |
sudo apt install -y lld | |
# Since only building the runtime, exclude compiler deps (expensive). | |
python ./sync_deps.py --depth 1 --submodules-depth 1 --exclude-dep xla --exclude-submodule "iree:third_party/(.*)" | |
pip install -r requirements.txt | |
python -m pip install absl-py pytest | |
python -m pip install -e ctstools | |
- name: "Configure" | |
run: | | |
python ./configure.py --cc=clang --cxx=clang++ | |
- name: "Run JAX Testsuite" | |
id: jax-testsuite | |
env: | |
PASSING_ARTIFACT: jaxsuite_passing.txt | |
FAILING_ARTIFACT: jaxsuite_failing.txt | |
GOLDEN_ARTIFACT: jaxsuite_golden.txt | |
run: | | |
source .env.sh | |
echo "testsuite-passing-artifact=${PASSING_ARTIFACT}" >> "${GITHUB_OUTPUT}" | |
echo "testsuite-failing-artifact=${FAILING_ARTIFACT}" >> "${GITHUB_OUTPUT}" | |
echo "testsuite-golden-artifact=${GOLDEN_ARTIFACT}" >> "${GITHUB_OUTPUT}" | |
# Point to the downloaded PJRT library. | |
export PJRT_NAMES_AND_LIBRARY_PATHS="iree_cpu:/home/runner/work/openxla-pjrt-plugin/openxla-pjrt-plugin/pjrt_plugins/pjrt_plugin_iree_cpu.so" | |
JAX_PLATFORMS=iree_cpu python test/test_simple.py | |
JAX_PLATFORMS=iree_cpu python test/test_jax.py \ | |
/home/runner/work/openxla-pjrt-plugin/jax/tests/nn_test.py \ | |
--passing ${PASSING_ARTIFACT} \ | |
--failing ${FAILING_ARTIFACT} \ | |
--expected ${GOLDEN_ARTIFACT} \ | |
--logdir jax_testsuite | |
# If we passed we can update the golden. | |
if [ $? -eq 0 ]; then | |
cp ${PASSING_ARTIFACT} ${GOLDEN_ARTIFACT} | |
fi | |
- name: "Run JAX Testsuite Triage" | |
run: | | |
python test/triage_jaxtest.py --logdir jax_testsuite | |
- uses: actions/upload-artifact@0b7f8abb1508181956e8e162db84b466c27e18ce # v3.1.2 | |
if: always() | |
with: | |
path: | | |
${{ steps.jax-testsuite.outputs.testsuite-passing-artifact }} | |
${{ steps.jax-testsuite.outputs.testsuite-failing-artifact }} | |
${{ steps.jax-testsuite.outputs.testsuite-golden-artifact }} | |
retention-days: 5 |