From a5777adf6d33ed31f124fbf73bf89808c82d565f Mon Sep 17 00:00:00 2001 From: ashors1 <71393111+ashors1@users.noreply.github.com> Date: Tue, 20 Feb 2024 21:58:31 -0800 Subject: [PATCH 1/5] Add 5b pax test with fused attention (#506) --- .github/container/test-pax.sh | 33 ++++++++++++---- .github/workflows/_test_pax_rosetta.yaml | 50 ++++++++++++++++++------ 2 files changed, 63 insertions(+), 20 deletions(-) diff --git a/.github/container/test-pax.sh b/.github/container/test-pax.sh index 050e9f286..8d825e272 100755 --- a/.github/container/test-pax.sh +++ b/.github/container/test-pax.sh @@ -17,6 +17,8 @@ usage() { echo " --dtype Batch size, defaults to bfloat16." echo " --enable-te If set, will run with env var ENABLE_TE=1." echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1." + echo " --enable-fused-attn Whether to test fused attention through TE." + echo " --run-5b Whether run GPT5B rather than the default 126M." echo " --evaluate Whether to test evaluation rather than training." echo " -s, --steps Number of steps to run, defaults to 500." echo " --multiprocess Enable the multiprocess GPU mode." @@ -30,7 +32,7 @@ usage() { exit $1 } -args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") +args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,enable-fused-attn,run-5b,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") if [[ $? -ne 0 ]]; then exit $1 fi @@ -48,8 +50,10 @@ TP=1 PP=1 NODES=1 ENABLE_TE=0 +NVTE_FUSED_ATTN=0 DROPOUT=0 EVALUATE=0 +RUN_5B=0 ADDITIONAL_ARGS="" eval set -- "$args" @@ -75,6 +79,14 @@ while [ : ]; do DROPOUT='0.1' shift 1 ;; + --enable-fused-attn) + NVTE_FUSED_ATTN=1 + shift 1 + ;; + --run-5b) + RUN_5B=1 + shift 1 + ;; --evaluate) EVALUATE=1 shift 1 @@ -136,6 +148,7 @@ print_var NGPUS print_var OUTPUT print_var MULTIPROCESS print_var ENABLE_TE +print_var NVTE_FUSED_ATTN print_var EVALUATE print_var DROPOUT print_var DP @@ -162,7 +175,6 @@ tp = ${TP} pp = ${PP} num_gpus = ${NGPUS} percore_batch_size = ${BATCH_PER_GPU} -steps = ${STEPS} dtype = "${DTYPE}" dropout = float(${DROPOUT}) @@ -280,7 +292,6 @@ if pp > 1: NUM_STAGES = pp PERCORE_BATCH_SIZE = percore_batch_size FRPOP_DTYPE = dtype - MAX_STEPS = steps def task(self): task_p = super().task() @@ -296,7 +307,6 @@ else: DCN_MESH_SHAPE = [dcn_dp, dcn_fsdp, 1] PERCORE_BATCH_SIZE = percore_batch_size FRPOP_DTYPE = dtype - MAX_STEPS = steps DROPOUT_PROB = dropout @@ -327,10 +337,18 @@ set -ex export XLA_PYTHON_CLIENT_MEM_FRACTION=${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.65} export ENABLE_TE=$ENABLE_TE +export NVTE_FUSED_ATTN=$NVTE_FUSED_ATTN + +CONFIG=ci_configs.Synthetic126MCI +if [[ ${RUN_5B} -ne 0 ]]; then + CONFIG=paxml.contrib.gpu.scripts_gpu.configs.Synthetic5B + ADDITIONAL_ARGS="--fdl.DCN_MESH_SHAPE=[1,${NODES},1] --fdl.ICI_MESH_SHAPE=[${DP},${FSDP},${TP}] ${ADDITIONAL_ARGS} --fdl.PERCORE_BATCH_SIZE=${BATCH_PER_GPU}" +fi + if [[ ${EVALUATE} -ne 0 ]]; then ## train for 0 steps to generate an initial checkpoint python -m paxml.main \ - --fdl_config=ci_configs.Synthetic126MCI \ + --fdl_config=${CONFIG} \ --fdl.MAX_STEPS=0 \ --job_log_dir=${OUTPUT} \ --alsologtostderr \ @@ -339,7 +357,7 @@ if [[ ${EVALUATE} -ne 0 ]]; then ## restore from initial checkpoint for eval python -m paxml.main \ - --fdl_config=ci_configs.Synthetic126MCI \ + --fdl_config=${CONFIG} \ --job_log_dir=${OUTPUT} \ --mode='eval' \ --alsologtostderr \ @@ -350,9 +368,10 @@ if [[ ${EVALUATE} -ne 0 ]]; then rm -rf ${OUTPUT}/checkpoints else python -m paxml.main \ - --fdl_config=ci_configs.Synthetic126MCI \ + --fdl_config=${CONFIG} \ --job_log_dir=${OUTPUT} \ --alsologtostderr \ + --fdl.MAX_STEPS=${STEPS} \ --enable_checkpoint_saving=False \ $ADDITIONAL_ARGS \ $([[ $MULTIPROCESS != 0 ]] && echo --multiprocess_gpu) diff --git a/.github/workflows/_test_pax_rosetta.yaml b/.github/workflows/_test_pax_rosetta.yaml index 4bea7c36c..0f41d2c84 100644 --- a/.github/workflows/_test_pax_rosetta.yaml +++ b/.github/workflows/_test_pax_rosetta.yaml @@ -81,7 +81,7 @@ jobs: run: | cd $GITHUB_WORKSPACE alias sshx='ssh -o "ServerAliveInterval 7" ${{ secrets.CLUSTER_LOGIN_USER }}@${{ vars.HOSTNAME_SLURM_LOGIN }}' - sshx "date && hostname && sinfo" + sshx "date && hostname && sinfo" sshx mkdir -p ${{ steps.meta.outputs.MODEL_PATH }} JOB=$(sshx sbatch --parsable << EOF #!/bin/bash @@ -210,12 +210,35 @@ jobs: rosetta-pax-multi-node-te: strategy: matrix: - PARALLEL_CONFIG: - - [1, 1, 1, 1] - - [1, 8, 1, 1] - - [1, 1, 8, 1] - - [1, 4, 1, 2] - - [1, 16, 1, 1] + include: + - TEST_NAME: 1DP1FSDP1TP1PP_TE + PARALLEL_CONFIG: [1, 1, 1, 1] + BATCH_SIZE: 4 + ADDITIONAL_ARGS: "" + - TEST_NAME: 8DP1FSDP1TP1PP_TE + PARALLEL_CONFIG: [1, 8, 1, 1] + ADDITIONAL_ARGS: "" + BATCH_SIZE: 4 + - TEST_NAME: 1DP8FSDP1TP1PP_TE + PARALLEL_CONFIG: [1, 1, 8, 1] + BATCH_SIZE: 4 + ADDITIONAL_ARGS: "" + - TEST_NAME: 4DP1FSDP2TP1PP_TE + PARALLEL_CONFIG: [1, 4, 1, 2] + BATCH_SIZE: 4 + ADDITIONAL_ARGS: "" + - TEST_NAME: 16DP1FSDP1TP1PP_TE + PARALLEL_CONFIG: [1, 16, 1, 1] + BATCH_SIZE: 4 + ADDITIONAL_ARGS: "" + - TEST_NAME: 5B_fused_attn_1 + PARALLEL_CONFIG: [1, 1, 8, 1] + BATCH_SIZE: 2 + ADDITIONAL_ARGS: "--run-5b --enable-fused-attn" + - TEST_NAME: 5B_fused_attn_0 + PARALLEL_CONFIG: [1, 1, 8, 1] + BATCH_SIZE: 2 + ADDITIONAL_ARGS: "--run-5b" fail-fast: false runs-on: ubuntu-22.04 @@ -249,7 +272,7 @@ jobs: run: | cd $GITHUB_WORKSPACE IMAGE="$(echo ${{inputs.PAX_IMAGE}} | sed 's/\//#/')" - TEST_CASE_NAME=${{ matrix.PARALLEL_CONFIG[1] }}DP${{ matrix.PARALLEL_CONFIG[2] }}FSDP${{ matrix.PARALLEL_CONFIG[3] }}TP${{ matrix.PARALLEL_CONFIG[0] }}PP_TE + TEST_CASE_NAME=${{ matrix.TEST_NAME }} TOTAL_TASKS=$((${{ matrix.PARALLEL_CONFIG[0] }} * ${{ matrix.PARALLEL_CONFIG[1] }} * ${{ matrix.PARALLEL_CONFIG[2] }} * ${{ matrix.PARALLEL_CONFIG[3] }})) MAX_GPUS_PER_NODE=8 NODES=$(((TOTAL_TASKS+MAX_GPUS_PER_NODE-1)/MAX_GPUS_PER_NODE)) @@ -267,7 +290,7 @@ jobs: shell: bash -O expand_aliases -x -e {0} run: | alias sshx='ssh -o "ServerAliveInterval 7" ${{ secrets.CLUSTER_LOGIN_USER }}@${{ vars.HOSTNAME_SLURM_LOGIN }}' - sshx "date && hostname && sinfo" + sshx "date && hostname && sinfo" sshx mkdir -p ${{ steps.meta.outputs.MODEL_PATH }} JOB=$(sshx sbatch --parsable << EOF #!/bin/bash @@ -287,7 +310,7 @@ jobs: test-pax.sh \ --output /output/${{ steps.meta.outputs.TEST_CASE_NAME }} \ --dtype bfloat16 \ - --batch-per-gpu 4 \ + --batch-per-gpu ${{ matrix.BATCH_SIZE }} \ --steps 300 \ --pipeline-parallel ${{ matrix.PARALLEL_CONFIG[0] }} \ --data-parallel ${{ matrix.PARALLEL_CONFIG[1] }} \ @@ -295,7 +318,8 @@ jobs: --tensor-parallel ${{ matrix.PARALLEL_CONFIG[3] }} \ --nodes ${{ steps.meta.outputs.NODES }} \ --enable-te \ - --additional-args --fdl.PACKED_INPUT=False \ + --additional-args "--fdl.PACKED_INPUT=False" \ + ${{ matrix.ADDITIONAL_ARGS }} \ $([[ ${{ steps.meta.outputs.TOTAL_TASKS }} > 1 ]] && echo --multiprocess) EOF ) @@ -834,7 +858,7 @@ jobs: run: | cd $GITHUB_WORKSPACE alias sshx='ssh -o "ServerAliveInterval 7" ${{ secrets.CLUSTER_LOGIN_USER }}@${{ vars.HOSTNAME_SLURM_LOGIN }}' - sshx "date && hostname && sinfo" + sshx "date && hostname && sinfo" sshx mkdir -p ${{ steps.meta.outputs.MODEL_PATH }} JOB=$(sshx sbatch --parsable << EOF #!/bin/bash @@ -1021,7 +1045,7 @@ jobs: ENDPOINT_FILENAME: 'rosetta-pax-test-status.json' PUBLISH: false SCRIPT: | - EXIT_STATUSES="rosetta-pax-*DP*FSDP*TP*PP*/*-status.json rosetta-pax-${GITHUB_RUN_ID}-*DP_TE_dropout/*-status.json" + EXIT_STATUSES="rosetta-pax-*/*-status.json" PASSED_TESTS=$(jq -r '. | select ((.state == "COMPLETED") and (.exitcode == "0")) | .state' $EXIT_STATUSES | wc -l) FAILED_TESTS=$(jq -r '. | select ((.state != "COMPLETED") or (.exitcode != "0")) | .state' $EXIT_STATUSES | wc -l) TOTAL_TESTS=$(ls $EXIT_STATUSES | wc -l) From f450e3c70eec3e7dcf031e8db0413959796a45cd Mon Sep 17 00:00:00 2001 From: "Yu-Hang \"Maxin\" Tang" Date: Tue, 20 Feb 2024 23:00:35 -0800 Subject: [PATCH 2/5] update badges (#576) --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 00dd78673..35be3f93b 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@
- - + + @@ -92,7 +92,7 @@ rosetta - + @@ -115,7 +115,7 @@ rosetta - + From 249ea8ab5dfa4917094adb4686b0c52afb60a69b Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Wed, 21 Feb 2024 08:05:22 +0100 Subject: [PATCH 3/5] Apply XLA's patches to XLA's tag of Triton (#570) XLA maintains a list of patches in third_party/triton/workspace.bzl that should be applied on top of its specific tag of Triton. Previously we did not apply these when building the Pallas containers. This PR applies them. --- .github/container/Dockerfile.pallas | 4 +++- .github/container/bump-openxla-triton.sh | 18 ++++++++++++++---- .../patches/openxla-triton/0_cl602723852.patch | 15 +++++++++++++++ .../patches/openxla-triton/1_cl602997103.patch | 14 ++++++++++++++ 4 files changed, 46 insertions(+), 5 deletions(-) create mode 100644 .github/container/patches/openxla-triton/0_cl602723852.patch create mode 100644 .github/container/patches/openxla-triton/1_cl602997103.patch diff --git a/.github/container/Dockerfile.pallas b/.github/container/Dockerfile.pallas index 9f274703e..3666ea847 100644 --- a/.github/container/Dockerfile.pallas +++ b/.github/container/Dockerfile.pallas @@ -10,8 +10,10 @@ FROM ${BASE_IMAGE} as builder ARG SRC_PATH_TRITON # bump-openxla-triton.sh ensures that the commit of openxla-triton referenced -# in the manifest file is consistent with the commit of xla +# in the manifest file is consistent with the commit of xla, and that any extra +# patches are available under patches/openxla-triton RUN get-source.sh -l openxla-triton -m ${MANIFEST_FILE} +RUN cd /opt/openxla-triton && for patch in /opt/manifest.d/patches/openxla-triton/*.patch; do patch -p1 < "${patch}"; done && git diff RUN <<"EOF" bash -ex mkdir -p "${SRC_PATH_TRITON}/dist" diff --git a/.github/container/bump-openxla-triton.sh b/.github/container/bump-openxla-triton.sh index 0c0257eee..cec0930cc 100755 --- a/.github/container/bump-openxla-triton.sh +++ b/.github/container/bump-openxla-triton.sh @@ -6,7 +6,7 @@ usage() { cat < /dev/null && pwd ) xla_url=$(yq e ".xla.url" $MANIFEST) xla_tracking_ref=$(yq e ".xla.tracking_ref" $MANIFEST) @@ -58,8 +59,17 @@ xla_commit=$(yq e ".xla.latest_verified_commit" $MANIFEST) git clone --branch "${xla_tracking_ref}" --single-branch "${xla_url}" "${xla_repo}" (cd "${xla_repo}" && git checkout "${xla_commit}") # Extract the openxla/triton tag used by XLA. Even though it is called -# TRITON_COMMIT it is a tag. In principle we should also account for the -# patches in this .bzl file, but skip that for now. -openxla_triton_tag=$(sed -n -e 's#\s\+TRITON_COMMIT = "\(cl[0-9]\+\)"#\1#p' "${xla_repo}/third_party/triton/workspace.bzl") +# TRITON_COMMIT it is a tag. +workspace_file="${xla_repo}/third_party/triton/workspace.bzl" +openxla_triton_tag=$(sed -n -e 's#\s\+TRITON_COMMIT = "\(cl[0-9]\+\)"#\1#p' "${workspace_file}") +# Extract Triton patch files applied by XLA +patch_files=$(python3 -c 'import ast, sys; tree = ast.parse(sys.stdin.read()); print(" ".join(elem.value.removeprefix("//third_party/triton:").removesuffix(".patch") for node in ast.walk(tree) if isinstance(node, ast.keyword) and node.arg == "patch_file" for elem in node.value.elts))' < "${workspace_file}") +i=0 +# Remove old patch files +rm -vf ${SCRIPT_DIR}/patches/openxla-triton/*.patch +for patch_file in ${patch_files}; do + cp -v "${xla_repo}/third_party/triton/${patch_file}.patch" "${SCRIPT_DIR}/patches/openxla-triton/${i}_${patch_file}.patch" + i=$((i+1)) +done rm -rf "${xla_repo}" yq e ".openxla-triton.latest_verified_commit = \"${openxla_triton_tag}\"" -i $MANIFEST diff --git a/.github/container/patches/openxla-triton/0_cl602723852.patch b/.github/container/patches/openxla-triton/0_cl602723852.patch new file mode 100644 index 000000000..850124cbc --- /dev/null +++ b/.github/container/patches/openxla-triton/0_cl602723852.patch @@ -0,0 +1,15 @@ +Remove once b/322980485 is fixed. +diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py +--- a/python/triton/backends/__init__.py ++++ b/python/triton/backends/__init__.py +@@ -46,5 +46,8 @@ def _discover_backends(): + _find_concrete_subclasses(driver, DriverBase)) + return backends + +- +-backends = _discover_backends() ++from triton.backends.nvidia.driver import CudaDriver ++from triton.backends.nvidia.compiler import CUDABackend ++backends = { ++ "nvidia": Backend(CUDABackend, CudaDriver) ++} diff --git a/.github/container/patches/openxla-triton/1_cl602997103.patch b/.github/container/patches/openxla-triton/1_cl602997103.patch new file mode 100644 index 000000000..6e955e23e --- /dev/null +++ b/.github/container/patches/openxla-triton/1_cl602997103.patch @@ -0,0 +1,14 @@ +==== triton/python/src/ir.cc#3 - /google/src/cloud/joelwee/mlir_da784a25557e29996bd33638d51d569ddf989faf_1706700588/triton/python/src/ir.cc ==== +# action=edit type=text +--- triton/python/src/ir.cc 2024-01-30 04:43:56.000000000 -0800 ++++ triton/python/src/ir.cc 2024-01-31 04:40:21.000000000 -0800 +@@ -271,7 +271,8 @@ + }) + .def("get_num_arguments", &mlir::Block::getNumArguments) + .def("dump", &mlir::Block::dump) +- .def("move_before", &mlir::Block::moveBefore) ++ .def("move_before", ++ [](mlir::Block &self, mlir::Block *tgt) { self.moveBefore(tgt); }) + .def("insert_before", &mlir::Block::insertBefore) + .def("get_parent", &mlir::Block::getParent, ret::reference) + .def("merge_block_before", From f050909e817df81bd605873607dafc9e431be5ae Mon Sep 17 00:00:00 2001 From: "Yu-Hang \"Maxin\" Tang" Date: Tue, 20 Feb 2024 23:09:17 -0800 Subject: [PATCH 4/5] Fix metadata badge format (#577) --- .github/workflows/_finalize.yaml | 6 +++++- .github/workflows/ci.yaml | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/_finalize.yaml b/.github/workflows/_finalize.yaml index 1b8c77531..eb2e9d70f 100644 --- a/.github/workflows/_finalize.yaml +++ b/.github/workflows/_finalize.yaml @@ -36,7 +36,11 @@ jobs: source .github/workflows/scripts/to_json.sh badge_label='workflow metadata' - badge_message="Run ${{ github.run_id }}, ${{ inputs.BUILD_DATE || github.event.created_at }}" + if [[ -n "${{ inputs.BUILD_DATE }}" ]]; then + badge_message="${{ inputs.BUILD_DATE }}: run #${{ github.run_id }}" + else + badge_message="run #${{ github.run_id }}" + fi schemaVersion=1 \ label="${badge_label}" \ diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 646af9eed..1b538eed1 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -199,5 +199,6 @@ jobs: if: "!cancelled()" uses: ./.github/workflows/_finalize.yaml with: + BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} PUBLISH_BADGE: ${{ needs.metadata.outputs.PUBLISH == 'true' }} secrets: inherit From 8fdff5c62d10bf27e6ad0990eb34052671de11cf Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Wed, 21 Feb 2024 22:41:25 -0800 Subject: [PATCH 5/5] Remove chex pin from pax build causing failure due to pax using later than 0.1.7 now (#578) Failing build: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/7971250761/job/21763211927#step:10:240 --- .github/container/Dockerfile.pax.arm64 | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/container/Dockerfile.pax.arm64 b/.github/container/Dockerfile.pax.arm64 index 3e703fb63..58ac5d33c 100644 --- a/.github/container/Dockerfile.pax.arm64 +++ b/.github/container/Dockerfile.pax.arm64 @@ -97,7 +97,6 @@ RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip- RUN <<"EOF" bash -ex echo "tensorflow==2.13.0" >> /opt/pip-tools.d/requirements-paxml.in echo "tensorflow_datasets==4.9.2" >> /opt/pip-tools.d/requirements-paxml.in -echo "chex==0.1.7" >> /opt/pip-tools.d/requirements-paxml.in echo "auditwheel" >> /opt/pip-tools.d/requirements-paxml.in get-source.sh -l paxml -m ${MANIFEST_FILE}