Skip to content

Commit

Permalink
[Operator] Add slice&select_scatter's benchmark (FlagOpen#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
GwokHiujin authored Oct 28, 2024
1 parent 4f7e792 commit 4bcb3ea
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,68 @@ def gather_args(dtype, batch, size):
sizes=SIZES,
)
bench.run()


def test_slice_scatter_perf():
def slice_scatter_args(dtype, batch, size):
shape = [batch, size]
import random

dim = random.choice([0, 1])
start = 16
end = 1024
step = 2

inp = torch.randn(shape, dtype=dtype, device="cuda")

range = end - start
valid_shape = list(inp.shape)
if end < start:
range = 0
elif (end - start) > valid_shape[dim]:
range = valid_shape[dim]
start = 0
end = valid_shape[dim]

valid_shape[dim] = (range + (step - 1)) // step
src = torch.randn(valid_shape, dtype=dtype, device="cuda")
return (inp, src, dim, start, end, step)

bench = Benchmark(
op_name="slice_scatter",
torch_op=torch.slice_scatter,
arg_func=slice_scatter_args,
dtypes=FLOAT_DTYPES,
batch=REDUCTION_BATCH,
sizes=SIZES,
)
bench.run()


def test_select_scatter_perf():
def select_scatter_args(dtype, batch, size):
shape = [batch, size]
import random

dim = random.choice([0, 1])

import random

index = random.randint(0, shape[dim] - 1)
inp = torch.randn(shape, dtype=dtype, device="cuda")

src_shape = list(inp.shape)
del src_shape[dim]
src = torch.randn(src_shape, dtype=dtype, device="cuda")

return (inp, src, dim, index)

bench = Benchmark(
op_name="select_scatter",
torch_op=torch.select_scatter,
arg_func=select_scatter_args,
dtypes=FLOAT_DTYPES,
batch=REDUCTION_BATCH,
sizes=SIZES,
)
bench.run()

0 comments on commit 4bcb3ea

Please sign in to comment.