diff --git a/scripts/train/benchmarking/submit_benchmarks.py b/scripts/train/benchmarking/submit_benchmarks.py index 2c02223409..14b72f3d74 100644 --- a/scripts/train/benchmarking/submit_benchmarks.py +++ b/scripts/train/benchmarking/submit_benchmarks.py @@ -412,7 +412,7 @@ def run_config(config: Tuple[str, int, int, str, str, int, str], command += """pip install -U git+https://github.com/mvpatel2000/composer.git@784f50be7fa8617ed562704c0207316ca2284e71 pip uninstall torch==2.0.1 --yes pip install --no-cache-dir --pre --index-url https://download.pytorch.org/whl/nightly/cu121 torch==2.1.0.dev20230821+cu121""" - if gpu_type == 'h100_80gb': # Required for flash-attn and FP8 training + if gpu_type == 'h100_80gb' and 'fp8' in precision: # Required for flash-attn and FP8 training command += f""" pip install flash-attn==1.0.7 --no-build-isolation pip install git+https://github.com/NVIDIA/TransformerEngine.git@v0.10