Skip to content

Commit

Permalink
Enhance Benchmarking for repeat_interleave Operation (FlagOpen#274)
Browse files Browse the repository at this point in the history
* Relocate select and slice benchmarks to test_select_and_slice_perf.py

* sort keys for summary result

* clean cuda cache after benchmark

* fix repeat_interleave

* modify format for summary info
  • Loading branch information
kiddyjinjin authored Nov 6, 2024
1 parent 3e47645 commit 2bd92c1
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 208 deletions.
6 changes: 6 additions & 0 deletions benchmark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def setup_once(request):
print(note_info)


@pytest.fixture(scope="module", autouse=True)
def clear_cuda_cache():
yield
torch.cuda.empty_cache()


@pytest.fixture()
def extract_and_log_op_attributes(request):
print("")
Expand Down
10 changes: 9 additions & 1 deletion benchmark/core_shapes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ TensorRepeatBenchmark:
- [512, 512]
- [512, 1024]
- [512, 2048]
shape_desc: "((B), M, N) * 3"
shape_desc: "(B), M, N"

GenericBenchmarkExcluse1D:
shapes:
Expand All @@ -108,6 +108,14 @@ GenericBenchmark2DOnly:
- [4096, 4096]
- [1024, 65536]

UnaryReductionBenchmark:
shapes:
- [1048576,] # 1024 * 1024
- [64, 64]
- [4096, 4096]
- [64, 512, 512]
- [1024, 1024, 1024]

UnaryPointwiseBenchmark:
shapes:
- [1073741824,] # 1024 * 1024 * 1024
Expand Down
45 changes: 24 additions & 21 deletions benchmark/summary_for_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


@dataclass
class SummaryResult:
class SummaryResultOverDtype:
op_name: str = ""
float16_speedup: float = 0.0
float32_speedup: float = 0.0
Expand All @@ -51,13 +51,13 @@ class SummaryResult:
def __str__(self) -> str:
return (
f"{self.op_name:<30} "
f"{self.float16_speedup:<15.6f} "
f"{self.float32_speedup:<15.6f} "
f"{self.bfloat16_speedup:<15.6f} "
f"{self.int16_speedup:<15.6f} "
f"{self.int32_speedup:<15.6f} "
f"{self.bool_speedup:<15.6f} "
f"{self.cfloat_speedup:<15.6f}"
f"{self.float16_speedup:<20.6f} "
f"{self.float32_speedup:<20.6f} "
f"{self.bfloat16_speedup:<20.6f} "
f"{self.int16_speedup:<20.6f} "
f"{self.int32_speedup:<20.6f} "
f"{self.bool_speedup:<20.6f} "
f"{self.cfloat_speedup:<20.6f}"
)


Expand Down Expand Up @@ -99,13 +99,13 @@ def parse_log(log_file_path: str) -> List[BenchmarkResult]:
return benchmark_results


def calculate_avg_speedup(metrics):
def calculate_avg_speedup_over_dtype(metrics):
speedups = [metric.speedup for metric in metrics if metric.speedup is not None]
return sum(speedups) / len(speedups) if speedups else 0.0


def summary_for_plot(benchmark_results):
summary = defaultdict(SummaryResult)
summary = defaultdict(SummaryResultOverDtype)

dtype_mapping = {
"torch.float16": "float16_speedup",
Expand All @@ -114,7 +114,7 @@ def summary_for_plot(benchmark_results):
"torch.int16": "int16_speedup",
"torch.int32": "int32_speedup",
"torch.bool": "bool_speedup",
"torch.cfloat": "cfloat_speedup",
"torch.complex64": "cfloat_speedup",
}

for item in benchmark_results:
Expand All @@ -124,14 +124,14 @@ def summary_for_plot(benchmark_results):
else:
dtype_suffix = (
"_complex"
if "cfloat" in item.dtype
if "complex64" in item.dtype
else "_int"
if "int" in item.dtype
else "_bool"
)

op_name = item.op_name + dtype_suffix
avg_speedup = calculate_avg_speedup(item.result)
avg_speedup = calculate_avg_speedup_over_dtype(item.result)
cur_op_summary = summary[op_name]
cur_op_summary.op_name = op_name
setattr(
Expand All @@ -140,19 +140,22 @@ def summary_for_plot(benchmark_results):
avg_speedup,
)

# sort the keys based on `op_name`
sorted_summary = sorted(summary.values(), key=lambda x: x.op_name)

header = (
f"{'op_name':<30} "
f"{'float16_speedup':<16} "
f"{'float32_speedup':<16} "
f"{'bfloat16_speedup':<16} "
f"{'int16_speedup':<16} "
f"{'int32_speedup':<16} "
f"{'bool_speedup':<16} "
f"{'cfloat_speedup':<16}"
f"{'float16_speedup':<20} "
f"{'float32_speedup':<20} "
f"{'bfloat16_speedup':<20} "
f"{'int16_speedup':<20} "
f"{'int32_speedup':<20} "
f"{'bool_speedup':<20} "
f"{'cfloat_speedup':<20}"
)

print(header)
for result in summary.values():
for result in sorted_summary:
print(result)

return summary
Expand Down
25 changes: 21 additions & 4 deletions benchmark/test_norm_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

from .attri_util import FLOAT_DTYPES, BenchLevel
from .conftest import Config
from .performance_utils import GenericBenchmark, unary_input_fn
from .performance_utils import (
GenericBenchmark,
GenericBenchmarkExcluse1D,
unary_input_fn,
)


class NormBenchmark(GenericBenchmark):
Expand Down Expand Up @@ -75,15 +79,26 @@ def test_group_and_layer_norm_benchmark(op_name, torch_op, input_fn):
bench.run()


def weight_norm_input_fn(shape, dtype, device):
def weight_norm_interface_input_fn(shape, dtype, device):
dim = 0
v = torch.randn(shape, dtype=dtype, device=device)
g = torch.randn(shape[dim], dtype=dtype, device=device)
yield v, g, dim


def weight_norm_input_fn(shape, dtype, device):
v = torch.randn(shape, dtype=dtype, device=device)
g = torch.randn(shape, dtype=dtype, device=device)
yield v, g, 0


norm_operations = [
("weight_norm_interface", torch._weight_norm_interface, weight_norm_input_fn),
(
"weight_norm_interface",
torch._weight_norm_interface,
weight_norm_interface_input_fn,
),
("weight_norm", torch._weight_norm, weight_norm_input_fn),
("vector_norm", torch.linalg.vector_norm, unary_input_fn),
]

Expand All @@ -96,5 +111,7 @@ def weight_norm_input_fn(shape, dtype, device):
],
)
def test_weight_vector_norm_benchmark(op_name, torch_op, input_fn):
bench = GenericBenchmark(input_fn=input_fn, op_name=op_name, torch_op=torch_op)
bench = GenericBenchmarkExcluse1D(
input_fn=input_fn, op_name=op_name, torch_op=torch_op
)
bench.run()
170 changes: 1 addition & 169 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
from typing import Generator

