Skip to content

fix t5x tests

fix t5x tests #524

Workflow file for this run

name: nsys-jax pure-Python CI
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
on:
pull_request:
types:
- opened
- reopened
- ready_for_review
- synchronize
paths-ignore:
- '**.md'
push:
branches:
- main
defaults:
run:
shell: bash -x -eo pipefail {0}
env:
NSYS_JAX_PYTHON_FILES: |
JAX-Toolbox/.github/container/nsys_jax
JAX-Toolbox/.github/container/jax-nccl-test
jobs:
mypy:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
with:
path: JAX-Toolbox
sparse-checkout: |
.github/container
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
# jax is just a CPU-only build of the latest release for type-checking purposes
- name: "Install jax / nsys_jax / mypy"
run: pip install jax -e JAX-Toolbox/.github/container/nsys_jax matplotlib mypy nbconvert types-protobuf types-requests
- name: "Install protoc"
# TODO: this could install into the pip prefix as a default
run: |
install-protoc local/
echo "$PWD/local/bin" >> "${GITHUB_PATH}"
- name: "Fetch XLA .proto files"
uses: actions/checkout@v4
with:
path: xla
repository: openxla/xla
sparse-checkout: |
*.proto
sparse-checkout-cone-mode: false
- name: "Compile .proto files"
run: |
mkdir compiled_protos compiled_stubs protos
mv -v xla/third_party/tsl/tsl protos/
mv -v xla/xla protos/
python -c "from nsys_jax import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')"
touch compiled_stubs/py.typed
- name: "Convert .ipynb to .py"
run: |
for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do
jupyter nbconvert --to script ${notebook}
done
- name: "Run mypy checks"
run: |
export MYPYPATH="${PWD}/compiled_stubs"
mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES}
# Test nsys-jax-combine and notebook execution; in future perhaps upload the rendered
# notebook from here too. These input files were generated with something like
# srun -n 4 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-4proc-proc%q{SLURM_PROCID}
# -- test-pax.sh --steps=5 --fsdp=4 --multiprocess
# with newer nsys-jax components bind-mounted in.
combine:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
- name: "Setup Python 3.12"
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install nsys-jax and dependencies
run: |
# Installs nsys-jax-combine; use an editable install here for better coverage
pip install -e .github/container/nsys_jax[jupyter]
# TODO: this could install into the pip prefix as a default
install-flamegraph local/
install-protoc local/
echo "$PWD/local/bin" >> "${GITHUB_PATH}"
- name: Use nsys-jax-combine to merge profiles from multiple nsys processes
run: |
nsys-jax-combine \
--analysis summary \
--analysis communication \
-o pax_fsdp4_4proc.zip \
.github/workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip
- name: Extract the output .zip file
run: |
mkdir combined/
unzip -d combined/ pax_fsdp4_4proc.zip
- name: Execute the notebook
run: |
NOTEBOOK=$(python -c 'from importlib.resources import files; print(files("nsys_jax") / "analyses" / "Analysis.ipynb")')
# Point to the extracted nsys-jax-combine output
export NSYS_JAX_DEFAULT_PREFIX="${PWD}/combined"
# Run with ipython for the sake of getting a clear error message
ipython "${NOTEBOOK}"
# This input file was generated with something like
# srun -n 1 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-1proc -- test-pax.sh
# --steps=5 --fsdp=4
notebook:
env:
# TODO: these could/should be saved in the repository settings instead
RENDERED_NOTEBOOK_GIST_ID: e2cd3520201caab6b67385ed36fad3c1
MOCK_RENDERED_NOTEBOOK_GIST_ID: 16698d9e9e52320243165d61b5bb3975
# Name/bash regex for shields.io endpoint JSON files
PUBLISH_NOTEBOOK_FILES: '(.*\.ipynb|.*\.svg)'
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
- name: Extract the post-processed profile data from a real .zip file (no .nsys-rep)
run: |
# Get the actual test data from a real archive, minus the .nsys-rep file
mkdir profile_data/
unzip -d profile_data/ .github/workflows/nsys-jax/test_data/pax_fsdp4_1proc.zip
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install nsys-jax and dependencies
run: |
# Do *not* use an editable install (covered above) for better coverage
pip install .github/container/nsys_jax[jupyter]
# TODO: this could install into the pip prefix as a default
install-flamegraph local/
install-protoc local/
echo "$PWD/local/bin" >> "${GITHUB_PATH}"
- name: Execute the notebook
id: exec
run: |
NOTEBOOK=$(python -c 'from importlib.resources import files; print(files("nsys_jax") / "analyses" / "Analysis.ipynb")')
# Point to the extracted profile data
export NSYS_JAX_DEFAULT_PREFIX="${PWD}/profile_data"
# Run with ipython for the sake of getting a clear error message
ipython "${NOTEBOOK}"
echo "NOTEBOOK=${NOTEBOOK}" >> $GITHUB_OUTPUT
- name: Render the notebook
id: render
run: |
workdir=$(mktemp -d)
export NSYS_JAX_DEFAULT_PREFIX="${PWD}/profile_data"
jupyter nbconvert --execute --inplace '${{ steps.exec.outputs.NOTEBOOK }}'
cp '${{ steps.exec.outputs.NOTEBOOK }}' *.svg "${workdir}"
echo "WORKDIR=${workdir}" >> $GITHUB_OUTPUT
- name: Upload rendered notebook to Gist
id: upload
uses: actions/github-script@v7
with:
github-token: ${{ secrets.NVJAX_GIST_TOKEN }}
script: |
const currentDateTime = new Date().toISOString();
const gistDescription =
`Rendered IPython notebook from workflow: ${{ github.workflow }}, ` +
`Run ID: ${{ github.run_id }}, ` +
`Repository: ${{ github.repository }}, ` +
`Event: ${{ github.event_name }}, ` +
`Created: ${currentDateTime}`;
const fs = require('fs').promises;
const workdir = '${{ steps.render.outputs.WORKDIR }}'
const files = await fs.readdir(workdir);
gist = await github.rest.gists.create({
description: gistDescription,
public: false,
files: Object.fromEntries(
await Promise.all(
files.map(
async filename => {
const content = await fs.readFile(`${workdir}/${filename}`, 'utf8');
return [filename, { content }];
}
)
)
)
});
console.log(gist)
return gist.data.id;
- name: Copy rendered notebook to Gist with well-known ID
uses: actions/github-script@v7
with:
github-token: ${{ secrets.NVJAX_GIST_TOKEN }}
script: |
const srcId = ${{ steps.upload.outputs.result }};
const dstId = "${{ github.ref == 'refs/heads/main' && env.RENDERED_NOTEBOOK_GIST_ID || env.MOCK_RENDERED_NOTEBOOK_GIST_ID }}";
const { PUBLISH_NOTEBOOK_FILES } = process.env;
// Fetch existing files from destination gist
const { data: dstData } = await github.rest.gists.get({
gist_id: dstId
});
// Mark existing files in destination gist for deletion
let filesToUpdate = {};
for (const filename of Object.keys(dstData.files)) {
filesToUpdate[filename] = null;
}
// Fetch files from source gist
const { data: srcData } = await github.rest.gists.get({
gist_id: srcId
});
// Add or update files based on the pattern
const pattern = new RegExp(`${PUBLISH_NOTEBOOK_FILES}`);
for (const [filename, fileObj] of Object.entries(srcData.files)) {
if (filename.match(pattern)) {
// If the total gist size is too large, not all the content will have
// been returned and we need some extra requests.
if (fileObj.truncated) {
const { data } = await github.request(fileObj.raw_url)
filesToUpdate[filename] = {
content: new TextDecoder().decode(data)
};
} else {
filesToUpdate[filename] = {
content: fileObj.content
};
}
}
}
// Update files in destination gist
await github.rest.gists.update({
gist_id: dstId,
files: filesToUpdate
});
console.log("Files copied successfully.");
console.log(Object.keys(filesToUpdate));
ruff:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
with:
path: JAX-Toolbox
sparse-checkout: |
.github/container
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: "Install ruff"
run: pip install ruff
- name: "Run ruff checks"
run: |
ruff check ${NSYS_JAX_PYTHON_FILES}
check_status=$?
ruff format --check ${NSYS_JAX_PYTHON_FILES}
format_status=$?
if [[ $format_status != 0 || $check_status != 0 ]]; then
exit 1
fi