Skip to content

Merge pull request #22 from rdyro/add-matrix-to-ci #9

Merge pull request #22 from rdyro/add-matrix-to-ci

Merge pull request #22 from rdyro/add-matrix-to-ci #9

Workflow file for this run

name: CI
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
lint_and_typecheck:
strategy:
matrix:
python-version: ["3.11", "3.12"]
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- run: |
python3 -m pip install -U ruff
ruff format --diff torch2jax
ruff check torch2jax
pytest:
strategy:
matrix:
python-version: ["3.11", "3.12"]
jax-version: ["0.4.33", "0.4.35", "latest"]
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- uses: actions/checkout@v3
- name: Install ninja
run: |
sudo DEBIAN_FRONTEND=noninteractive apt update -y
sudo DEBIAN_FRONTEND=noninteractive apt install -y ninja-build coreutils
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Get uv cache dir
id: uv-cache
run: |
python3 -m pip install -U uv pip
echo "venv_path=~/.venv" >> $GITHUB_OUTPUT
echo "pyver=$(python3 -V)" >> $GITHUB_OUTPUT
- uses: actions/cache@v4
name: uv cache
with:
path: ${{ steps.uv-cache.outputs.venv_path }}
key: ${{ runner.os }}-uv-${{ steps.uv-cache.outputs.pyver }}
- name: Install uv and pytest
run: |
[[ ! -d ~/.venv ]] && uv venv --seed ~/.venv
source ~/.venv/bin/activate
uv pip install pytest && pip install pytest
if [[ ${{ matrix.jax-version }} == "latest" ]]; then
uv pip install -U torch jax .
else
uv pip install torch jax==${{ matrix.jax-version }} .
fi
- name: Get extension path
id: cpp-extension-cache
run: |
source ~/.venv/bin/activate
cpp_ext_ver=$(python3 -c 'from torch2jax.compile \
import _generate_extension_version; print(_generate_extension_version())')
cpp_ext_path="~/.cache/torch2jax/$cpp_ext_ver"
echo "#################################################"
echo "cpp_ext_ver=$cpp_ext_ver"
echo "cpp_ext_path=$cpp_ext_path"
echo "#################################################"
echo "cpp_ext_ver=$cpp_ext_ver" >> $GITHUB_OUTPUT
echo "cpp_ext_path=$cpp_ext_path" >> $GITHUB_OUTPUT
- uses: actions/cache@v4
name: cpp extension cache
with:
path: ${{ steps.cpp-extension-cache.outputs.cpp_ext_path }}
key: ${{ steps.cpp-extension-cache.outputs.cpp_ext_ver }}
- name: Compile the torch2jax cpp extension
run: |
source ~/.venv/bin/activate
python3 -c 'from torch2jax.compile import compile_extension; compile_extension()'
- name: Run tests
run: |
source ~/.venv/bin/activate
env JAX_ENABLE_X64=True pytest tests/*