-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
nsys-jax: re-work to be more pip-install-able (#1165)
The overarching goal of this PR is to get closer to a world where the `nsys-jax` tooling is straightforwardly `pip install`-able. While the diff looks scary, it's mostly re-organisation. Substantive changes: - `nsys-jax` no longer bundles Python code in the output archives, the `install.sh` script provided for users to run on local machines becomes, loosely, `install 'pip nsys-jax[jupyter] @ git+https://github.com/NVIDIA/JAX-Toolbox.git@COMMIT#subdirectory=.github/container/nsys_jax'`, where `COMMIT` corresponds to the `nsys-jax` command that produced the archive. For the `ghcr.io/nvidia/jax` containers, this is the commit of JAX-Toolbox that triggered the container build. Changes included: - Introduce `/opt/pip-tools-post-install.d`, which `pip-finalize.sh` will execute the contents of *after* installing the `pip`-managed world - Migrate `install-protoc` to use this, so `pip-finalize.sh` can forget about that detail. - Install https://github.com/brendangregg/FlameGraph/blob/master/flamegraph.pl via this. - Patch the `nvtx_gpu_proj_trace` Python code in Nsight Systems 2024.5 and 2024.6 via this. - Move `nsys-jax` installation (specifically for the containers) into `install-nsys-jax.sh` and thereby clean up `install-nsight.sh`. The new script has to be told the git commit hash of JAX-Toolbox that is being built, because `nsys-jax` bakes this into an installation script in its output `.zip` archives to ensure the local environment matches the profile-collection environment. - The CLI tools like `nsys-jax`, `nsys-jax-combine` and `install-protoc` are now handled via `[project.scripts]` in `pyproject.toml` instead of being standalone Python scripts. This is "more standard", and also makes it easier to share code between `nsys-jax` and `nsys-jax-combine`. - The Python library is renamed from `jax_nsys` to `nsys_jax` for consistency. - It's now possible to set the default data loading path via the `NSYS_JAX_DEFAULT_PREFIX` environment variable; previously the default was the current working directory, but that can be inconvenient to steer in Jupyter environments.
- Loading branch information
Showing
36 changed files
with
1,497 additions
and
1,283 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ ARG BASE_IMAGE=nvidia/cuda:12.6.2-devel-ubuntu22.04 | |
ARG GIT_USER_NAME="JAX Toolbox" | ||
ARG [email protected] | ||
ARG CLANG_VERSION=18 | ||
ARG JAX_TOOLBOX_REF | ||
|
||
############################################################################### | ||
## Obtain GCP's NCCL TCPx plugin | ||
|
@@ -30,6 +31,7 @@ ARG BASE_IMAGE | |
ARG GIT_USER_EMAIL | ||
ARG GIT_USER_NAME | ||
ARG CLANG_VERSION | ||
ARG JAX_TOOLBOX_REF | ||
ENV CUDA_BASE_IMAGE=${BASE_IMAGE} | ||
|
||
############################################################################### | ||
|
@@ -110,7 +112,7 @@ RUN <<"EOF" bash -ex | |
git config --global user.name "${GIT_USER_NAME}" | ||
git config --global user.email "${GIT_USER_EMAIL}" | ||
EOF | ||
RUN mkdir -p /opt/pip-tools.d | ||
RUN mkdir -p /opt/pip-tools.d /opt/pip-tools-post-install.d | ||
ADD --chmod=777 \ | ||
git-clone.sh \ | ||
pip-finalize.sh \ | ||
|
@@ -141,7 +143,6 @@ COPY --from=tcpx-installer /var/lib/tcpx/lib64 ${TCPX_LIBRARY_PATH} | |
############################################################################### | ||
|
||
ADD install-nsight.sh /usr/local/bin | ||
ADD nsys-2024.5-tid-export.patch /opt/nvidia | ||
RUN install-nsight.sh | ||
|
||
############################################################################### | ||
|
@@ -183,7 +184,7 @@ ENV PATH=/opt/amazon/efa/bin:${PATH} | |
ADD install-nccl-sanity-check.sh /usr/local/bin | ||
ADD nccl-sanity-check.cu /opt | ||
RUN install-nccl-sanity-check.sh | ||
ADD jax-nccl-test parallel-launch /usr/local/bin | ||
ADD jax-nccl-test parallel-launch /usr/local/bin/ | ||
|
||
############################################################################### | ||
## Add the systemcheck to the entrypoint. | ||
|
@@ -199,23 +200,11 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/ | |
# COPY gcp-autoconfig.sh /opt/nvidia/entrypoint.d/ | ||
|
||
############################################################################### | ||
## Add helper scripts for profiling with Nsight Systems | ||
## | ||
## The scripts saved to /opt/jax_nsys are embedded in the output archives | ||
## written by nsys-jax, while the nsys-jax and nsys-jax-combine scripts are | ||
## only used inside the containers. | ||
############################################################################### | ||
ADD nsys-jax nsys-jax-combine /usr/local/bin/ | ||
ADD jax_nsys/ /opt/jax_nsys | ||
# The jax_nsys package should be installed inside the containers, so nsys-jax | ||
# can eagerly execute analysis recipes (--nsys-jax-analysis) in the container | ||
# environment, without an extra layer of virtual environment indirection. | ||
RUN echo "-e /opt/jax_nsys/python/jax_nsys" > /opt/pip-tools.d/requirements-nsys-jax.in | ||
# This should be embedded in output archives and be runnable inside containers | ||
RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/ | ||
# Should be available for execution inside the containers, should not be | ||
# embedded in the output archives. | ||
ADD jax_nsys_tests/ /opt/jax_nsys_tests | ||
## Install the nsys-jax JAX/XLA-aware profiling scripts, patch Nsight Systems | ||
############################################################################### | ||
|
||
ADD install-nsys-jax.sh /usr/local/bin | ||
RUN install-nsys-jax.sh ${JAX_TOOLBOX_REF} | ||
|
||
############################################################################### | ||
## Copy manifest file to the container | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/bin/bash | ||
set -exo pipefail | ||
|
||
REF="$1" | ||
if [[ -z "${REF}" ]]; then | ||
echo "$0: <git ref of JAX-Toolbox>" | ||
exit 1 | ||
fi | ||
|
||
# Install extra dependencies needed for `nsys recipe ...` commands. These are | ||
# used by the nsys-jax wrapper script. | ||
NSYS_DIR=$(dirname $(realpath $(command -v nsys))) | ||
ln -s ${NSYS_DIR}/python/packages/nsys_recipe/requirements/common.txt /opt/pip-tools.d/requirements-nsys-recipe.in | ||
|
||
# Install the nsys-jax package, which includes nsys-jax, nsys-jax-combine, | ||
# install-protoc (called from pip-finalize.sh), and nsys-jax-patch-nsys as well as the | ||
# nsys_jax Python library. | ||
URL="git+https://github.com/NVIDIA/JAX-Toolbox.git@${REF}#subdirectory=.github/container/nsys_jax&egg=nsys-jax" | ||
echo "-e '${URL}'" > /opt/pip-tools.d/requirements-nsys-jax.in | ||
|
||
# protobuf will be installed at least as a dependency of nsys_jax in the base | ||
# image, but the installed version is likely to be influenced by other packages. | ||
echo "install-protoc /usr/local" > /opt/pip-tools-post-install.d/protoc | ||
chmod 755 /opt/pip-tools-post-install.d/protoc | ||
|
||
# Make sure flamegraph.pl is available | ||
echo "install-flamegraph /usr/local" > /opt/pip-tools-post-install.d/flamegraph | ||
chmod 755 /opt/pip-tools-post-install.d/flamegraph | ||
|
||
# Make sure Nsight Systems Python patches are installed if needed | ||
echo "nsys-jax-patch-nsys" > /opt/pip-tools-post-install.d/patch-nsys | ||
chmod 755 /opt/pip-tools-post-install.d/patch-nsys |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.