Skip to content

Add Pax singleprocess train and eval tests #1001

Add Pax singleprocess train and eval tests

Add Pax singleprocess train and eval tests #1001

Workflow file for this run

name: CI
on:
pull_request:
paths-ignore:
- '**.md'
workflow_dispatch:
inputs:
CUDA_IMAGE:
type: string
description: 'Base CUDA image, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04'
required: false
default: 'latest'
SRC_JAX:
description: 'JAX source: <repo>#<branch|tag|commit>'
type: string
required: true
default: 'https://github.com/google/jax.git#main'
SRC_XLA:
description: 'XLA source: <repo>#<branch|tag|commit>'
type: string
required: true
default: 'https://github.com/openxla/xla.git#main'
SRC_TE:
description: 'TE source: <repo>#<branch|tag|commit>'
type: string
required: true
default: 'https://github.com/NVIDIA/TransformerEngine.git#main'
SRC_T5X:
description: 'T5X source: <repo>#<branch|tag|commit>'
type: string
required: true
default: 'https://github.com/google-research/t5x.git#main'
SRC_PAXML:
description: 'Paxml source: <repo>#<branch|tag|commit>'
type: string
required: true
default: 'https://github.com/google/paxml.git#main'
SRC_PRAXIS:
description: 'Praxis source: <repo>#<branch|tag|commit>'
type: string
required: true
default: 'https://github.com/google/praxis.git#main'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
packages: write # to upload container
jobs:
metadata:
runs-on: ubuntu-22.04
outputs:
BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }}
REPO_JAX: ${{ steps.parse-inputs.outputs.REPO_JAX }}
REF_JAX: ${{ steps.parse-inputs.outputs.REF_JAX }}
REPO_XLA: ${{ steps.parse-inputs.outputs.REPO_XLA }}
REF_XLA: ${{ steps.parse-inputs.outputs.REF_XLA }}
REPO_TE: ${{ steps.parse-inputs.outputs.REPO_TE }}
REF_TE: ${{ steps.parse-inputs.outputs.REF_TE }}
REPO_T5X: ${{ steps.parse-inputs.outputs.REPO_T5X }}
REF_T5X: ${{ steps.parse-inputs.outputs.REF_T5X }}
REPO_PAXML: ${{ steps.parse-inputs.outputs.REPO_PAXML }}
REF_PAXML: ${{ steps.parse-inputs.outputs.REF_PAXML }}
REPO_PRAXIS: ${{ steps.parse-inputs.outputs.REPO_PRAXIS }}
REF_PRAXIS: ${{ steps.parse-inputs.outputs.REF_PRAXIS }}
steps:
- name: Set build date
id: date
shell: bash -x -e {0}
run: |
BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d')
echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT
- name: Parse inputs
id: parse-inputs
shell: bash -x -e {0}
run: |
# split input in the format of repo#ref into repo and ref parts
parse_git_src() {
PACKAGE=$1
INPUT="$2"
DEFAULT="$3"
SRC="${INPUT:-${DEFAULT}}"
echo "REPO_${PACKAGE}=$(echo "${SRC}" | cut -f1 -d#)" >> $GITHUB_OUTPUT
echo "REF_${PACKAGE}=$(echo "${SRC}" | cut -f2 -d#)" >> $GITHUB_OUTPUT
}
# default values are for `pull_request`` event types
parse_git_src JAX "${{ inputs.SRC_JAX }}" "https://github.com/google/jax.git#main"
parse_git_src XLA "${{ inputs.SRC_XLA }}" "https://github.com/openxla/xla.git#main"
parse_git_src TE "${{ inputs.SRC_TE }}" "https://github.com/NVIDIA/TransformerEngine.git#main"
parse_git_src T5X "${{ inputs.SRC_T5X }}" "https://github.com/google-research/t5x.git#main"
parse_git_src PAXML "${{ inputs.SRC_PAXML }}" "https://github.com/google/paxml.git#main"
parse_git_src PRAXIS "${{ inputs.SRC_PRAXIS }}" "https://github.com/google/praxis.git#main"
build-base:
needs: metadata
uses: ./.github/workflows/_build_base.yaml
with:
BASE_IMAGE: ${{ inputs.CUDA_IMAGE || 'latest' }}
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
secrets: inherit
build-jax:
needs: [metadata, build-base]
uses: ./.github/workflows/_build_jax.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-base.outputs.DOCKER_TAGS }}
REPO_JAX: ${{ needs.metadata.outputs.REPO_JAX }}
REF_JAX: ${{ needs.metadata.outputs.REF_JAX }}
REPO_XLA: ${{ needs.metadata.outputs.REPO_XLA }}
REF_XLA: ${{ needs.metadata.outputs.REF_XLA }}
secrets: inherit
build-te:
needs: [metadata, build-jax]
uses: ./.github/workflows/_build_te.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }}
REPO_TE: ${{ needs.metadata.outputs.REPO_TE }}
REF_TE: ${{ needs.metadata.outputs.REF_TE }}
secrets: inherit
build-t5x:
needs: [metadata, build-jax]
uses: ./.github/workflows/_build_t5x.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }}
REPO_T5X: ${{ needs.metadata.outputs.REPO_T5X }}
REF_T5X: ${{ needs.metadata.outputs.REF_T5X }}
secrets: inherit
build-pax:
needs: [metadata, build-jax]
uses: ./.github/workflows/_build_pax.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }}
REPO_PAXML: ${{ needs.metadata.outputs.REPO_PAXML }}
REF_PAXML: ${{ needs.metadata.outputs.REF_PAXML }}
REPO_PRAXIS: ${{ needs.metadata.outputs.REPO_PRAXIS }}
REF_PRAXIS: ${{ needs.metadata.outputs.REF_PRAXIS }}
secrets: inherit
build-rosetta-t5x:
uses: ./.github/workflows/_build_rosetta.yaml
needs: [metadata, build-t5x]
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-t5x.outputs.DOCKER_TAGS }}
BASE_LIBRARY: t5x
PLATFORMS: '["amd64"]'
secrets: inherit
build-rosetta-pax:
uses: ./.github/workflows/_build_rosetta.yaml
needs: [metadata, build-pax]
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-pax.outputs.DOCKER_TAGS }}
BASE_LIBRARY: pax
secrets: inherit
build-summary:
needs: [build-base, build-jax, build-te, build-t5x, build-pax, build-rosetta-t5x, build-rosetta-pax]
# needs: [build-base, build-jax, build-te, build-t5x, build-pax, build-pax-aarch64, build-rosetta-t5x, build-rosetta-pax]
if: always()
runs-on: ubuntu-22.04
steps:
- name: Generate job summary for container build
shell: bash -x -e {0}
run: |
cat > $GITHUB_STEP_SUMMARY << EOF
# Images created
| Image | Link |
| ------------ | -------------------------------------------------- |
| Base | ${{ needs.build-base.outputs.DOCKER_TAGS }} |
| JAX | ${{ needs.build-jax.outputs.DOCKER_TAGS }} |
| JAX-TE | ${{ needs.build-te.outputs.DOCKER_TAGS }} |
| T5X | ${{ needs.build-t5x.outputs.DOCKER_TAGS }} |
| PAX | ${{ needs.build-pax.outputs.DOCKER_TAGS }} |
| ROSETTA(t5x) | ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAGS }} |
| ROSETTA(pax) | ${{ needs.build-rosetta-pax.outputs.DOCKER_TAGS }} |
EOF
retrofit-containers:
needs: [build-base, build-jax, build-te, build-t5x, build-pax, build-rosetta-t5x, build-rosetta-pax]
if: always()
runs-on: ubuntu-22.04
env:
DOCKER_REPO: 'ghcr.io/nvidia/jax-toolbox-retrofit'
steps:
- name: Login to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
## Requires skopeo >= v1.6.0, but Actions only has v1.4.0
# - name: Create Docker v2s2 multi-arch manifest list
# id: multi-arch
# shell: bash -x -e {0}
# run: |
# for tag in $(echo "${{ steps.meta.outputs.tags }}"); do
# skopeo copy --multi-arch all --format v2s2 docker://${{ inputs.SOURCE_IMAGE }} docker://$tag
# done
- name: Create Docker v2s2 single-arch manifests
id: single-arch
shell: bash -x -e {0}
run: |
for source in \
${{ needs.build-base.outputs.DOCKER_TAGS }} \
${{ needs.build-jax.outputs.DOCKER_TAGS }} \
${{ needs.build-te.outputs.DOCKER_TAGS }} \
${{ needs.build-t5x.outputs.DOCKER_TAGS }} \
${{ needs.build-pax.outputs.DOCKER_TAGS }} \
${{ needs.build-rosetta-t5x.outputs.DOCKER_TAGS }} \
${{ needs.build-rosetta-pax.outputs.DOCKER_TAGS }} \
; do
source_repo=$(echo ${source} | cut -d: -f1)
media_type=$(docker manifest inspect ${source} | jq -r '.mediaType')
if [[ ${media_type} != "application/vnd.oci.image.index.v1+json" ]]; then
echo "Image ${source} is already in Docker format v2s2"
dest=${DOCKER_REPO}:$(echo ${source} | cut -d: -f2)
skopeo copy --format v2s2 docker://${source} docker://${dest}
echo "${dest}" >> $GITHUB_STEP_SUMMARY
else
manifests=$(
docker manifest inspect ${source} |\
jq -r '.manifests[] | select(.platform.os != "unknown") | .digest' |\
xargs -I{} echo ${source_repo}@{} |\
tr '\n' ' '
)
## registry/org/repo:tag -> repo-tag
# dest_tag=$(echo ${source} | cut -d: -f1 | cut -d/ -f3)-$(echo ${source} | cut -d: -f2)
## registry/org/repo:tag -> tag
dest_tag=$(echo ${source} | cut -d: -f2)
for manifest in ${manifests}; do
os=$(docker manifest inspect -v $manifest | jq -r '.Descriptor.platform.os')
arch=$(docker manifest inspect -v $manifest | jq -r '.Descriptor.platform.architecture')
# single_arch_tag="ghcr.io/nvidia/jax-toolbox-retrofit:${{ github.run_id }}-${dest_tag}-${os}-${arch}"
single_arch_tag="${DOCKER_REPO}:${dest_tag}-${os}-${arch}"
skopeo copy --format v2s2 docker://$manifest docker://${single_arch_tag}
echo "${single_arch_tag}" >> $GITHUB_STEP_SUMMARY
done
fi
done
test-distribution:
needs: metadata
uses: ./.github/workflows/_test_distribution.yaml
secrets: inherit
test-jax:
needs: build-jax
uses: ./.github/workflows/_test_jax.yaml
with:
JAX_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }}
secrets: inherit
test-te:
needs: build-te
uses: ./.github/workflows/_test_te.yaml
with:
JAX_TE_IMAGE: ${{ needs.build-te.outputs.DOCKER_TAGS }}
secrets: inherit
test-t5x:
needs: build-t5x
uses: ./.github/workflows/_test_t5x.yaml
with:
T5X_IMAGE: ${{ needs.build-t5x.outputs.DOCKER_TAGS }}
secrets: inherit
test-pax:
needs: build-pax
uses: ./.github/workflows/_test_pax.yaml
with:
PAX_IMAGE: ${{ needs.build-pax.outputs.DOCKER_TAGS }}
secrets: inherit
test-vit:
needs: build-rosetta-t5x
uses: ./.github/workflows/_test_vit.yaml
with:
ROSETTA_T5X_IMAGE: ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAGS }}
secrets: inherit
finalize:
if: always()
# TODO: use dynamic matrix to make dependencies self-updating
needs: [build-summary, test-distribution, test-jax, test-te, test-t5x, test-pax]
uses: ./.github/workflows/_finalize.yaml
with:
PUBLISH_BADGE: false
secrets: inherit