Skip to content

Commit

Permalink
Merge branch 'main' of github.com:NVIDIA/JAX-Toolbox into pax-add-eva…
Browse files Browse the repository at this point in the history
…l-test
  • Loading branch information
ashors1 committed Nov 5, 2023
2 parents 25ab5bb + f6aff16 commit 8600d40
Show file tree
Hide file tree
Showing 32 changed files with 701 additions and 112 deletions.
2 changes: 1 addition & 1 deletion .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ ENV BUILD_DATE=${BUILD_DATE}
ENV XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false"
ENV CUDA_DEVICE_MAX_CONNECTIONS=1
ENV NCCL_IB_SL=1
ENV NCCL_NVLS_ENABLE=0
ENV CUDA_MODULE_LOADING=EAGER

COPY --from=jax-builder ${SRC_PATH_JAX}-no-git ${SRC_PATH_JAX}
COPY --from=jax-builder ${SRC_PATH_XLA}-no-git ${SRC_PATH_XLA}
Expand Down
11 changes: 6 additions & 5 deletions .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -126,24 +126,25 @@ case "${BATTERY}" in
large)
JOBS_PER_GPU=1
JOBS=$((NGPUS * JOBS_PER_GPU))
EXTRA_FLAGS="--jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
BAZEL_TARGET="${BAZEL_TARGET} //tests:image_test_gpu //tests:scipy_stats_test_gpu"
;;
gpu)
JOBS_PER_GPU=8
JOBS=$((NGPUS * JOBS_PER_GPU))
EXTRA_FLAGS="--jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
BAZEL_TARGET="${BAZEL_TARGET} //tests:gpu_tests"
;;
backend-independent)
JOBS=$NCPUS
EXTRA_FLAGS="--jobs=${JOBS} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
JOBS_PER_GPU=4
JOBS=$(($NGPUS * JOBS_PER_GPU))
EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
BAZEL_TARGET="${BAZEL_TARGET} //tests:backend_independent_tests"
;;
"")
JOBS_PER_GPU=4
JOBS=$((NGPUS * JOBS_PER_GPU))
EXTRA_FLAGS="--jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU}"
EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU}"
;;
*)
echo "Unknown battery ${BATTERY}"
Expand Down
31 changes: 24 additions & 7 deletions .github/container/test-t5x.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@ usage() {
echo " OPTIONS DESCRIPTION"
echo " -a, --additional-args Additional gin args to pass to t5x/train.py"
echo " -b, --batch-size Global batch size (REQUIRED)"
echo " -c --use-contrib-configs If provided uses contrib/gpu configs instead of top-level configs. Notably, gpu configs use adamw instead of adafactor"
echo " -c, --use-contrib-configs If provided uses contrib/gpu configs instead of top-level configs. Notably, gpu configs use adamw instead of adafactor"
echo " -d, --dtype Data type, defaults to bfloat16."
echo " --enable-te {0,1} 1 to enable, 0 to disable; defaults to ENABLE_TE in env or 0 if unset"
echo " -e, --epochs Number of epochs to run, defaults to 7."
echo " --multiprocess Enable the multiprocess GPU mode."
echo " -o, --output NAME Name for the output folder, a temporary folder will be created if none specified."
echo " --seed INT Random seed for deterministim. Defaults to 42."
echo " -s, --steps-per-epoch INT Steps per epoch. Detauls to 100"
echo " -h, --help Print usage."
exit $1
}

args=$(getopt -o a:b:cd:e:o:s:h --long additional-args:,batch-size:,use-contrib-configs,dtype:,epochs:,help,multiprocess,output:,steps-per-epoch: -- "$@")
args=$(getopt -o a:b:cd:e:ho:s: --long additional-args:,batch-size:,use-contrib-configs,dtype:,enable-te:,epochs:,help,multiprocess,output:,seed:,steps-per-epoch: -- "$@")
if [[ $? -ne 0 ]]; then
exit 1
fi
Expand All @@ -37,7 +40,9 @@ DTYPE=bfloat16
EPOCHS=7
MULTIPROCESS=0
OUTPUT=$(mktemp -d)
SEED=42
STEPS_PER_EPOCH=100
ENABLE_TE=${ENABLE_TE:-0}

eval set -- "$args"
while [ : ]; do
Expand All @@ -58,10 +63,17 @@ while [ : ]; do
DTYPE="$2"
shift 2
;;
--enable-te)
ENABLE_TE="$2"
shift 2
;;
-e | --epochs)
EPOCHS="$2"
shift 2
;;
-h | --help)
usage 1
;;
--multiprocess)
MULTIPROCESS=1
shift 1
Expand All @@ -70,13 +82,14 @@ while [ : ]; do
OUTPUT="$2"
shift 2
;;
--seed)
SEED="$2"
shift 2
;;
-s | --steps-per-epoch)
STEPS_PER_EPOCH="$2"
shift 2
;;
-h | --help)
usage 1
;;
--)
shift;
break
Expand All @@ -100,6 +113,7 @@ print_var ADDITIONAL_ARGS
print_var BATCH_SIZE
print_var USE_CONTRIB_CONFIGS
print_var DTYPE
print_var ENABLE_TE
print_var EPOCHS
print_var OUTPUT
print_var MULTIPROCESS
Expand Down Expand Up @@ -176,7 +190,8 @@ EOF

