Skip to content

Commit

Permalink
Fix the nightly docker build (#120)
Browse files Browse the repository at this point in the history
Summary:
After the submodule update, the FA3 CUTLASS kernels now cost much more memory to build and will exhaust the CI machine.

Make the following changes:
1. Build FA3 with MAX_JOBS=4 and NVCC_THREADS=1.
2. Disable xformers build as it also requires FA3, we will try to enable its build later.
3. Also disable colfax and TK build for now - will fix that later.

Pull Request resolved: #120

Test Plan: CI

Reviewed By: adamomainz

Differential Revision: D67524390

Pulled By: xuzhao9

fbshipit-source-id: fc889fae5d51d7d2d974d2996e33e3f31d8db98e
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Dec 20, 2024
1 parent 06c28ed commit 5cc3976
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 14 deletions.
3 changes: 3 additions & 0 deletions docker/tritonbench-nightly.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ RUN cd /workspace/tritonbench && \
# which is from NVIDIA driver
RUN sudo apt update && sudo apt-get install -y libnvidia-compute-550 patchelf patch

# Workaround: installing Ninja from setup.py hits "Failed to decode METADATA with UTF-8" error
RUN . ${SETUP_SCRIPT} && pip install ninja

# Install Tritonbench
RUN cd /workspace/tritonbench && \
bash .ci/tritonbench/install.sh
Expand Down
23 changes: 10 additions & 13 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,14 @@ def install_fa2(compile=False):
if compile:
# compile from source (slow)
FA2_PATH = REPO_PATH.joinpath("submodules", "flash-attention")
cmd = [sys.executable, "setup.py", "install"]
cmd = ["pip", "install", "-e", "."]
subprocess.check_call(cmd, cwd=str(FA2_PATH.resolve()))
else:
# Install the pre-built binary
cmd = ["pip", "install", "flash-attn", "--no-build-isolation"]
subprocess.check_call(cmd)


def install_fa3():
FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper")
cmd = [sys.executable, "setup.py", "install"]
subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()))


def install_liger():
# Liger-kernel has a conflict dependency `triton` with pytorch,
# so we need to install it without dependencies
Expand Down Expand Up @@ -119,32 +113,35 @@ def setup_hip(args: argparse.Namespace):
# checkout submodules
checkout_submodules(REPO_PATH)
# install submodules
if args.fa3 or args.all:
# we need to install fa3 above all other dependencies
logger.info("[tritonbench] installing fa3...")
from tools.flash_attn.install import install_fa3

install_fa3()
if args.fbgemm or args.all:
logger.info("[tritonbench] installing FBGEMM...")
install_fbgemm()
if args.fa2 or args.all:
logger.info("[tritonbench] installing fa2 from source...")
install_fa2(compile=True)
if args.fa3 or args.all:
logger.info("[tritonbench] installing fa3...")
install_fa3()
if args.colfax or args.all:
if args.colfax:
logger.info("[tritonbench] installing colfax cutlass-kernels...")
from tools.cutlass_kernels.install import install_colfax_cutlass

install_colfax_cutlass()
if args.jax or args.all:
logger.info("[tritonbench] installing jax...")
install_jax()
if args.tk or args.all:
if args.tk:
logger.info("[tritonbench] installing thunderkittens...")
from tools.tk.install import install_tk

install_tk()
if args.liger or args.all:
logger.info("[tritonbench] installing liger-kernels...")
install_liger()
if args.xformers or args.all:
if args.xformers:
logger.info("[tritonbench] installing xformers...")
from tools.xformers.install import install_xformers

Expand Down
2 changes: 1 addition & 1 deletion submodules/ThunderKittens
Submodule ThunderKittens updated 383 files
14 changes: 14 additions & 0 deletions tools/flash_attn/hopper.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
diff --git a/hopper/setup.py b/hopper/setup.py
index f9f3cfd..132ce07 100644
--- a/hopper/setup.py
+++ b/hopper/setup.py
@@ -78,7 +78,8 @@ def check_if_cuda_home_none(global_option: str) -> None:


def append_nvcc_threads(nvcc_extra_args):
- return nvcc_extra_args + ["--threads", "4"]
+ nvcc_threads = os.getenv("NVCC_THREADS") or "4"
+ return nvcc_extra_args + ["--threads", nvcc_threads]


cmdclass = {}
48 changes: 48 additions & 0 deletions tools/flash_attn/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import subprocess
import sys

from pathlib import Path

REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
CUR_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)))


def patch_fa3():
patches = ["hopper.patch"]
for patch_file in patches:
patch_file_path = os.path.join(CUR_DIR, patch_file)
submodule_path = str(
REPO_PATH.joinpath("submodules", "flash-attention").absolute()
)
try:
subprocess.check_output(
[
"patch",
"-p1",
"--forward",
"-i",
patch_file_path,
"-r",
"/tmp/rej",
],
cwd=submodule_path,
)
except subprocess.SubprocessError as e:
output_str = str(e.output)
if "previously applied" in output_str:
return
else:
print(str(output_str))
sys.exit(1)


def install_fa3():
patch_fa3()
FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper")
env = os.environ.copy()
# limit nvcc memory usage on the CI machine
env["MAX_JOBS"] = "8"
env["NVCC_THREADS"] = "1"
cmd = ["pip", "install", "-e", "."]
subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()), env=env)

0 comments on commit 5cc3976

Please sign in to comment.