diff --git a/install.py b/install.py index cb8d426..425c21e 100644 --- a/install.py +++ b/install.py @@ -113,17 +113,16 @@ def setup_hip(args: argparse.Namespace): # checkout submodules checkout_submodules(REPO_PATH) # install submodules + if args.fa3 or args.all: + logger.info("[tritonbench] installing fa3...") + from tools.flash_attn.install import install_fa3 + install_fa3() if args.fbgemm or args.all: logger.info("[tritonbench] installing FBGEMM...") install_fbgemm() if args.fa2 or args.all: logger.info("[tritonbench] installing fa2 from source...") install_fa2(compile=True) - if args.fa3 or args.all: - logger.info("[tritonbench] installing fa3...") - from tools.flash_attn.install import install_fa3 - - install_fa3() if args.colfax: logger.info("[tritonbench] installing colfax cutlass-kernels...") from tools.cutlass_kernels.install import install_colfax_cutlass