Skip to content

Commit

Permalink
Merge branch 'main' into cancel-ondemand-slurm-on-error
Browse files Browse the repository at this point in the history
  • Loading branch information
yhtang authored Feb 22, 2024
2 parents 4f5f152 + 8fdff5c commit 863799a
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 31 deletions.
4 changes: 3 additions & 1 deletion .github/container/Dockerfile.pallas
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ FROM ${BASE_IMAGE} as builder
ARG SRC_PATH_TRITON

# bump-openxla-triton.sh ensures that the commit of openxla-triton referenced
# in the manifest file is consistent with the commit of xla
# in the manifest file is consistent with the commit of xla, and that any extra
# patches are available under patches/openxla-triton
RUN get-source.sh -l openxla-triton -m ${MANIFEST_FILE}
RUN cd /opt/openxla-triton && for patch in /opt/manifest.d/patches/openxla-triton/*.patch; do patch -p1 < "${patch}"; done && git diff

RUN <<"EOF" bash -ex
mkdir -p "${SRC_PATH_TRITON}/dist"
Expand Down
1 change: 0 additions & 1 deletion .github/container/Dockerfile.pax.arm64
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip-
RUN <<"EOF" bash -ex
echo "tensorflow==2.13.0" >> /opt/pip-tools.d/requirements-paxml.in
echo "tensorflow_datasets==4.9.2" >> /opt/pip-tools.d/requirements-paxml.in
echo "chex==0.1.7" >> /opt/pip-tools.d/requirements-paxml.in
echo "auditwheel" >> /opt/pip-tools.d/requirements-paxml.in

get-source.sh -l paxml -m ${MANIFEST_FILE}
Expand Down
18 changes: 14 additions & 4 deletions .github/container/bump-openxla-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ usage() {
cat <<EOF
This script is a utility for updating the commit reference for openxla-triton
in a manifest YAML file used to build JAX-Toolbox images. The commit is derived
from the commit for xla contained in the manifest.
from the commit for xla contained in the manifest, along with the patches.
Usage: $0 [OPTION]...
-h, --help Print usage.
Expand Down Expand Up @@ -50,6 +50,7 @@ if [[ -z "${MANIFEST:-}" ]]; then
fi

set -eou pipefail
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

xla_url=$(yq e ".xla.url" $MANIFEST)
xla_tracking_ref=$(yq e ".xla.tracking_ref" $MANIFEST)
Expand All @@ -58,8 +59,17 @@ xla_commit=$(yq e ".xla.latest_verified_commit" $MANIFEST)
git clone --branch "${xla_tracking_ref}" --single-branch "${xla_url}" "${xla_repo}"
(cd "${xla_repo}" && git checkout "${xla_commit}")
# Extract the openxla/triton tag used by XLA. Even though it is called
# TRITON_COMMIT it is a tag. In principle we should also account for the
# patches in this .bzl file, but skip that for now.
openxla_triton_tag=$(sed -n -e 's#\s\+TRITON_COMMIT = "\(cl[0-9]\+\)"#\1#p' "${xla_repo}/third_party/triton/workspace.bzl")
# TRITON_COMMIT it is a tag.
workspace_file="${xla_repo}/third_party/triton/workspace.bzl"
openxla_triton_tag=$(sed -n -e 's#\s\+TRITON_COMMIT = "\(cl[0-9]\+\)"#\1#p' "${workspace_file}")
# Extract Triton patch files applied by XLA
patch_files=$(python3 -c 'import ast, sys; tree = ast.parse(sys.stdin.read()); print(" ".join(elem.value.removeprefix("//third_party/triton:").removesuffix(".patch") for node in ast.walk(tree) if isinstance(node, ast.keyword) and node.arg == "patch_file" for elem in node.value.elts))' < "${workspace_file}")
i=0
# Remove old patch files
rm -vf ${SCRIPT_DIR}/patches/openxla-triton/*.patch
for patch_file in ${patch_files}; do
cp -v "${xla_repo}/third_party/triton/${patch_file}.patch" "${SCRIPT_DIR}/patches/openxla-triton/${i}_${patch_file}.patch"
i=$((i+1))
done
rm -rf "${xla_repo}"
yq e ".openxla-triton.latest_verified_commit = \"${openxla_triton_tag}\"" -i $MANIFEST
15 changes: 15 additions & 0 deletions .github/container/patches/openxla-triton/0_cl602723852.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Remove once b/322980485 is fixed.
diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py
--- a/python/triton/backends/__init__.py
+++ b/python/triton/backends/__init__.py
@@ -46,5 +46,8 @@ def _discover_backends():
_find_concrete_subclasses(driver, DriverBase))
return backends

-
-backends = _discover_backends()
+from triton.backends.nvidia.driver import CudaDriver
+from triton.backends.nvidia.compiler import CUDABackend
+backends = {
+ "nvidia": Backend(CUDABackend, CudaDriver)
+}
14 changes: 14 additions & 0 deletions .github/container/patches/openxla-triton/1_cl602997103.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
==== triton/python/src/ir.cc#3 - /google/src/cloud/joelwee/mlir_da784a25557e29996bd33638d51d569ddf989faf_1706700588/triton/python/src/ir.cc ====
# action=edit type=text
--- triton/python/src/ir.cc 2024-01-30 04:43:56.000000000 -0800
+++ triton/python/src/ir.cc 2024-01-31 04:40:21.000000000 -0800
@@ -271,7 +271,8 @@
})
.def("get_num_arguments", &mlir::Block::getNumArguments)
.def("dump", &mlir::Block::dump)
- .def("move_before", &mlir::Block::moveBefore)
+ .def("move_before",
+ [](mlir::Block &self, mlir::Block *tgt) { self.moveBefore(tgt); })
.def("insert_before", &mlir::Block::insertBefore)
.def("get_parent", &mlir::Block::getParent, ret::reference)
.def("merge_block_before",
33 changes: 26 additions & 7 deletions .github/container/test-pax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ usage() {
echo " --dtype Batch size, defaults to bfloat16."
echo " --enable-te If set, will run with env var ENABLE_TE=1."
echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1."
echo " --enable-fused-attn Whether to test fused attention through TE."
echo " --run-5b Whether run GPT5B rather than the default 126M."
echo " --evaluate Whether to test evaluation rather than training."
echo " -s, --steps Number of steps to run, defaults to 500."
echo " --multiprocess Enable the multiprocess GPU mode."
Expand All @@ -30,7 +32,7 @@ usage() {
exit $1
}

args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,enable-fused-attn,run-5b,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
if [[ $? -ne 0 ]]; then
exit $1
fi
Expand All @@ -48,8 +50,10 @@ TP=1
PP=1
NODES=1
ENABLE_TE=0
NVTE_FUSED_ATTN=0
DROPOUT=0
EVALUATE=0
RUN_5B=0
ADDITIONAL_ARGS=""

eval set -- "$args"
Expand All @@ -75,6 +79,14 @@ while [ : ]; do
DROPOUT='0.1'
shift 1
;;
--enable-fused-attn)
NVTE_FUSED_ATTN=1
shift 1
;;
--run-5b)
RUN_5B=1
shift 1
;;
--evaluate)
EVALUATE=1
shift 1
Expand Down Expand Up @@ -136,6 +148,7 @@ print_var NGPUS
print_var OUTPUT
print_var MULTIPROCESS
print_var ENABLE_TE
print_var NVTE_FUSED_ATTN
print_var EVALUATE
print_var DROPOUT
print_var DP
Expand All @@ -162,7 +175,6 @@ tp = ${TP}
pp = ${PP}
num_gpus = ${NGPUS}
percore_batch_size = ${BATCH_PER_GPU}
steps = ${STEPS}
dtype = "${DTYPE}"
dropout = float(${DROPOUT})
Expand Down Expand Up @@ -280,7 +292,6 @@ if pp > 1:
NUM_STAGES = pp
PERCORE_BATCH_SIZE = percore_batch_size
FRPOP_DTYPE = dtype
MAX_STEPS = steps
def task(self):
task_p = super().task()
Expand All @@ -296,7 +307,6 @@ else:
DCN_MESH_SHAPE = [dcn_dp, dcn_fsdp, 1]
PERCORE_BATCH_SIZE = percore_batch_size
FRPOP_DTYPE = dtype
MAX_STEPS = steps
DROPOUT_PROB = dropout
Expand Down Expand Up @@ -327,10 +337,18 @@ set -ex

export XLA_PYTHON_CLIENT_MEM_FRACTION=${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.65}
export ENABLE_TE=$ENABLE_TE
export NVTE_FUSED_ATTN=$NVTE_FUSED_ATTN

CONFIG=ci_configs.Synthetic126MCI
if [[ ${RUN_5B} -ne 0 ]]; then
CONFIG=paxml.contrib.gpu.scripts_gpu.configs.Synthetic5B
ADDITIONAL_ARGS="--fdl.DCN_MESH_SHAPE=[1,${NODES},1] --fdl.ICI_MESH_SHAPE=[${DP},${FSDP},${TP}] ${ADDITIONAL_ARGS} --fdl.PERCORE_BATCH_SIZE=${BATCH_PER_GPU}"
fi

if [[ ${EVALUATE} -ne 0 ]]; then
## train for 0 steps to generate an initial checkpoint
python -m paxml.main \
--fdl_config=ci_configs.Synthetic126MCI \
--fdl_config=${CONFIG} \
--fdl.MAX_STEPS=0 \
--job_log_dir=${OUTPUT} \
--alsologtostderr \
Expand All @@ -339,7 +357,7 @@ if [[ ${EVALUATE} -ne 0 ]]; then

## restore from initial checkpoint for eval
python -m paxml.main \
--fdl_config=ci_configs.Synthetic126MCI \
--fdl_config=${CONFIG} \
--job_log_dir=${OUTPUT} \
--mode='eval' \
--alsologtostderr \
Expand All @@ -350,9 +368,10 @@ if [[ ${EVALUATE} -ne 0 ]]; then
rm -rf ${OUTPUT}/checkpoints
else
python -m paxml.main \
--fdl_config=ci_configs.Synthetic126MCI \
--fdl_config=${CONFIG} \
--job_log_dir=${OUTPUT} \
--alsologtostderr \
--fdl.MAX_STEPS=${STEPS} \
--enable_checkpoint_saving=False \
$ADDITIONAL_ARGS \
$([[ $MULTIPROCESS != 0 ]] && echo --multiprocess_gpu)
Expand Down
6 changes: 5 additions & 1 deletion .github/workflows/_finalize.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ jobs:
source .github/workflows/scripts/to_json.sh
badge_label='workflow metadata'
badge_message="Run ${{ github.run_id }}, ${{ inputs.BUILD_DATE || github.event.created_at }}"
if [[ -n "${{ inputs.BUILD_DATE }}" ]]; then
badge_message="${{ inputs.BUILD_DATE }}: run #${{ github.run_id }}"
else
badge_message="run #${{ github.run_id }}"
fi
schemaVersion=1 \
label="${badge_label}" \
Expand Down
50 changes: 37 additions & 13 deletions .github/workflows/_test_pax_rosetta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:
run: |
cd $GITHUB_WORKSPACE
alias sshx='ssh -o "ServerAliveInterval 7" ${{ secrets.CLUSTER_LOGIN_USER }}@${{ vars.HOSTNAME_SLURM_LOGIN }}'
sshx "date && hostname && sinfo"
sshx "date && hostname && sinfo"
sshx mkdir -p ${{ steps.meta.outputs.MODEL_PATH }}
JOB=$(sshx sbatch --parsable << EOF
#!/bin/bash
Expand Down Expand Up @@ -210,12 +210,35 @@ jobs:
rosetta-pax-multi-node-te:
strategy:
matrix:
PARALLEL_CONFIG:
- [1, 1, 1, 1]
- [1, 8, 1, 1]
- [1, 1, 8, 1]
- [1, 4, 1, 2]
- [1, 16, 1, 1]
include:
- TEST_NAME: 1DP1FSDP1TP1PP_TE
PARALLEL_CONFIG: [1, 1, 1, 1]
BATCH_SIZE: 4
ADDITIONAL_ARGS: ""
- TEST_NAME: 8DP1FSDP1TP1PP_TE
PARALLEL_CONFIG: [1, 8, 1, 1]
ADDITIONAL_ARGS: ""
BATCH_SIZE: 4
- TEST_NAME: 1DP8FSDP1TP1PP_TE
PARALLEL_CONFIG: [1, 1, 8, 1]
BATCH_SIZE: 4
ADDITIONAL_ARGS: ""
- TEST_NAME: 4DP1FSDP2TP1PP_TE
PARALLEL_CONFIG: [1, 4, 1, 2]
BATCH_SIZE: 4
ADDITIONAL_ARGS: ""
- TEST_NAME: 16DP1FSDP1TP1PP_TE
PARALLEL_CONFIG: [1, 16, 1, 1]
BATCH_SIZE: 4
ADDITIONAL_ARGS: ""
- TEST_NAME: 5B_fused_attn_1
PARALLEL_CONFIG: [1, 1, 8, 1]
BATCH_SIZE: 2
ADDITIONAL_ARGS: "--run-5b --enable-fused-attn"
- TEST_NAME: 5B_fused_attn_0
PARALLEL_CONFIG: [1, 1, 8, 1]
BATCH_SIZE: 2
ADDITIONAL_ARGS: "--run-5b"
fail-fast: false

runs-on: ubuntu-22.04
Expand Down Expand Up @@ -249,7 +272,7 @@ jobs:
run: |
cd $GITHUB_WORKSPACE
IMAGE="$(echo ${{inputs.PAX_IMAGE}} | sed 's/\//#/')"
TEST_CASE_NAME=${{ matrix.PARALLEL_CONFIG[1] }}DP${{ matrix.PARALLEL_CONFIG[2] }}FSDP${{ matrix.PARALLEL_CONFIG[3] }}TP${{ matrix.PARALLEL_CONFIG[0] }}PP_TE
TEST_CASE_NAME=${{ matrix.TEST_NAME }}
TOTAL_TASKS=$((${{ matrix.PARALLEL_CONFIG[0] }} * ${{ matrix.PARALLEL_CONFIG[1] }} * ${{ matrix.PARALLEL_CONFIG[2] }} * ${{ matrix.PARALLEL_CONFIG[3] }}))
MAX_GPUS_PER_NODE=8
NODES=$(((TOTAL_TASKS+MAX_GPUS_PER_NODE-1)/MAX_GPUS_PER_NODE))
Expand All @@ -267,7 +290,7 @@ jobs:
shell: bash -O expand_aliases -x -e {0}
run: |
alias sshx='ssh -o "ServerAliveInterval 7" ${{ secrets.CLUSTER_LOGIN_USER }}@${{ vars.HOSTNAME_SLURM_LOGIN }}'
sshx "date && hostname && sinfo"
sshx "date && hostname && sinfo"
sshx mkdir -p ${{ steps.meta.outputs.MODEL_PATH }}
JOB=$(sshx sbatch --parsable << EOF
#!/bin/bash
Expand All @@ -287,15 +310,16 @@ jobs:
test-pax.sh \
--output /output/${{ steps.meta.outputs.TEST_CASE_NAME }} \
--dtype bfloat16 \
--batch-per-gpu 4 \
--batch-per-gpu ${{ matrix.BATCH_SIZE }} \
--steps 300 \
--pipeline-parallel ${{ matrix.PARALLEL_CONFIG[0] }} \
--data-parallel ${{ matrix.PARALLEL_CONFIG[1] }} \
--fsdp ${{ matrix.PARALLEL_CONFIG[2] }} \
--tensor-parallel ${{ matrix.PARALLEL_CONFIG[3] }} \
--nodes ${{ steps.meta.outputs.NODES }} \
--enable-te \
--additional-args --fdl.PACKED_INPUT=False \
--additional-args "--fdl.PACKED_INPUT=False" \
${{ matrix.ADDITIONAL_ARGS }} \
$([[ ${{ steps.meta.outputs.TOTAL_TASKS }} > 1 ]] && echo --multiprocess)
EOF
)
Expand Down Expand Up @@ -834,7 +858,7 @@ jobs:
run: |
cd $GITHUB_WORKSPACE
alias sshx='ssh -o "ServerAliveInterval 7" ${{ secrets.CLUSTER_LOGIN_USER }}@${{ vars.HOSTNAME_SLURM_LOGIN }}'
sshx "date && hostname && sinfo"
sshx "date && hostname && sinfo"
sshx mkdir -p ${{ steps.meta.outputs.MODEL_PATH }}
JOB=$(sshx sbatch --parsable << EOF
#!/bin/bash
Expand Down Expand Up @@ -1021,7 +1045,7 @@ jobs:
ENDPOINT_FILENAME: 'rosetta-pax-test-status.json'
PUBLISH: false
SCRIPT: |
EXIT_STATUSES="rosetta-pax-*DP*FSDP*TP*PP*/*-status.json rosetta-pax-${GITHUB_RUN_ID}-*DP_TE_dropout/*-status.json"
EXIT_STATUSES="rosetta-pax-*/*-status.json"
PASSED_TESTS=$(jq -r '. | select ((.state == "COMPLETED") and (.exitcode == "0")) | .state' $EXIT_STATUSES | wc -l)
FAILED_TESTS=$(jq -r '. | select ((.state != "COMPLETED") or (.exitcode != "0")) | .state' $EXIT_STATUSES | wc -l)
TOTAL_TESTS=$(ls $EXIT_STATUSES | wc -l)
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -199,5 +199,6 @@ jobs:
if: "!cancelled()"
uses: ./.github/workflows/_finalize.yaml
with:
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
PUBLISH_BADGE: ${{ needs.metadata.outputs.PUBLISH == 'true' }}
secrets: inherit
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-jax-unit-test-V100.json&logo=nvidia&label=V100">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-jax-unit-test-A100.json&logo=nvidia&label=A100">
<br>
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fte-unit-test-status.json&logo=nvidia&label=TE%20Unit">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fte-multi-gpu-test-status.json&logo=nvidia&label=TE%20Multi%20GPU">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-te-unit-test.json&logo=nvidia&label=TE%20Unit">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-te-2g-test.json&logo=nvidia&label=TE%20Multi%20GPU">
</td>
</tr>
<tr>
Expand Down Expand Up @@ -92,7 +92,7 @@
<td>rosetta</td>
<td>
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-build-t5x-amd64.json&logo=docker&label=amd64">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-build-t5x-arm64.json&logo=docker&label=arm64">
<!-- <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-build-t5x-arm64.json&logo=docker&label=arm64"> -->
</td>
<td>
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Frosetta-t5x-overall-test-status.json&logo=nvidia">
Expand All @@ -115,7 +115,7 @@
<td>rosetta</td>
<td>
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-build-pax-amd64.json&logo=docker&label=amd64">
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-build-pax-arm64.json&logo=docker&label=arm64">
<!-- <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-build-pax-arm64.json&logo=docker&label=arm64"> -->
</td>
<td>
<img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Frosetta-pax-overall-test-status.json&logo=nvidia">
Expand Down

0 comments on commit 863799a

Please sign in to comment.