Skip to content

Commit

Permalink
Fix the PR CI (#72)
Browse files Browse the repository at this point in the history
Summary:
We still need to patch HSTU at runtime, we do not need to patch xformers as it is already installed.

Also, move all compilation artifacts to `$REPO_DIR/.data` so that we do not need to recompile colfax and tk.

Pull Request resolved: #72

Reviewed By: adamomainz

Differential Revision: D66363425

Pulled By: xuzhao9

fbshipit-source-id: e3fe1a49bc89973e12a10dbd3cba0e8f1b9297bd
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 22, 2024
1 parent 17e025d commit e8f5ba4
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 11 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/_linux-test-h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ jobs:
sudo nvidia-smi -pm 1
sudo ldconfig
nvidia-smi
- name: Install Tritonbench
run: |
# todo: remove this when the new docker rolls out
mkdir -p /workspace/tritonbench/.data
# speedup install and skip compile
ln -s /workspace/tritonbench/.data .
. "${SETUP_SCRIPT}"
python install.py --colfax --tk --hstu
- name: Test Tritonbench operators on H100 GPU
run: |
bash ./.ci/tritonbench/test-gpu.sh
2 changes: 1 addition & 1 deletion .github/workflows/docker-rocm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
if: github.event_name != 'pull_request'
run: |
# Extract pytorch version from the docker
PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"')
PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:rocm-latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"')
export DOCKER_TAG=$(awk '{match($0, /dev[0-9]+/, arr); print arr[0]}' <<< "${PYTORCH_VERSION}")
docker push ghcr.io/pytorch-labs/tritonbench:rocm-${DOCKER_TAG}
docker push ghcr.io/pytorch-labs/tritonbench:rocm-latest
Expand Down
12 changes: 9 additions & 3 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@
fbcode_skip_file_path = "fb/skip_tests_h100_fbcode.yaml"
SKIP_FILE = importlib.resources.files(__package__).joinpath(fbcode_skip_file_path)
else:
SKIP_FILE_NAME = "skip_tests_h100_pytorch.yaml"
try:
# test if it is Triton main branch
import triton.tools.experimental_descriptor # @manual # noqa: F401

SKIP_FILE_NAME = "skip_tests_h100_triton_main.yaml"
except ModuleNotFoundError:
pass
import os

SKIP_FILE = os.path.abspath(
os.path.join(os.path.dirname(__file__), "skip_tests_h100_pytorch.yaml")
)
SKIP_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), SKIP_FILE_NAME))

with open(SKIP_FILE, "r") as f:
skip_tests = yaml.safe_load(f)
Expand Down
1 change: 1 addition & 0 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ fp8_gemm:
# triton_*_persistent requires triton-main
- triton_persistent_fp8_gemm
- triton_tma_persistent_fp8_gemm
# fbgemm fp8 gemm requires triton-main (desc_helper.fill_2d_tma_descriptor)
fp8_gemm_rowwise:
gemm:
# triton_*_persistent_* requires triton-main
Expand Down
40 changes: 40 additions & 0 deletions test/test_gpu/skip_tests_h100_triton_main.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Tests we skip in triton-pytorch + OSS CI
# triton-pytorch is the triton version bundled with pytorch nightly
# We need to skip kernels that only work on triton-main
# Usage:
# op-name: to skip an entire operator
# op-name:\n\t- impl-name to skip an impl
flash_attention:
# thunderkittens cannot handle the default input shapes
- tk
# _ws kernels require Triton with warp specialization
- triton_tutorial_flash_v2_ws
- triton_tutorial_flash_v2_tma_ws
fp8_attention:
# fb-only kernel
- colfax_fmha
# fb-only kernels
fp8_fused_quant_gemm_rowwise:
fp8_gemm:
# FIXME: out of shared memory
- triton_persistent_fp8_gemm
# FIXME: out of shared memory
- triton_tma_persistent_fp8_gemm
gemm:
# FIXME: out of shared memory
- triton_tma_persistent_matmul
# FIXME: out of shared memory
- triton_tma_persistent_cached_matmul
# internal only kernels
- hstu_triton_matmul
- colfax_cutlass_matmul
# jagged tests are slow, so disable them in OSS
jagged_layer_norm:
jagged_mean:
jagged_softmax:
jagged_sum:
# FIXME: ragged attention will Abort (Core Dump) on Triton Main
ragged_attention:
test_op:
fwd_only_ops:
- flash_attention
2 changes: 1 addition & 1 deletion tools/cutlass_kernels/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_colfax_cutlass(colfax_cutlass_lib: str):

def install_colfax_cutlass():
# compile colfax_cutlass kernels
output_dir = COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath(".data")
output_dir = REPO_PATH.joinpath(".data", "cutlass_kernels")
output_dir.mkdir(parents=True, exist_ok=True)
cmd = ["nvcc"]
cmd.extend(COMPILER_FLAGS)
Expand Down
2 changes: 1 addition & 1 deletion tools/tk/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_tk_attn_h100_fwd(tk_lib):

def install_tk():
# compile thunderkitten kernels
output_dir = TRITONBENCH_TK_PATH.joinpath(".data")
output_dir = REPO_PATH.joinpath(".data", "tk")
output_dir.mkdir(parents=True, exist_ok=True)
cmd = ["nvcc"]
cmd.extend(COMPILER_FLAGS)
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/int4_gemm/int4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from typing import Any, List, Optional

import torch
import triton
import triton.language as tl
import triton.ops

from tritonbench.utils.triton_op import (
BenchmarkOperator,
Expand Down
6 changes: 3 additions & 3 deletions tritonbench/utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
def set_env():
# set cutlass dir
# by default we use the cutlass version built with pytorch
import torch
import torch._inductor.config as inductor_config

current_cutlass_dir = torch._inductor.config.cuda.cutlass_dir
current_cutlass_dir = inductor_config.cuda.cutlass_dir
if not os.path.exists(current_cutlass_dir):
tb_cutlass_dir = REPO_PATH.joinpath("submodules", "cutlass")
if tb_cutlass_dir.is_dir():
torch._inductor.config.cuda.cutlass_dir = str(tb_cutlass_dir)
inductor_config.cuda.cutlass_dir = str(tb_cutlass_dir)


def set_random_seed():
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ def load_library(library_path: str):
import torch

prefix, _delimiter, so_file = library_path.partition("/")
so_full_path = REPO_PATH.joinpath("tools", prefix, ".data", so_file).resolve()
so_full_path = REPO_PATH.joinpath(".data", prefix, so_file).resolve()
torch.ops.load_library(str(so_full_path))

0 comments on commit e8f5ba4

Please sign in to comment.