Merge pull request #22 from rdyro/add-matrix-to-ci #9
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: CI | |
on: | |
# Trigger the workflow on push or pull request, | |
# but only for the main branch | |
push: | |
branches: | |
- main | |
pull_request: | |
branches: | |
- main | |
permissions: | |
contents: read # to fetch code | |
actions: write # to cancel previous workflows | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | |
cancel-in-progress: true | |
jobs: | |
lint_and_typecheck: | |
strategy: | |
matrix: | |
python-version: ["3.11", "3.12"] | |
runs-on: ubuntu-latest | |
timeout-minutes: 5 | |
steps: | |
- uses: actions/checkout@v4 | |
- name: Set up Python 3.11 | |
uses: actions/setup-python@v5 | |
with: | |
python-version: ${{ matrix.python-version }} | |
- run: | | |
python3 -m pip install -U ruff | |
ruff format --diff torch2jax | |
ruff check torch2jax | |
pytest: | |
strategy: | |
matrix: | |
python-version: ["3.11", "3.12"] | |
jax-version: ["0.4.33", "0.4.35", "latest"] | |
runs-on: ubuntu-latest | |
timeout-minutes: 5 | |
steps: | |
- uses: actions/checkout@v3 | |
- name: Install ninja | |
run: | | |
sudo DEBIAN_FRONTEND=noninteractive apt update -y | |
sudo DEBIAN_FRONTEND=noninteractive apt install -y ninja-build coreutils | |
- name: Set up Python ${{ matrix.python-version }} | |
uses: actions/setup-python@v5 | |
with: | |
python-version: ${{ matrix.python-version }} | |
- name: Get uv cache dir | |
id: uv-cache | |
run: | | |
python3 -m pip install -U uv pip | |
echo "venv_path=~/.venv" >> $GITHUB_OUTPUT | |
echo "pyver=$(python3 -V)" >> $GITHUB_OUTPUT | |
- uses: actions/cache@v4 | |
name: uv cache | |
with: | |
path: ${{ steps.uv-cache.outputs.venv_path }} | |
key: ${{ runner.os }}-uv-${{ steps.uv-cache.outputs.pyver }} | |
- name: Install uv and pytest | |
run: | | |
[[ ! -d ~/.venv ]] && uv venv --seed ~/.venv | |
source ~/.venv/bin/activate | |
uv pip install pytest && pip install pytest | |
if [[ ${{ matrix.jax-version }} == "latest" ]]; then | |
uv pip install -U torch jax . | |
else | |
uv pip install torch jax==${{ matrix.jax-version }} . | |
fi | |
- name: Get extension path | |
id: cpp-extension-cache | |
run: | | |
source ~/.venv/bin/activate | |
cpp_ext_ver=$(python3 -c 'from torch2jax.compile \ | |
import _generate_extension_version; print(_generate_extension_version())') | |
cpp_ext_path="~/.cache/torch2jax/$cpp_ext_ver" | |
echo "#################################################" | |
echo "cpp_ext_ver=$cpp_ext_ver" | |
echo "cpp_ext_path=$cpp_ext_path" | |
echo "#################################################" | |
echo "cpp_ext_ver=$cpp_ext_ver" >> $GITHUB_OUTPUT | |
echo "cpp_ext_path=$cpp_ext_path" >> $GITHUB_OUTPUT | |
- uses: actions/cache@v4 | |
name: cpp extension cache | |
with: | |
path: ${{ steps.cpp-extension-cache.outputs.cpp_ext_path }} | |
key: ${{ steps.cpp-extension-cache.outputs.cpp_ext_ver }} | |
- name: Compile the torch2jax cpp extension | |
run: | | |
source ~/.venv/bin/activate | |
python3 -c 'from torch2jax.compile import compile_extension; compile_extension()' | |
- name: Run tests | |
run: | | |
source ~/.venv/bin/activate | |
env JAX_ENABLE_X64=True pytest tests/* |