Skip to content

Updates rosetta CI to build multiarch (Arm64/amd64) for pax #759

Updates rosetta CI to build multiarch (Arm64/amd64) for pax

Updates rosetta CI to build multiarch (Arm64/amd64) for pax #759

Workflow file for this run

name: CI
on:
pull_request:
paths-ignore:
- '**.md'
workflow_dispatch:
inputs:
CUDA_IMAGE:
type: string
description: 'Base CUDA image, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04'
required: false
default: 'latest'
SRC_JAX:
description: 'JAX source: <repo>#<branch|tag|commit>'
type: string
required: true
default: 'https://github.com/google/jax.git#main'
SRC_XLA:
description: 'XLA source: <repo>#<branch|tag|commit>'
type: string
required: true
default: 'https://github.com/openxla/xla.git#main'
SRC_TE:
description: 'TE source: <repo>#<branch|tag|commit>'
type: string
required: true
default: 'https://github.com/NVIDIA/TransformerEngine.git#main'
SRC_T5X:
description: 'T5X source: <repo>#<branch|tag|commit>'
type: string
required: true
default: 'https://github.com/google-research/t5x.git#main'
SRC_PAXML:
description: 'Paxml source: <repo>#<branch|tag|commit>'
type: string
required: true
default: 'https://github.com/google/paxml.git#main'
SRC_PRAXIS:
description: 'Praxis source: <repo>#<branch|tag|commit>'
type: string
required: true
default: 'https://github.com/google/praxis.git#main'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
packages: write # to upload container
jobs:
metadata:
runs-on: ubuntu-22.04
outputs:
BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }}
REPO_JAX: ${{ steps.parse-inputs.outputs.REPO_JAX }}
REF_JAX: ${{ steps.parse-inputs.outputs.REF_JAX }}
REPO_XLA: ${{ steps.parse-inputs.outputs.REPO_XLA }}
REF_XLA: ${{ steps.parse-inputs.outputs.REF_XLA }}
REPO_TE: ${{ steps.parse-inputs.outputs.REPO_TE }}
REF_TE: ${{ steps.parse-inputs.outputs.REF_TE }}
REPO_T5X: ${{ steps.parse-inputs.outputs.REPO_T5X }}
REF_T5X: ${{ steps.parse-inputs.outputs.REF_T5X }}
REPO_PAXML: ${{ steps.parse-inputs.outputs.REPO_PAXML }}
REF_PAXML: ${{ steps.parse-inputs.outputs.REF_PAXML }}
REPO_PRAXIS: ${{ steps.parse-inputs.outputs.REPO_PRAXIS }}
REF_PRAXIS: ${{ steps.parse-inputs.outputs.REF_PRAXIS }}
steps:
- name: Set build date
id: date
shell: bash -x -e {0}
run: |
BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d')
echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT
- name: Parse inputs
id: parse-inputs
shell: bash -x -e {0}
run: |
# split input in the format of repo#ref into repo and ref parts
parse_git_src() {
PACKAGE=$1
INPUT="$2"
DEFAULT="$3"
SRC="${INPUT:-${DEFAULT}}"
echo "REPO_${PACKAGE}=$(echo "${SRC}" | cut -f1 -d#)" >> $GITHUB_OUTPUT
echo "REF_${PACKAGE}=$(echo "${SRC}" | cut -f2 -d#)" >> $GITHUB_OUTPUT
}
# default values are for `pull_request`` event types
parse_git_src JAX "${{ inputs.SRC_JAX }}" "https://github.com/google/jax.git#main"
parse_git_src XLA "${{ inputs.SRC_XLA }}" "https://github.com/openxla/xla.git#main"
parse_git_src TE "${{ inputs.SRC_TE }}" "https://github.com/NVIDIA/TransformerEngine.git#main"
parse_git_src T5X "${{ inputs.SRC_T5X }}" "https://github.com/google-research/t5x.git#main"
parse_git_src PAXML "${{ inputs.SRC_PAXML }}" "https://github.com/google/paxml.git#main"
parse_git_src PRAXIS "${{ inputs.SRC_PRAXIS }}" "https://github.com/google/praxis.git#main"
build-base:
needs: metadata
uses: ./.github/workflows/_build_base.yaml
with:
BASE_IMAGE: ${{ inputs.CUDA_IMAGE || 'latest' }}
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
secrets: inherit
build-jax:
needs: [metadata, build-base]
uses: ./.github/workflows/_build_jax.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-base.outputs.DOCKER_TAGS }}
REPO_JAX: ${{ needs.metadata.outputs.REPO_JAX }}
REF_JAX: ${{ needs.metadata.outputs.REF_JAX }}
REPO_XLA: ${{ needs.metadata.outputs.REPO_XLA }}
REF_XLA: ${{ needs.metadata.outputs.REF_XLA }}
secrets: inherit
build-te:
needs: [metadata, build-jax]
uses: ./.github/workflows/_build_te.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }}
REPO_TE: ${{ needs.metadata.outputs.REPO_TE }}
REF_TE: ${{ needs.metadata.outputs.REF_TE }}
secrets: inherit
build-t5x:
needs: [metadata, build-jax]
uses: ./.github/workflows/_build_t5x.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }}
REPO_T5X: ${{ needs.metadata.outputs.REPO_T5X }}
REF_T5X: ${{ needs.metadata.outputs.REF_T5X }}
secrets: inherit
build-pax:
needs: [metadata, build-jax]
uses: ./.github/workflows/_build_pax.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }}
REPO_PAXML: ${{ needs.metadata.outputs.REPO_PAXML }}
REF_PAXML: ${{ needs.metadata.outputs.REF_PAXML }}
REPO_PRAXIS: ${{ needs.metadata.outputs.REPO_PRAXIS }}
REF_PRAXIS: ${{ needs.metadata.outputs.REF_PRAXIS }}
secrets: inherit
# build-pax-aarch64:
# needs: [metadata, build-jax]
# uses: ./.github/workflows/_build_pax_aarch64.yaml
# with:
# # BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
# BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }}
# # REPO_PAXML: ${{ needs.metadata.outputs.REPO_PAXML }}
# # REF_PAXML: ${{ needs.metadata.outputs.REF_PAXML }}
# # REPO_PRAXIS: ${{ needs.metadata.outputs.REPO_PRAXIS }}
# # REF_PRAXIS: ${{ needs.metadata.outputs.REF_PRAXIS }}
# secrets: inherit
build-rosetta-t5x:
uses: ./.github/workflows/_build_rosetta.yaml
needs: [metadata, build-t5x]
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-t5x.outputs.DOCKER_TAGS }}
BASE_LIBRARY: t5x
secrets: inherit
build-rosetta-pax:
uses: ./.github/workflows/_build_rosetta.yaml
needs: [metadata, build-pax]
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-pax.outputs.DOCKER_TAGS }}
BASE_LIBRARY: pax
secrets: inherit
build-summary:
needs: [build-base, build-jax, build-te, build-t5x, build-pax, build-rosetta-t5x, build-rosetta-pax]
# needs: [build-base, build-jax, build-te, build-t5x, build-pax, build-pax-aarch64, build-rosetta-t5x, build-rosetta-pax]
if: always()
runs-on: ubuntu-22.04
steps:
- name: Generate job summary for container build
shell: bash -x -e {0}
run: |
cat > $GITHUB_STEP_SUMMARY << EOF
# Images created
| Image | Link |
| ------------ | -------------------------------------------------- |
| Base | ${{ needs.build-base.outputs.DOCKER_TAGS }} |
| JAX | ${{ needs.build-jax.outputs.DOCKER_TAGS }} |
| JAX-TE | ${{ needs.build-te.outputs.DOCKER_TAGS }} |
| T5X | ${{ needs.build-t5x.outputs.DOCKER_TAGS }} |
| PAX | ${{ needs.build-pax.outputs.DOCKER_TAGS }} |
| ROSETTA(t5x) | ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAGS }} |
| ROSETTA(pax) | ${{ needs.build-rosetta-pax.outputs.DOCKER_TAGS }} |
EOF
test-distribution:
needs: metadata
uses: ./.github/workflows/_test_distribution.yaml
secrets: inherit
test-jax:
needs: build-jax
uses: ./.github/workflows/_test_jax.yaml
with:
JAX_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAGS }}
secrets: inherit
test-te:
needs: build-te
uses: ./.github/workflows/_test_te.yaml
with:
JAX_TE_IMAGE: ${{ needs.build-te.outputs.DOCKER_TAGS }}
secrets: inherit
test-t5x:
needs: build-t5x
uses: ./.github/workflows/_test_t5x.yaml
with:
T5X_IMAGE: ${{ needs.build-t5x.outputs.DOCKER_TAGS }}
secrets: inherit
test-pax:
needs: build-pax
uses: ./.github/workflows/_test_pax.yaml
with:
PAX_IMAGE: ${{ needs.build-pax.outputs.DOCKER_TAGS }}
secrets: inherit
finalize:
if: always()
# TODO: use dynamic matrix to make dependencies self-updating
needs: [build-summary, test-distribution, test-jax, test-te, test-t5x, test-pax]
uses: ./.github/workflows/_finalize.yaml
with:
PUBLISH_BADGE: false
secrets: inherit