import pytest
Expand All @@ -22,6 +21,7 @@ def set_more_shapes(self):
more_shapes_1d = [
(4,),
(1024,),
(1024 * 1024 * 1024,),
]
more_shapes_2d = [(1024, 2**i) for i in range(0, 20, 4)]
more_shapes_3d = [(64, 64, 2**i) for i in range(0, 15, 4)]
Expand Down Expand Up @@ -95,23 +95,6 @@ def cumsum_input_fn(shape, cur_dtype, device):
yield inp, 1


def index_select_input_fn(shape, cur_dtype, device):
inp = generate_tensor_input(shape, cur_dtype, device)
threshold = 0.1
dim = 0
index_size = inp.size(dim)
from math import floor

index = torch.randint(0, index_size, [floor(index_size * threshold)], device=device)
yield inp, dim, index


def masked_select_input_fn(shape, cur_dtype, device):
inp = generate_tensor_input(shape, cur_dtype, device)
mask = generate_tensor_input(shape, cur_dtype, device) < 0.3
yield inp, mask


@pytest.mark.parametrize(
"op_name, torch_op, input_fn, dtypes",
[
Expand Down Expand Up @@ -143,161 +126,10 @@ def masked_select_input_fn(shape, cur_dtype, device):
FLOAT_DTYPES + INT_DTYPES,
marks=pytest.mark.cumsum,
),
pytest.param(
"index_select",
torch.index_select,
index_select_input_fn,
FLOAT_DTYPES,
marks=pytest.mark.index_select,
),
pytest.param(
"masked_select",
torch.masked_select,
masked_select_input_fn,
FLOAT_DTYPES,
marks=pytest.mark.masked_select,
),
],
)
def test_generic_reduction_benchmark(op_name, torch_op, input_fn, dtypes):
bench = GenericBenchmark2DOnly(
input_fn=input_fn, op_name=op_name, torch_op=torch_op, dtypes=dtypes
)
bench.run()


