Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Publishes rosetta CUDA 12.1 and containers for pax and t5x AND containers for nightlies based on jax-pinned images for CUDA 12.1 and 12.2 #286

Merged
merged 24 commits into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4a5d83a
Enables multiarch builds of rosetta images
terrykong Oct 5, 2023
877e755
update buildkit to maybe address .manifests[] key being missing
terrykong Oct 5, 2023
62df68d
Update so t5x only build amd64
terrykong Oct 6, 2023
de739a7
Remove step to validate library's in base image in rosetta build
terrykong Oct 9, 2023
7f8cd4b
Merge branch 'main' into multiarch-rosetta
yhtang Oct 9, 2023
a514468
No need for rmdir /opt/*-mirrors and exit early in rosetta dockerfile…
terrykong Oct 9, 2023
cf07de5
Disabling extra rosetta installs in rosetta-pax image since they are
terrykong Oct 10, 2023
0e76cf7
Removing extra installs instead of commenting out
terrykong Oct 10, 2023
eb14bed
Publishes rosetta 12.1 containers for pax and t5x
terrykong Oct 6, 2023
6eaf361
Add PLATFORMS to 12.1 workflow
terrykong Oct 6, 2023
1b0f92b
Merge branch 'main' into rosetta-cuda-121
yhtang Oct 10, 2023
ade02fb
Add a second workflow that will build from a pinned jax
terrykong Oct 10, 2023
5822e48
Rename for clarity
terrykong Oct 10, 2023
18f7aa4
experiment: disable bazel cache
yhtang Oct 10, 2023
1ab545b
add cuda12.2 jax pin workflow and remove tags for the jax pinned
terrykong Oct 10, 2023
e7e2f63
Change rosetta-t5x to only build for x86 since arm isn't ready yet
terrykong Oct 11, 2023
1ac8541
Merge branch 'main' into rosetta-cuda-121
terrykong Oct 12, 2023
d6ec8ae
Add custom t5x ref/repo due to build failure
terrykong Oct 12, 2023
292dd1a
Switch to mirror repo for t5x patches
terrykong Oct 12, 2023
a37c1ae
Add mechanism to pin TE in t5x
terrykong Oct 12, 2023
b8841d1
fix
terrykong Oct 12, 2023
0680206
use base image digest as Bazel cache key
yhtang Oct 13, 2023
05078d9
Merge branch 'main' into rosetta-cuda-121
yhtang Oct 13, 2023
5596eb0
bazel automatically includes platform and toolchain as cache key
yhtang Oct 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/container/Dockerfile.t5x
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ ADD install-te.sh /usr/local/bin
ENV NVTE_FRAMEWORK=jax
ARG REPO_T5X=https://github.com/google-research/t5x.git
ARG REF_T5X=main
ARG REPO_TE=https://github.com/NVIDIA/TransformerEngine.git
ARG REF_TE=main
RUN <<"EOF" bash -ex
install-t5x.sh --defer --from ${REPO_T5X} --ref ${REF_T5X}
install-te.sh --defer
install-te.sh --defer --from ${REPO_TE} --ref ${REF_TE}

if [[ -f /opt/requirements-defer.txt ]]; then
pip install -r /opt/requirements-defer.txt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/_build_jax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ jobs:
"SSH_KNOWN_HOSTS=${{ steps.ssh-known-hosts.outputs.FILE }}"
build-args: |
BASE_IMAGE=${{ inputs.BASE_IMAGE }}
BAZEL_CACHE=${{ vars.BAZEL_REMOTE_CACHE_URL }}/${{ matrix.PLATFORM }}
BAZEL_CACHE=${{ vars.BAZEL_REMOTE_CACHE_URL }}
BUILD_DATE=${{ inputs.BUILD_DATE }}
REPO_JAX=${{ inputs.REPO_JAX }}
REPO_XLA=${{ inputs.REPO_XLA }}
Expand Down
12 changes: 12 additions & 0 deletions .github/workflows/_build_t5x.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ on:
description: Git commit, tag, or branch for T5X
required: false
default: main
REPO_TE:
type: string
description: URL of TE repository to check out
required: false
default: "https://github.com/NVIDIA/TransformerEngine.git"
REF_TE:
type: string
description: Git commit, tag, or branch for TE
required: false
default: main
outputs:
DOCKER_TAGS:
description: "Tags of the image built"
Expand Down Expand Up @@ -88,3 +98,5 @@ jobs:
BUILD_DATE=${{ inputs.BUILD_DATE }}
REPO_T5X=${{ inputs.REPO_T5X }}
REF_T5X=${{ inputs.REF_T5X }}
REPO_TE=${{ inputs.REPO_TE }}
REF_TE=${{ inputs.REF_TE }}
190 changes: 190 additions & 0 deletions .github/workflows/cuda-121-jax-pin.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
name: Nightly Containers on CUDA 12.1 (JAX pinned)

on:
schedule:
- cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC
workflow_dispatch:
inputs:
JAX_BASE_IMAGE:
type: string
description: 'Base Multiarch JAX Image'
default: 'ghcr.io/nvidia/jax-toolbox-internal:6473019396-jax-multiarch'
required: true
REPO_T5X:
type: string
description: URL of T5X repository to check out
required: false
default: "https://github.com/nvjax-svc-0/t5x.git"
REF_T5X:
type: string
description: Git commit, tag, or branch for T5X
required: false
default: unpin-tfds-gpu-extra
REPO_TE:
type: string
description: URL of TE repository to check out
required: false
default: "https://github.com/NVIDIA/TransformerEngine.git"
REF_TE:
type: string
description: Git commit, tag, or branch for TE
required: false
default: v0.13
PUBLISH:
type: boolean
description: Publish dated images and update the 'latest' tag?
default: false
required: false

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}

permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
packages: write # to upload container

env:
DEFAULT_JAX_BASE_IMAGE: ghcr.io/nvidia/jax-toolbox-internal:6473019396-jax-multiarch
DEFAULT_REPO_T5X: https://github.com/nvjax-svc-0/t5x.git
DEFAULT_REF_T5X: unpin-tfds-gpu-extra
DEFAULT_REPO_TE: https://github.com/NVIDIA/TransformerEngine.git
DEFAULT_REF_TE: v0.13

jobs:

metadata:
runs-on: ubuntu-22.04
outputs:
BUILD_DATE: ${{ steps.meta.outputs.BUILD_DATE }}
JAX_BASE_IMAGE: ${{ steps.meta.outputs.JAX_BASE_IMAGE}}
REPO_T5X: ${{ steps.meta.outputs.REPO_T5X }}
REF_T5X: ${{ steps.meta.outputs.REF_T5X }}
REPO_TE: ${{ steps.meta.outputs.REPO_TE }}
REF_TE: ${{ steps.meta.outputs.REF_TE }}
steps:
- name: Set build date and base image
id: meta
shell: bash -x -e {0}
run: |
BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d')
echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT
if [[ -z "${{ inputs.JAX_BASE_IMAGE }}" ]]; then
echo "JAX_BASE_IMAGE=${{ env.DEFAULT_JAX_BASE_IMAGE }}" >> $GITHUB_OUTPUT
else
echo "JAX_BASE_IMAGE=${{ inputs.JAX_BASE_IMAGE }}" >> $GITHUB_OUTPUT
fi
if [[ -z "${{ inputs.REPO_T5X }}" ]]; then
echo "REPO_T5X=${{ env.DEFAULT_REPO_T5X }}" >> $GITHUB_OUTPUT
else
echo "REPO_T5X=${{ inputs.REPO_T5X }}" >> $GITHUB_OUTPUT
fi
if [[ -z "${{ inputs.REF_T5X }}" ]]; then
echo "REF_T5X=${{ env.DEFAULT_REF_T5X }}" >> $GITHUB_OUTPUT
else
echo "REF_T5X=${{ inputs.REF_T5X }}" >> $GITHUB_OUTPUT
fi
if [[ -z "${{ inputs.REPO_TE }}" ]]; then
echo "REPO_TE=${{ env.DEFAULT_REPO_TE }}" >> $GITHUB_OUTPUT
else
echo "REPO_TE=${{ inputs.REPO_TE }}" >> $GITHUB_OUTPUT
fi
if [[ -z "${{ inputs.REF_TE }}" ]]; then
echo "REF_TE=${{ env.DEFAULT_REF_TE }}" >> $GITHUB_OUTPUT
else
echo "REF_TE=${{ inputs.REF_TE }}" >> $GITHUB_OUTPUT
fi

build-pax:
needs: [metadata]
uses: ./.github/workflows/_build_pax.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.metadata.outputs.JAX_BASE_IMAGE }}
secrets: inherit

build-rosetta-pax:
uses: ./.github/workflows/_build_rosetta.yaml
needs: [metadata, build-pax]
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-pax.outputs.DOCKER_TAGS }}
BASE_LIBRARY: pax
PLATFORMS: '["amd64"]'
secrets: inherit

build-t5x:
needs: [metadata]
uses: ./.github/workflows/_build_t5x.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.metadata.outputs.JAX_BASE_IMAGE }}
REPO_T5X: ${{ needs.metadata.outputs.REPO_T5X }}
REF_T5X: ${{ needs.metadata.outputs.REF_T5X }}
REPO_TE: ${{ needs.metadata.outputs.REPO_TE }}
REF_TE: ${{ needs.metadata.outputs.REF_TE }}
secrets: inherit

build-rosetta-t5x:
uses: ./.github/workflows/_build_rosetta.yaml
needs: [metadata, build-t5x]
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-t5x.outputs.DOCKER_TAGS }}
BASE_LIBRARY: t5x
PLATFORMS: '["amd64"]'
secrets: inherit

build-summary:
needs: [metadata, build-t5x, build-rosetta-t5x, build-pax, build-rosetta-pax]
if: always()
runs-on: ubuntu-22.04
steps:
- name: Generate job summary for container build
shell: bash -x -e {0}
run: |
cat > $GITHUB_STEP_SUMMARY << EOF
# Images created

| Image | Link |
| ------------ | -------------------------------------------------- |
| JAX (input) | ${{ needs.metadata.outputs.JAX_BASE_IMAGE }} |
| T5X | ${{ needs.build-t5x.outputs.DOCKER_TAGS }} |
| ROSETTA(T5X) | ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAGS }} |
| PAX | ${{ needs.build-pax.outputs.DOCKER_TAGS }} |
| ROSETTA(pax) | ${{ needs.build-rosetta-pax.outputs.DOCKER_TAGS }} |
EOF

test-jax:
needs: metadata
uses: ./.github/workflows/_test_jax.yaml
with:
JAX_IMAGE: ${{ needs.metadata.outputs.JAX_BASE_IMAGE }}
secrets: inherit

test-pax:
needs: build-pax
uses: ./.github/workflows/_test_pax.yaml
with:
PAX_IMAGE: ${{ needs.build-pax.outputs.DOCKER_TAGS }}
secrets: inherit

test-t5x:
needs: build-t5x
uses: ./.github/workflows/_test_t5x.yaml
with:
T5X_IMAGE: ${{ needs.build-t5x.outputs.DOCKER_TAGS }}
secrets: inherit

# TODO(terry): This is missing the rosetta tests which can only be added
# After a fix for the TB log collision is pushed.

finalize:
if: always()
# TODO: use dynamic matrix to make dependencies self-updating
needs: [build-summary, test-jax, test-pax]
uses: ./.github/workflows/_finalize.yaml
with:
PUBLISH_BADGE: false
secrets: inherit
Loading