diff --git a/.github/workflows/_linux-test-h100.yml b/.github/workflows/_linux-test-h100.yml index 614b7c63..d6dcf18e 100644 --- a/.github/workflows/_linux-test-h100.yml +++ b/.github/workflows/_linux-test-h100.yml @@ -27,6 +27,14 @@ jobs: sudo nvidia-smi -pm 1 sudo ldconfig nvidia-smi + - name: Install Tritonbench + run: | + # todo: remove this when the new docker rolls out + mkdir -p /workspace/tritonbench/.data + # speedup install and skip compile + ln -s /workspace/tritonbench/.data . + . "${SETUP_SCRIPT}" + python install.py --colfax --tk --hstu - name: Test Tritonbench operators on H100 GPU run: | bash ./.ci/tritonbench/test-gpu.sh diff --git a/.github/workflows/docker-rocm.yaml b/.github/workflows/docker-rocm.yaml index 17a3f03f..690ec8bf 100644 --- a/.github/workflows/docker-rocm.yaml +++ b/.github/workflows/docker-rocm.yaml @@ -48,7 +48,7 @@ jobs: if: github.event_name != 'pull_request' run: | # Extract pytorch version from the docker - PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"') + PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:rocm-latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"') export DOCKER_TAG=$(awk '{match($0, /dev[0-9]+/, arr); print arr[0]}' <<< "${PYTORCH_VERSION}") docker push ghcr.io/pytorch-labs/tritonbench:rocm-${DOCKER_TAG} docker push ghcr.io/pytorch-labs/tritonbench:rocm-latest diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index b9c0025d..575cc703 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -18,11 +18,17 @@ fbcode_skip_file_path = "fb/skip_tests_h100_fbcode.yaml" SKIP_FILE = importlib.resources.files(__package__).joinpath(fbcode_skip_file_path) else: + SKIP_FILE_NAME = "skip_tests_h100_pytorch.yaml" + try: + # test if it is Triton main branch + import triton.tools.experimental_descriptor # noqa: F401 + + SKIP_FILE_NAME = "skip_tests_h100_triton_main.yaml" + except ModuleNotFoundError: + pass import os - SKIP_FILE = os.path.abspath( - os.path.join(os.path.dirname(__file__), "skip_tests_h100_pytorch.yaml") - ) + SKIP_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), SKIP_FILE_NAME)) with open(SKIP_FILE, "r") as f: skip_tests = yaml.safe_load(f) diff --git a/test/test_gpu/skip_tests_h100_triton_main.yaml b/test/test_gpu/skip_tests_h100_triton_main.yaml new file mode 100644 index 00000000..4727832b --- /dev/null +++ b/test/test_gpu/skip_tests_h100_triton_main.yaml @@ -0,0 +1,44 @@ +# Tests we skip in triton-pytorch + OSS CI +# triton-pytorch is the triton version bundled with pytorch nightly +# We need to skip kernels that only work on triton-main +# Usage: +# op-name: to skip an entire operator +# op-name:\n\t- impl-name to skip an impl +flash_attention: + # thunderkittens cannot handle the default input shapes + - tk + # FIXME: triton_tutorial_* kernels are broken + - triton_tutorial_flash_v2 + - triton_tutorial_flash_v2_opt + - triton_tutorial_flash_v2_tma + - triton_tutorial_flash_v2_ws + - triton_tutorial_flash_v2_tma_ws +fp8_attention: + # fb-only kernel + - colfax_fmha + # FIXME: triton_flash_v2 kernel is broken + - triton_flash_v2 +# fb-only kernels +fp8_fused_quant_gemm_rowwise: +fp8_gemm: + # FIXME: triton_*_persistent kernels are broken + - triton_persistent_fp8_gemm + - triton_tma_persistent_fp8_gemm +gemm: + # out of shared memory + - triton_tma_persistent_matmul + # out of shared memory + - triton_tma_persistent_cached_matmul + # internal only kernels + - hstu_triton_matmul + - colfax_cutlass_matmul +# jagged tests are slow, so disable them in OSS +jagged_layer_norm: +jagged_mean: +jagged_softmax: +jagged_sum: +# FIXME: ragged attention will Abort (Core Dump) on Triton Main +ragged_attention: +test_op: +fwd_only_ops: + - flash_attention diff --git a/tools/cutlass_kernels/install.py b/tools/cutlass_kernels/install.py index 9ea9d4a7..d7c50c08 100644 --- a/tools/cutlass_kernels/install.py +++ b/tools/cutlass_kernels/install.py @@ -82,7 +82,7 @@ def test_colfax_cutlass(colfax_cutlass_lib: str): def install_colfax_cutlass(): # compile colfax_cutlass kernels - output_dir = COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath(".data") + output_dir = REPO_PATH.joinpath(".data", "cutlass_kernels") output_dir.mkdir(parents=True, exist_ok=True) cmd = ["nvcc"] cmd.extend(COMPILER_FLAGS) diff --git a/tools/tk/install.py b/tools/tk/install.py index 0c547699..b3d3c03e 100644 --- a/tools/tk/install.py +++ b/tools/tk/install.py @@ -92,7 +92,7 @@ def test_tk_attn_h100_fwd(tk_lib): def install_tk(): # compile thunderkitten kernels - output_dir = TRITONBENCH_TK_PATH.joinpath(".data") + output_dir = REPO_PATH.joinpath(".data", "tk") output_dir.mkdir(parents=True, exist_ok=True) cmd = ["nvcc"] cmd.extend(COMPILER_FLAGS) diff --git a/tritonbench/operators/int4_gemm/int4_gemm.py b/tritonbench/operators/int4_gemm/int4_gemm.py index 614be6ef..8b611451 100644 --- a/tritonbench/operators/int4_gemm/int4_gemm.py +++ b/tritonbench/operators/int4_gemm/int4_gemm.py @@ -12,8 +12,8 @@ from typing import Any, List, Optional import torch +import triton import triton.language as tl -import triton.ops from tritonbench.utils.triton_op import ( BenchmarkOperator, diff --git a/tritonbench/utils/env_utils.py b/tritonbench/utils/env_utils.py index fbfca74c..f3debfcf 100644 --- a/tritonbench/utils/env_utils.py +++ b/tritonbench/utils/env_utils.py @@ -27,13 +27,13 @@ def set_env(): # set cutlass dir # by default we use the cutlass version built with pytorch - import torch + import torch._inductor.config as inductor_config - current_cutlass_dir = torch._inductor.config.cuda.cutlass_dir + current_cutlass_dir = inductor_config.cuda.cutlass_dir if not os.path.exists(current_cutlass_dir): tb_cutlass_dir = REPO_PATH.joinpath("submodules", "cutlass") if tb_cutlass_dir.is_dir(): - torch._inductor.config.cuda.cutlass_dir = str(tb_cutlass_dir) + inductor_config.cuda.cutlass_dir = str(tb_cutlass_dir) def set_random_seed(): diff --git a/tritonbench/utils/loader.py b/tritonbench/utils/loader.py index 2c505bef..06d7bb9f 100644 --- a/tritonbench/utils/loader.py +++ b/tritonbench/utils/loader.py @@ -5,5 +5,5 @@ def load_library(library_path: str): import torch prefix, _delimiter, so_file = library_path.partition("/") - so_full_path = REPO_PATH.joinpath("tools", prefix, ".data", so_file).resolve() + so_full_path = REPO_PATH.joinpath(".data", prefix, so_file).resolve() torch.ops.load_library(str(so_full_path))