class TensorSelectBenchmark(GenericBenchmark2DOnly):
def set_more_shapes(self):
shapes = super().set_more_shapes()
return [
# this filter is for scatter
shape
for shape in shapes
if len(shape) == 2 and shape[0] > 16 and shape[1] > 16
]


@pytest.mark.scatter
def test_perf_scatter():
def scatter_input_fn(shape, dtype, device):
batch, size = shape
src_shape = [batch // 16, size // 16]
inp = torch.randn(shape, dtype=dtype, device=device)
src = torch.randn(src_shape, dtype=dtype, device=device)

dim = random.choice([0, 1])
size_dim = min(src_shape[dim], shape[dim])

index_shape = [
random.randint(1, min(src_shape[0], shape[0])),
random.randint(1, min(src_shape[1], shape[1])),
]
index = torch.empty(tuple(index_shape), dtype=torch.long, device=device)

m, n = index_shape

index_size_dim = index_shape[dim]
# make unique indices
for i in range(1 if dim == 0 else m):
for j in range(1 if dim == 1 else n):
ii = [i, j]
ii[dim] = slice(0, index.size(dim) + 1)
index[tuple(ii)] = torch.randperm(size_dim)[0:index_size_dim]

yield inp, dim, index, src

bench = TensorSelectBenchmark(
op_name="scatter",
torch_op=torch.scatter,
input_fn=scatter_input_fn,
dtypes=FLOAT_DTYPES,
)
bench.run()


@pytest.mark.gather
def test_perf_gather():
def gather_input_fn(shape, dtype, device):
inp = torch.randn(shape, dtype=dtype, device=device)

dim = random.choice([0, 1])
size_dim = shape[dim]
index_shape = [
random.randint(1, shape[0]),
random.randint(1, shape[1]),
]
index = torch.empty(tuple(index_shape), dtype=torch.long, device=device)

m, n = index_shape

index_size_dim = index_shape[dim]
# make unique indices
for i in range(1 if dim == 0 else m):
for j in range(1 if dim == 1 else n):
ii = [i, j]
ii[dim] = slice(0, index.size(dim) + 1)
index[tuple(ii)] = torch.randperm(size_dim)[0:index_size_dim]

yield inp, dim, index

bench = GenericBenchmark2DOnly(
op_name="gather",
torch_op=torch.gather,
input_fn=gather_input_fn,
dtypes=FLOAT_DTYPES,
)
bench.run()


@pytest.mark.slice_scatter
def test_slice_scatter_perf():
def slice_scatter_input_fn(shape, dtype, device):
dim = random.choice([0, 1])
start = 16
end = 1024
step = 2

inp = torch.randn(shape, dtype=dtype, device=device)

range = end - start
valid_shape = list(inp.shape)
if end < start:
range = 0
elif (end - start) > valid_shape[dim]:
range = valid_shape[dim]
start = 0
end = valid_shape[dim]

valid_shape[dim] = (range + (step - 1)) // step
src = torch.randn(valid_shape, dtype=dtype, device=device)
yield inp, src, dim, start, end, step

bench = GenericBenchmark2DOnly(
op_name="slice_scatter",
torch_op=torch.slice_scatter,
input_fn=slice_scatter_input_fn,
dtypes=FLOAT_DTYPES,
)
bench.run()


@pytest.mark.select_scatter
def test_select_scatter_perf():
def select_scatter_input_fn(shape, dtype, device):
dim = random.choice([0, 1])
index = random.randint(0, shape[dim] - 1)
inp = torch.randn(shape, dtype=dtype, device=device)

src_shape = list(inp.shape)
del src_shape[dim]
src = torch.randn(src_shape, dtype=dtype, device=device)

yield inp, src, dim, index

bench = GenericBenchmark2DOnly(
op_name="select_scatter",
torch_op=torch.select_scatter,
input_fn=select_scatter_input_fn,
dtypes=FLOAT_DTYPES,
)
bench.run()
Loading

0 comments on commit 2bd92c1

Please sign in to comment.