Add Pax singleprocess train and eval tests #1001
Workflow file for this run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: CI | |
on: | |
pull_request: | |
paths-ignore: | |
- '**.md' | |
workflow_dispatch: | |
inputs: | |
CUDA_IMAGE: | |
type: string | |
description: 'Base CUDA image, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04' | |
required: false | |
default: 'latest' | |
SRC_JAX: | |
description: 'JAX source: <repo>#<branch|tag|commit>' | |
type: string | |
required: true | |
default: 'https://github.com/google/jax.git#main' | |
SRC_XLA: | |
description: 'XLA source: <repo>#<branch|tag|commit>' | |
type: string | |
required: true | |
default: 'https://github.com/openxla/xla.git#main' | |
SRC_TE: | |
description: 'TE source: <repo>#<branch|tag|commit>' | |
type: string | |
required: true | |
default: 'https://github.com/NVIDIA/TransformerEngine.git#main' | |
SRC_T5X: | |
description: 'T5X source: <repo>#<branch|tag|commit>' | |
type: string | |
required: true | |
default: 'https://github.com/google-research/t5x.git#main' | |
SRC_PAXML: | |
description: 'Paxml source: <repo>#<branch|tag|commit>' | |
type: string | |
required: true | |
default: 'https://github.com/google/paxml.git#main' | |
SRC_PRAXIS: | |
description: 'Praxis source: <repo>#<branch|tag|commit>' | |
type: string | |
required: true | |
default: 'https://github.com/google/praxis.git#main' | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | |
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} | |
permissions: | |
contents: read # to fetch code | |
actions: write # to cancel previous workflows | |
packages: write # to upload container | |
jobs: | |
metadata: | |
runs-on: ubuntu-22.04 | |
outputs: | |
BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }} | |
REPO_JAX: ${{ steps.parse-inputs.outputs.REPO_JAX }} | |
REF_JAX: ${{ steps.parse-inputs.outputs.REF_JAX }} | |
REPO_XLA: ${{ steps.parse-inputs.outputs.REPO_XLA }} | |
REF_XLA: ${{ steps.parse-inputs.outputs.REF_XLA }} | |
REPO_TE: ${{ steps.parse-inputs.outputs.REPO_TE }} | |
REF_TE: ${{ steps.parse-inputs.outputs.REF_TE }} | |
REPO_T5X: ${{ steps.parse-inputs.outputs.REPO_T5X }} | |
REF_T5X: ${{ steps.parse-inputs.outputs.REF_T5X }} | |
REPO_PAXML: ${{ steps.parse-inputs.outputs.REPO_PAXML }} | |
REF_PAXML: ${{ steps.parse-inputs.outputs.REF_PAXML }} | |
REPO_PRAXIS: ${{ steps.parse-inputs.outputs.REPO_PRAXIS }} | |
REF_PRAXIS: ${{ steps.parse-inputs.outputs.REF_PRAXIS }} | |
steps: | |
- name: Set build date | |
id: date | |
shell: bash -x -e {0} | |
run: | | |
BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d') | |
echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT | |
- name: Parse inputs | |
id: parse-inputs | |
shell: bash -x -e {0} | |
run: | | |
# split input in the format of repo#ref into repo and ref parts | |
parse_git_src() { | |
PACKAGE=$1 | |
INPUT="$2" | |
DEFAULT="$3" | |
SRC="${INPUT:-${DEFAULT}}" | |
echo "REPO_${PACKAGE}=$(echo "${SRC}" | cut -f1 -d#)" >> $GITHUB_OUTPUT | |
echo "REF_${PACKAGE}=$(echo "${SRC}" | cut -f2 -d#)" >> $GITHUB_OUTPUT | |
} | |
# default values are for `pull_request`` event types | |
parse_git_src JAX "${{ inputs.SRC_JAX }}" "https://github.com/google/jax.git#main" | |
parse_git_src XLA "${{ inputs.SRC_XLA }}" "https://github.com/openxla/xla.git#main" | |
parse_git_src TE "${{ inputs.SRC_TE }}" "https://github.com/NVIDIA/TransformerEngine.git#main" | |
parse_git_src T5X "${{ inputs.SRC_T5X }}" "https://github.com/google-research/t5x.git#main" | |
parse_git_src PAXML "${{ inputs.SRC_PAXML }}" "https://github.com/google/paxml.git#main" | |
parse_git_src PRAXIS "${{ inputs.SRC_PRAXIS }}" "https://github.com/google/praxis.git#main" | |
build-base: | |
needs: metadata | |
uses: ./.github/workflows/_build_base.yaml | |
with: | |
BASE_IMAGE: ${{ inputs.CUDA_IMAGE || 'latest' }} | |
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
secrets: inherit | |
build-jax: | |
needs: [metadata, build-base] | |
uses: ./.github/workflows/_build_jax.yaml | |
with: | |
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
BASE_IMAGE: ${{ needs.build-base.outputs.DOCKER_TAGS }} | |
REPO_JAX: ${{ needs.metadata.outputs.REPO_JAX }} | |
REF_JAX: ${{ needs.metadata.outputs.REF_JAX }} | |
REPO_XLA: ${{ needs.metadata.outputs.REPO_XLA }} | |
REF_XLA: ${{ needs.metadata.outputs.REF_XLA }} | |
secrets: inherit | |
build-te: | |
needs: [metadata, build-jax] | |
uses: ./.github/workflows/_build_te.yaml | |
with: | |
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }} | |
REPO_TE: ${{ needs.metadata.outputs.REPO_TE }} | |
REF_TE: ${{ needs.metadata.outputs.REF_TE }} | |
secrets: inherit | |
build-t5x: | |
needs: [metadata, build-jax] | |
uses: ./.github/workflows/_build_t5x.yaml | |
with: | |
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }} | |
REPO_T5X: ${{ needs.metadata.outputs.REPO_T5X }} | |
REF_T5X: ${{ needs.metadata.outputs.REF_T5X }} | |
secrets: inherit | |
build-pax: | |
needs: [metadata, build-jax] | |
uses: ./.github/workflows/_build_pax.yaml | |
with: | |
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }} | |
REPO_PAXML: ${{ needs.metadata.outputs.REPO_PAXML }} | |
REF_PAXML: ${{ needs.metadata.outputs.REF_PAXML }} | |
REPO_PRAXIS: ${{ needs.metadata.outputs.REPO_PRAXIS }} | |
REF_PRAXIS: ${{ needs.metadata.outputs.REF_PRAXIS }} | |
secrets: inherit | |
build-rosetta-t5x: | |
uses: ./.github/workflows/_build_rosetta.yaml | |
needs: [metadata, build-t5x] | |
with: | |
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
BASE_IMAGE: ${{ needs.build-t5x.outputs.DOCKER_TAGS }} | |
BASE_LIBRARY: t5x | |
PLATFORMS: '["amd64"]' | |
secrets: inherit | |
build-rosetta-pax: | |
uses: ./.github/workflows/_build_rosetta.yaml | |
needs: [metadata, build-pax] | |
with: | |
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
BASE_IMAGE: ${{ needs.build-pax.outputs.DOCKER_TAGS }} | |
BASE_LIBRARY: pax | |
secrets: inherit | |
build-summary: | |
needs: [build-base, build-jax, build-te, build-t5x, build-pax, build-rosetta-t5x, build-rosetta-pax] | |
# needs: [build-base, build-jax, build-te, build-t5x, build-pax, build-pax-aarch64, build-rosetta-t5x, build-rosetta-pax] | |
if: always() | |
runs-on: ubuntu-22.04 | |
steps: | |
- name: Generate job summary for container build | |
shell: bash -x -e {0} | |
run: | | |
cat > $GITHUB_STEP_SUMMARY << EOF | |
# Images created | |
| Image | Link | | |
| ------------ | -------------------------------------------------- | | |
| Base | ${{ needs.build-base.outputs.DOCKER_TAGS }} | | |
| JAX | ${{ needs.build-jax.outputs.DOCKER_TAGS }} | | |
| JAX-TE | ${{ needs.build-te.outputs.DOCKER_TAGS }} | | |
| T5X | ${{ needs.build-t5x.outputs.DOCKER_TAGS }} | | |
| PAX | ${{ needs.build-pax.outputs.DOCKER_TAGS }} | | |
| ROSETTA(t5x) | ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAGS }} | | |
| ROSETTA(pax) | ${{ needs.build-rosetta-pax.outputs.DOCKER_TAGS }} | | |
EOF | |
retrofit-containers: | |
needs: [build-base, build-jax, build-te, build-t5x, build-pax, build-rosetta-t5x, build-rosetta-pax] | |
if: always() | |
runs-on: ubuntu-22.04 | |
env: | |
DOCKER_REPO: 'ghcr.io/nvidia/jax-toolbox-retrofit' | |
steps: | |
- name: Login to GitHub Container Registry | |
uses: docker/login-action@v2 | |
with: | |
registry: ghcr.io | |
username: ${{ github.repository_owner }} | |
password: ${{ secrets.GITHUB_TOKEN }} | |
## Requires skopeo >= v1.6.0, but Actions only has v1.4.0 | |
# - name: Create Docker v2s2 multi-arch manifest list | |
# id: multi-arch | |
# shell: bash -x -e {0} | |
# run: | | |
# for tag in $(echo "${{ steps.meta.outputs.tags }}"); do | |
# skopeo copy --multi-arch all --format v2s2 docker://${{ inputs.SOURCE_IMAGE }} docker://$tag | |
# done | |
- name: Create Docker v2s2 single-arch manifests | |
id: single-arch | |
shell: bash -x -e {0} | |
run: | | |
for source in \ | |
${{ needs.build-base.outputs.DOCKER_TAGS }} \ | |
${{ needs.build-jax.outputs.DOCKER_TAGS }} \ | |
${{ needs.build-te.outputs.DOCKER_TAGS }} \ | |
${{ needs.build-t5x.outputs.DOCKER_TAGS }} \ | |
${{ needs.build-pax.outputs.DOCKER_TAGS }} \ | |
${{ needs.build-rosetta-t5x.outputs.DOCKER_TAGS }} \ | |
${{ needs.build-rosetta-pax.outputs.DOCKER_TAGS }} \ | |
; do | |
source_repo=$(echo ${source} | cut -d: -f1) | |
media_type=$(docker manifest inspect ${source} | jq -r '.mediaType') | |
if [[ ${media_type} != "application/vnd.oci.image.index.v1+json" ]]; then | |
echo "Image ${source} is already in Docker format v2s2" | |
dest=${DOCKER_REPO}:$(echo ${source} | cut -d: -f2) | |
skopeo copy --format v2s2 docker://${source} docker://${dest} | |
echo "${dest}" >> $GITHUB_STEP_SUMMARY | |
else | |
manifests=$( | |
docker manifest inspect ${source} |\ | |
jq -r '.manifests[] | select(.platform.os != "unknown") | .digest' |\ | |
xargs -I{} echo ${source_repo}@{} |\ | |
tr '\n' ' ' | |
) | |
## registry/org/repo:tag -> repo-tag | |
# dest_tag=$(echo ${source} | cut -d: -f1 | cut -d/ -f3)-$(echo ${source} | cut -d: -f2) | |
## registry/org/repo:tag -> tag | |
dest_tag=$(echo ${source} | cut -d: -f2) | |
for manifest in ${manifests}; do | |
os=$(docker manifest inspect -v $manifest | jq -r '.Descriptor.platform.os') | |
arch=$(docker manifest inspect -v $manifest | jq -r '.Descriptor.platform.architecture') | |
# single_arch_tag="ghcr.io/nvidia/jax-toolbox-retrofit:${{ github.run_id }}-${dest_tag}-${os}-${arch}" | |
single_arch_tag="${DOCKER_REPO}:${dest_tag}-${os}-${arch}" | |
skopeo copy --format v2s2 docker://$manifest docker://${single_arch_tag} | |
echo "${single_arch_tag}" >> $GITHUB_STEP_SUMMARY | |
done | |
fi | |
done | |
test-distribution: | |
needs: metadata | |
uses: ./.github/workflows/_test_distribution.yaml | |
secrets: inherit | |
test-jax: | |
needs: build-jax | |
uses: ./.github/workflows/_test_jax.yaml | |
with: | |
JAX_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }} | |
secrets: inherit | |
test-te: | |
needs: build-te | |
uses: ./.github/workflows/_test_te.yaml | |
with: | |
JAX_TE_IMAGE: ${{ needs.build-te.outputs.DOCKER_TAGS }} | |
secrets: inherit | |
test-t5x: | |
needs: build-t5x | |
uses: ./.github/workflows/_test_t5x.yaml | |
with: | |
T5X_IMAGE: ${{ needs.build-t5x.outputs.DOCKER_TAGS }} | |
secrets: inherit | |
test-pax: | |
needs: build-pax | |
uses: ./.github/workflows/_test_pax.yaml | |
with: | |
PAX_IMAGE: ${{ needs.build-pax.outputs.DOCKER_TAGS }} | |
secrets: inherit | |
test-vit: | |
needs: build-rosetta-t5x | |
uses: ./.github/workflows/_test_vit.yaml | |
with: | |
ROSETTA_T5X_IMAGE: ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAGS }} | |
secrets: inherit | |
finalize: | |
if: always() | |
# TODO: use dynamic matrix to make dependencies self-updating | |
needs: [build-summary, test-distribution, test-jax, test-te, test-t5x, test-pax] | |
uses: ./.github/workflows/_finalize.yaml | |
with: | |
PUBLISH_BADGE: false | |
secrets: inherit |