fix t5x tests #524
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: nsys-jax pure-Python CI | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | |
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} | |
on: | |
pull_request: | |
types: | |
- opened | |
- reopened | |
- ready_for_review | |
- synchronize | |
paths-ignore: | |
- '**.md' | |
push: | |
branches: | |
- main | |
defaults: | |
run: | |
shell: bash -x -eo pipefail {0} | |
env: | |
NSYS_JAX_PYTHON_FILES: | | |
JAX-Toolbox/.github/container/nsys_jax | |
JAX-Toolbox/.github/container/jax-nccl-test | |
jobs: | |
mypy: | |
runs-on: ubuntu-24.04 | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
with: | |
path: JAX-Toolbox | |
sparse-checkout: | | |
.github/container | |
- name: "Setup Python 3.10" | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
# jax is just a CPU-only build of the latest release for type-checking purposes | |
- name: "Install jax / nsys_jax / mypy" | |
run: pip install jax -e JAX-Toolbox/.github/container/nsys_jax matplotlib mypy nbconvert types-protobuf types-requests | |
- name: "Install protoc" | |
# TODO: this could install into the pip prefix as a default | |
run: | | |
install-protoc local/ | |
echo "$PWD/local/bin" >> "${GITHUB_PATH}" | |
- name: "Fetch XLA .proto files" | |
uses: actions/checkout@v4 | |
with: | |
path: xla | |
repository: openxla/xla | |
sparse-checkout: | | |
*.proto | |
sparse-checkout-cone-mode: false | |
- name: "Compile .proto files" | |
run: | | |
mkdir compiled_protos compiled_stubs protos | |
mv -v xla/third_party/tsl/tsl protos/ | |
mv -v xla/xla protos/ | |
python -c "from nsys_jax import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')" | |
touch compiled_stubs/py.typed | |
- name: "Convert .ipynb to .py" | |
run: | | |
for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do | |
jupyter nbconvert --to script ${notebook} | |
done | |
- name: "Run mypy checks" | |
run: | | |
export MYPYPATH="${PWD}/compiled_stubs" | |
mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES} | |
# Test nsys-jax-combine and notebook execution; in future perhaps upload the rendered | |
# notebook from here too. These input files were generated with something like | |
# srun -n 4 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06 | |
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\ | |
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-4proc-proc%q{SLURM_PROCID} | |
# -- test-pax.sh --steps=5 --fsdp=4 --multiprocess | |
# with newer nsys-jax components bind-mounted in. | |
combine: | |
runs-on: ubuntu-24.04 | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
- name: "Setup Python 3.12" | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.12' | |
- name: Install nsys-jax and dependencies | |
run: | | |
# Installs nsys-jax-combine; use an editable install here for better coverage | |
pip install -e .github/container/nsys_jax[jupyter] | |
# TODO: this could install into the pip prefix as a default | |
install-flamegraph local/ | |
install-protoc local/ | |
echo "$PWD/local/bin" >> "${GITHUB_PATH}" | |
- name: Use nsys-jax-combine to merge profiles from multiple nsys processes | |
run: | | |
nsys-jax-combine \ | |
--analysis summary \ | |
--analysis communication \ | |
-o pax_fsdp4_4proc.zip \ | |
.github/workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip | |
- name: Extract the output .zip file | |
run: | | |
mkdir combined/ | |
unzip -d combined/ pax_fsdp4_4proc.zip | |
- name: Execute the notebook | |
run: | | |
NOTEBOOK=$(python -c 'from importlib.resources import files; print(files("nsys_jax") / "analyses" / "Analysis.ipynb")') | |
# Point to the extracted nsys-jax-combine output | |
export NSYS_JAX_DEFAULT_PREFIX="${PWD}/combined" | |
# Run with ipython for the sake of getting a clear error message | |
ipython "${NOTEBOOK}" | |
# This input file was generated with something like | |
# srun -n 1 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06 | |
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\ | |
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-1proc -- test-pax.sh | |
# --steps=5 --fsdp=4 | |
notebook: | |
env: | |
# TODO: these could/should be saved in the repository settings instead | |
RENDERED_NOTEBOOK_GIST_ID: e2cd3520201caab6b67385ed36fad3c1 | |
MOCK_RENDERED_NOTEBOOK_GIST_ID: 16698d9e9e52320243165d61b5bb3975 | |
# Name/bash regex for shields.io endpoint JSON files | |
PUBLISH_NOTEBOOK_FILES: '(.*\.ipynb|.*\.svg)' | |
runs-on: ubuntu-24.04 | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
- name: Extract the post-processed profile data from a real .zip file (no .nsys-rep) | |
run: | | |
# Get the actual test data from a real archive, minus the .nsys-rep file | |
mkdir profile_data/ | |
unzip -d profile_data/ .github/workflows/nsys-jax/test_data/pax_fsdp4_1proc.zip | |
- name: "Setup Python 3.10" | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
- name: Install nsys-jax and dependencies | |
run: | | |
# Do *not* use an editable install (covered above) for better coverage | |
pip install .github/container/nsys_jax[jupyter] | |
# TODO: this could install into the pip prefix as a default | |
install-flamegraph local/ | |
install-protoc local/ | |
echo "$PWD/local/bin" >> "${GITHUB_PATH}" | |
- name: Execute the notebook | |
id: exec | |
run: | | |
NOTEBOOK=$(python -c 'from importlib.resources import files; print(files("nsys_jax") / "analyses" / "Analysis.ipynb")') | |
# Point to the extracted profile data | |
export NSYS_JAX_DEFAULT_PREFIX="${PWD}/profile_data" | |
# Run with ipython for the sake of getting a clear error message | |
ipython "${NOTEBOOK}" | |
echo "NOTEBOOK=${NOTEBOOK}" >> $GITHUB_OUTPUT | |
- name: Render the notebook | |
id: render | |
run: | | |
workdir=$(mktemp -d) | |
export NSYS_JAX_DEFAULT_PREFIX="${PWD}/profile_data" | |
jupyter nbconvert --execute --inplace '${{ steps.exec.outputs.NOTEBOOK }}' | |
cp '${{ steps.exec.outputs.NOTEBOOK }}' *.svg "${workdir}" | |
echo "WORKDIR=${workdir}" >> $GITHUB_OUTPUT | |
- name: Upload rendered notebook to Gist | |
id: upload | |
uses: actions/github-script@v7 | |
with: | |
github-token: ${{ secrets.NVJAX_GIST_TOKEN }} | |
script: | | |
const currentDateTime = new Date().toISOString(); | |
const gistDescription = | |
`Rendered IPython notebook from workflow: ${{ github.workflow }}, ` + | |
`Run ID: ${{ github.run_id }}, ` + | |
`Repository: ${{ github.repository }}, ` + | |
`Event: ${{ github.event_name }}, ` + | |
`Created: ${currentDateTime}`; | |
const fs = require('fs').promises; | |
const workdir = '${{ steps.render.outputs.WORKDIR }}' | |
const files = await fs.readdir(workdir); | |
gist = await github.rest.gists.create({ | |
description: gistDescription, | |
public: false, | |
files: Object.fromEntries( | |
await Promise.all( | |
files.map( | |
async filename => { | |
const content = await fs.readFile(`${workdir}/${filename}`, 'utf8'); | |
return [filename, { content }]; | |
} | |
) | |
) | |
) | |
}); | |
console.log(gist) | |
return gist.data.id; | |
- name: Copy rendered notebook to Gist with well-known ID | |
uses: actions/github-script@v7 | |
with: | |
github-token: ${{ secrets.NVJAX_GIST_TOKEN }} | |
script: | | |
const srcId = ${{ steps.upload.outputs.result }}; | |
const dstId = "${{ github.ref == 'refs/heads/main' && env.RENDERED_NOTEBOOK_GIST_ID || env.MOCK_RENDERED_NOTEBOOK_GIST_ID }}"; | |
const { PUBLISH_NOTEBOOK_FILES } = process.env; | |
// Fetch existing files from destination gist | |
const { data: dstData } = await github.rest.gists.get({ | |
gist_id: dstId | |
}); | |
// Mark existing files in destination gist for deletion | |
let filesToUpdate = {}; | |
for (const filename of Object.keys(dstData.files)) { | |
filesToUpdate[filename] = null; | |
} | |
// Fetch files from source gist | |
const { data: srcData } = await github.rest.gists.get({ | |
gist_id: srcId | |
}); | |
// Add or update files based on the pattern | |
const pattern = new RegExp(`${PUBLISH_NOTEBOOK_FILES}`); | |
for (const [filename, fileObj] of Object.entries(srcData.files)) { | |
if (filename.match(pattern)) { | |
// If the total gist size is too large, not all the content will have | |
// been returned and we need some extra requests. | |
if (fileObj.truncated) { | |
const { data } = await github.request(fileObj.raw_url) | |
filesToUpdate[filename] = { | |
content: new TextDecoder().decode(data) | |
}; | |
} else { | |
filesToUpdate[filename] = { | |
content: fileObj.content | |
}; | |
} | |
} | |
} | |
// Update files in destination gist | |
await github.rest.gists.update({ | |
gist_id: dstId, | |
files: filesToUpdate | |
}); | |
console.log("Files copied successfully."); | |
console.log(Object.keys(filesToUpdate)); | |
ruff: | |
runs-on: ubuntu-24.04 | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
with: | |
path: JAX-Toolbox | |
sparse-checkout: | | |
.github/container | |
- name: "Setup Python 3.10" | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
- name: "Install ruff" | |
run: pip install ruff | |
- name: "Run ruff checks" | |
run: | | |
ruff check ${NSYS_JAX_PYTHON_FILES} | |
check_status=$? | |
ruff format --check ${NSYS_JAX_PYTHON_FILES} | |
format_status=$? | |
if [[ $format_status != 0 || $check_status != 0 ]]; then | |
exit 1 | |
fi |