diff --git a/tools/flash_attn/install.py b/tools/flash_attn/install.py index f06d438..f3b2303 100644 --- a/tools/flash_attn/install.py +++ b/tools/flash_attn/install.py @@ -38,7 +38,7 @@ def install_fa3(): FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper") env = os.environ.copy() # nvcc will spawn cicc process and will cost ~1G memory - env["MAX_JOBS"] = "8" - env["NVCC_THREADS"] = "1" + # env["MAX_JOBS"] = "8" + # env["NVCC_THREADS"] = "1" cmd = [sys.executable, "setup.py", "install"] subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()), env=env)