## Launch
set -exou pipefail
python -m t5x.train \

ENABLE_TE=$ENABLE_TE python -m t5x.train \
--gin_file benchmark.gin \
--gin.MODEL_DIR=\"${OUTPUT}\" \
--gin.network.T5Config.dtype=\"${DTYPE}\" \
Expand All @@ -185,7 +200,9 @@ python -m t5x.train \
--gin.train.eval_steps=0 \
--gin.train.eval_period=${STEPS_PER_EPOCH} \
--gin.CheckpointConfig.save=None \
--gin.train/utils.DatasetConfig.seed=${SEED} \
--gin.train_eval/utils.DatasetConfig.seed=${SEED} \
--gin.train.random_seed=${SEED} \
$ADDITIONAL_ARGS \
$([[ $MULTIPROCESS != 0 ]] && echo --multiprocess_gpu)
set +x
echo "Output at ${OUTPUT}"
4 changes: 0 additions & 4 deletions .github/workflows/_publish_container.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,6 @@ jobs:
docker buildx imagetools create --tag $tag ${{ steps.get-manifests.outputs.manifests }}
done
- name: Skopeo Login to GitHub Container Registry
run: |
echo ${{ secrets.GITHUB_TOKEN }} | skopeo login --authfile - ghcr.io
- name: Create single-arch images
if: ${{ inputs.EXPOSE_SINGLE_ARCH_IMAGES }}
shell: bash -x -e {0}
Expand Down
98 changes: 98 additions & 0 deletions .github/workflows/_retrofit_container.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
name: ~split multi-arch OCI manifests into Docker Image Manifest V2, Schema 2

on:
workflow_call:
inputs:
SOURCE_IMAGE:
type: string
description: 'Source docker image:'
required: true
TARGET_TAGS:
type: string
description: 'Target docker tags in docker/metadata-action format:'
required: true
EXPOSE_SINGLE_ARCH_IMAGES:
type: boolean
description: 'Also expose single-arch images:'
required: false
default: true
outputs:
# MULTIARCH_TAG:
# description: "Tags of the multi-arch image published"
# value: ${{ jobs.publish.outputs.MULTIARCH_TAG }}
SINGLEARCH_TAGS:
description: "Tags of the single-arch images published"
value: ${{ jobs.publish.outputs.SINGLEARCH_TAGS }}

env:
DOCKER_REPOSITORY: 'ghcr.io/nvidia/jax-toolbox-retrofit'

jobs:
publish:
runs-on: ubuntu-22.04
outputs:
# MULTIARCH_TAG: ${{ steps.meta.outputs.tags }}
SINGLEARCH_TAGS: ${{ steps.single-arch.outputs.tags }}
steps:
- name: Login to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Set docker metadata
id: meta
uses: docker/metadata-action@v4
with:
images: ${{ env.DOCKER_REPOSITORY }}
flavor: latest=false
tags: ${{ inputs.TARGET_TAGS }}

- name: Extract manifests from the source manifest list
id: get-manifests
shell: bash -x -e {0}
run: |
SOURCE_REPO=$(echo ${{ inputs.SOURCE_IMAGE }} | cut -d: -f1)
MEDIA_TYPE=$(docker manifest inspect ${{ inputs.SOURCE_IMAGE }} | jq -r '.mediaType')
if [[ ${MEDIA_TYPE} != "application/vnd.oci.image.index.v1+json" ]]; then
echo "This workflow only work with OCI manifest lists"
exit 1
fi
MANIFESTS=$(
docker manifest inspect ${{ inputs.SOURCE_IMAGE }} |\
jq -r '.manifests[] | select(.platform.os != "unknown") | .digest' |\
xargs -I{} echo ${SOURCE_REPO}@{} |\
tr '\n' ' '
)
echo "manifests=$MANIFESTS" >> $GITHUB_OUTPUT
## Requires skopeo >= v1.6.0, but Actions only has v1.4.0
# - name: Create Docker v2s2 multi-arch manifest list
# id: multi-arch
# shell: bash -x -e {0}
# run: |
# for tag in $(echo "${{ steps.meta.outputs.tags }}"); do
# skopeo copy --multi-arch all --format v2s2 docker://${{ inputs.SOURCE_IMAGE }} docker://$tag
# done

- name: Create Docker v2s2 single-arch manifests
id: single-arch
if: ${{ inputs.EXPOSE_SINGLE_ARCH_IMAGES }}
shell: bash -x -e {0}
run: |
output_tags=""
# Create new manifest list from extracted manifests
for manifest in ${{ steps.get-manifests.outputs.manifests }}; do
os=$(docker manifest inspect -v $manifest | jq -r '.Descriptor.platform.os')
arch=$(docker manifest inspect -v $manifest | jq -r '.Descriptor.platform.architecture')
for tag in $(echo "${{ steps.meta.outputs.tags }}"); do
single_arch_tag="${tag}-${os}-${arch}"
skopeo copy --format v2s2 docker://$manifest docker://${single_arch_tag}
output_tags="${output_tags} ${single_arch_tag}"
done
done
echo "tags=${output_tags}" >> $GITHUB_OUTPUT
Loading

0 comments on commit 8600d40

Please sign in to comment.