From 4f7e7924cd47f8132ce98a67fe7601b09b7371e0 Mon Sep 17 00:00:00 2001 From: Hiujin Gwok <70586936+GwokHiujin@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:27:02 +0800 Subject: [PATCH] [Operator] slice&select scatter (#143) * add Ops & UT & Bench * add full zero ones Ops & UT & Bench * split normal op * [Operator] init slice&select scatter * code format * PR comment * split test_special_ops * add K-S test * split special perf * Exponential added. (#138) * exponential added. * Added K-S tests to exponential_, fp64 corrected. * aligned with aten prototype * Exponential_ uses uint64 offsets in Triton kernel. * Update pyproject config for new test dependencies. * resolve conflict * Use int64 indexing when needed & fix argmax (#146) 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max; 2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size * test for op * test for op * Making libentry thread safe (#136) * libentry now is lock protected. * Add multithreading tests for libentry. * polish code. * add argparse * fix desc * fix num * Update test_specific_ops.py * split UT files * fix * fix * [Operator] Optimize CrossEntropyLoss (#131) reimplement cross_entropy_loss forward and backward support; indices/probabilities/weight/reduction/ignore_index/label_smoothing; perform better than torch eager on large scale tensors * Exponential added. (#138) * exponential added. * Added K-S tests to exponential_, fp64 corrected. * aligned with aten prototype * Exponential_ uses uint64 offsets in Triton kernel. * Update pyproject config for new test dependencies. * Use int64 indexing when needed & fix argmax (#146) 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max; 2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size * Making libentry thread safe (#136) * libentry now is lock protected. * Add multithreading tests for libentry. * polish code. * [Test] Test for op (#151) * [chore] solve slice&select scatter's test cases * [fix] fix slice&select scatter's test cases * [chore] remove out-of-range indices in select_scatter's test cases * [chore] simplify slice_scatter's test cases * [fix] Added range that is deleted by mistake * Merge branch 'master' into slice&select_scatter * [chore] reformat * [fix] typo * [chore] Considering perf, pause the replacement of some aTen operators * slice_scatter * select_scatter * index_select * [fix] Add libentry in op.cumsum * [fix] Del slice&select scatter's perf tests * [Chore] Add pytest mark for slice&select scatter's test * [Fix] Correct slice_scatter test * [Fix] Replace CPU Tensor --------- Co-authored-by: Bowen12992 Co-authored-by: Tongxin Bai Co-authored-by: Clement Chan Co-authored-by: Bowen <81504862+Bowen12992@users.noreply.github.com> Co-authored-by: StrongSpoon <35829812+StrongSpoon@users.noreply.github.com> --- src/flag_gems/__init__.py | 4 +- src/flag_gems/ops/__init__.py | 4 ++ src/flag_gems/ops/select_scatter.py | 86 ++++++++++++++++++++++++++ src/flag_gems/ops/slice_scatter.py | 96 +++++++++++++++++++++++++++++ src/flag_gems/utils/__init__.py | 4 +- src/flag_gems/utils/shape_utils.py | 84 ++++++++++++++++++------- tests/test_reduction_ops.py | 61 +++++++++++++++++- 7 files changed, 314 insertions(+), 25 deletions(-) create mode 100644 src/flag_gems/ops/select_scatter.py create mode 100644 src/flag_gems/ops/slice_scatter.py diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 30c1f9379..8ece28a89 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -142,8 +142,10 @@ def enable(lib=aten_lib): lib.impl("fill.Scalar", fill_scalar, "CUDA") lib.impl("fill.Tensor", fill_tensor, "CUDA") lib.impl("flip", flip, "CUDA") - lib.impl("tile", tile, "CUDA") + lib.impl("slice_scatter", slice_scatter, "CUDA") + lib.impl("select_scatter", select_scatter, "CUDA") lib.impl("index_select", index_select, "CUDA") + lib.impl("tile", tile, "CUDA") lib.impl("masked_fill", masked_fill, "CUDA") lib.impl("_unique2", _unique2, "CUDA") lib.impl("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa, "CUDA") diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 364802de0..1c5292394 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -80,9 +80,11 @@ from .rms_norm import rms_norm from .rsqrt import rsqrt from .scatter import scatter +from .select_scatter import select_scatter from .sigmoid import sigmoid from .silu import silu from .sin import sin +from .slice_scatter import slice_scatter from .softmax import softmax from .stack import stack from .sub import sub @@ -226,6 +228,8 @@ "where_self", "where_scalar_self", "where_scalar_other", + "select_scatter", + "slice_scatter", "masked_fill", "_unique2", "_upsample_bicubic2d_aa", diff --git a/src/flag_gems/ops/select_scatter.py b/src/flag_gems/ops/select_scatter.py new file mode 100644 index 000000000..40c15cb2a --- /dev/null +++ b/src/flag_gems/ops/select_scatter.py @@ -0,0 +1,86 @@ +import logging + +import torch +import triton +import triton.language as tl + +from ..utils import libentry, offsetCalculator, restride_dim + + +def cfggen(): + block_m = [1, 2, 4, 8] + configs = [ + triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m + ] + return configs + + +@libentry() +@triton.autotune(configs=cfggen(), key=["M", "N"]) +@triton.jit +def select_scatter_kernel( + inp, + inp_indices, + src, + src_offsets, + M, + N, + index, + stride_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + rows_mask = rows_offsets < M + + for off in range(0, N, BLOCK_N): + cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] + cols_mask = cols_offsets < N + + offsets = rows_offsets * N + cols_offsets + mask = rows_mask and cols_mask + + indices = tl.load(inp_indices + offsets, mask=mask, other=0) + src_indices = tl.load(src_offsets + offsets, mask=mask, other=0) + cur_src = tl.load(src + src_indices, mask=mask, other=0) + + indices += index * stride_dim + tl.store(inp + indices, cur_src, mask=mask) + + +def select_scatter(inp, src, dim, index): + logging.debug("GEMS SELECT_SCATTER") + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" + assert index >= -inp.size(dim) and index < inp.size(dim), "Invalid index" + dim = dim % inp.ndim + index = index % inp.size(dim) + out = inp.clone().contiguous() + src = src.contiguous() + + valid_shape = list(inp.shape) + del valid_shape[dim] + assert ( + list(src.shape) == valid_shape + ), "Expected src to have a size equal to the slice of self" + + src_expanded_shape = list(inp.shape) + src_expanded_shape[dim] = 1 + out_strided = restride_dim(out, dim, src_expanded_shape) + idx = torch.arange(0, src.numel(), device=inp.device).reshape(src_expanded_shape) + indices = offsetCalculator( + out_strided, idx, out.stride(), dim, isInp=False + ).squeeze(dim=dim) + src_offsets = offsetCalculator(src, idx, src.stride(), dim, isInp=False).squeeze( + dim=dim + ) + + N = valid_shape[src.ndim - 1] + M = src.numel() // N + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + select_scatter_kernel[grid]( + out, indices, src, src_offsets, M, N, index, out.stride(dim) + ) + + return out diff --git a/src/flag_gems/ops/slice_scatter.py b/src/flag_gems/ops/slice_scatter.py new file mode 100644 index 000000000..71881a5c4 --- /dev/null +++ b/src/flag_gems/ops/slice_scatter.py @@ -0,0 +1,96 @@ +import logging + +import torch +import triton +import triton.language as tl + +from ..utils import libentry, offsetCalculator, restride_dim + + +def cfggen(): + block_m = [1, 2, 4, 8] + configs = [ + triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m + ] + return configs + + +@libentry() +@triton.autotune(configs=cfggen(), key=["M", "N"]) +@triton.jit +def slice_scatter_kernel( + inp, + inp_indices, + src, + src_offsets, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + rows_mask = rows_offsets < M + + for off in range(0, N, BLOCK_N): + cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] + cols_mask = cols_offsets < N + + offsets = rows_offsets * N + cols_offsets + mask = rows_mask and cols_mask + + indices = tl.load(inp_indices + offsets, mask=mask, other=0) + src_indices = tl.load(src_offsets + offsets, mask=mask, other=0) + cur_src = tl.load(src + src_indices, mask=mask, other=0) + + tl.store(inp + indices, cur_src, mask=mask) + + +def slice_scatter(inp, src, dim=0, start=None, end=None, step=1): + logging.debug("GEMS SLICE_SCATTER") + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" + assert step > 0, "slice step must be positive" + dim = dim % inp.ndim + out = inp.clone().contiguous() + src = src.contiguous() + size_dim = inp.size(dim) + + if start is None: + start = 0 + if end is None: + end = size_dim + + range = end - start + if end < start: + range = 0 + elif (end - start) > size_dim: + range = size_dim + start = 0 + end = size_dim + + if range == 0: + return out + + valid_shape = list(inp.shape) + valid_shape[dim] = (range + (step - 1)) // step + assert ( + list(src.shape) == valid_shape + ), "Expected src to have a size equal to the slice of self" + + storage_offset = out.storage_offset() + start * out.stride(dim) + out_strided = restride_dim(out, dim, valid_shape, step, storage_offset) + idx = torch.arange(0, src.numel(), device=inp.device).reshape(valid_shape) + strides = list(out.stride()) + strides[dim] *= step + indices = ( + offsetCalculator(out_strided, idx, strides, dim, isInp=False) + storage_offset + ) + src_offsets = offsetCalculator(src, idx, src.stride(), dim, isInp=False) + + N = valid_shape[src.ndim - 1] + M = src.numel() // N + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + slice_scatter_kernel[grid](out, indices, src, src_offsets, M, N) + + return out diff --git a/src/flag_gems/utils/__init__.py b/src/flag_gems/utils/__init__.py index 6704a12c1..7340f729b 100644 --- a/src/flag_gems/utils/__init__.py +++ b/src/flag_gems/utils/__init__.py @@ -4,7 +4,7 @@ broadcastable, broadcastable_to, dim_compress, - offset_calculator, + offsetCalculator, restride_dim, ) @@ -13,7 +13,7 @@ "pointwise_dynamic", "dim_compress", "restride_dim", - "offset_calculator", + "offsetCalculator", "broadcastable_to", "broadcastable", ] diff --git a/src/flag_gems/utils/shape_utils.py b/src/flag_gems/utils/shape_utils.py index c1d88f7bf..e96eb771f 100644 --- a/src/flag_gems/utils/shape_utils.py +++ b/src/flag_gems/utils/shape_utils.py @@ -223,27 +223,6 @@ def can_use_int32_index(a): return True -def offsetCalculator(inp, idx, strides, dim, isInp): - ndim = inp.ndim - shape = list(inp.shape) - offsets = 0 - idx_dim = 0 - for d in range(0, ndim): - mod = idx % shape[d] - add_on = mod * strides[d] - offsets += add_on - if d == dim: - idx_dim = add_on - idx = idx // shape[d] - # FIXME: Should we write a fast div/mod - # to boost the '%' and '//'? (Since they may be run many times) - # See also: - # - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html - # - Division by Invariant Integers Using Multiplication, - # Torbjörn Granlund and Peter L. Montgomery, 1994. - return (offsets) if not isInp else (offsets - idx_dim) - - def restride_dim(src, dim, shape, step=0, storage_offset=None): strides = list(src.stride()) strides[dim] *= step @@ -290,6 +269,48 @@ def add_on_kernel( def offset_calculator(inp, idx, strides, dim, isInp): + """ + Calculate the flat index(a.k.a offset) for a given ravel index in a multi-dimensional array. + The formula can be seen in: + - https://numpy.org/doc/stable/reference/arrays.ndarray.html#internal-memory-layout-of-an-ndarray + - https://numpy.org/devdocs/user/basics.indexing.html#single-element-indexing + + + Parameters: + inp (tensor): The input multi-dimensional array from which the offset is calculated. + idx (tensor): The linear index for which the offset is to be calculated. + strides (list of int): A list containing the stride lengths for each dimension of the input array. + dim (int): The specific dimension for which the index offset needs to be calculated. + isInp (bool): A flag indicating whether the tensor 'inp' is the parameter 'self' + in scatter/gather/index_* operators or not. + + In operators such as scatter/gather and index_*, when the input tensor 'inp' + is the 'self' tensor to be processed, we may need to modify its offsets later. + For instance, in the scatter operator, the offset is calculated using the formula: + + inp_offset = origin_offset - stride[dim] * n_dim + stride[dim] * index. + + In this case, we return the fixed part of the formula: + + origin_offset - stride[dim] * n_dim, + + to facilitate subsequent modifications. + For other types of input 'inp', we return the complete calculation result + of origin_offsets directly. + + + Returns: + The calculated offset. If isInp is True, the fixed offset is returned; otherwise, the origin offset is returned. + + + Note: + The function includes a comment suggesting the potential optimization of division and modulus operations, + which may be beneficial if this function is called frequently. + See also: + - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html + - Division by Invariant Integers Using Multiplication, + Torbjörn Granlund and Peter L. Montgomery, 1994. + """ ndim = inp.ndim shape = list(inp.shape) offsets = torch.zeros_like(inp, dtype=torch.int32, device=inp.device) @@ -309,3 +330,24 @@ def offset_calculator(inp, idx, strides, dim, isInp): idx_dim = add_on idx = idx // shape[d] return offsets if not isInp else (offsets - idx_dim) + + +def offsetCalculator(inp, idx, strides, dim, isInp): + ndim = inp.ndim + shape = list(inp.shape) + offsets = 0 + idx_dim = 0 + for d in range(0, ndim): + mod = idx % shape[d] + add_on = mod * strides[d] + offsets += add_on + if d == dim: + idx_dim = add_on + idx = idx // shape[d] + # FIXME: Should we write a fast div/mod + # to boost the '%' and '//'? (Since they may be run many times) + # See also: + # - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html + # - Division by Invariant Integers Using Multiplication, + # Torbjörn Granlund and Peter L. Montgomery, 1994. + return (offsets) if not isInp else (offsets - idx_dim) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index ec9f3ca4a..542516452 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -443,9 +443,68 @@ def test_accuracy_gather(inp_shape, dim, dtype): gems_assert_equal(res_out, ref_out) +@pytest.mark.select_scatter +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_select_scatter(shape, dim, dtype): + import random + + index = random.randint(0, shape[dim] - 1) + inp = torch.randn(shape, dtype=dtype, device="cuda") + + src_shape = list(inp.shape) + del src_shape[dim] + src = torch.randn(src_shape, dtype=dtype, device="cuda") + + ref_inp = to_reference(inp) + ref_src = to_reference(src) + ref_out = torch.select_scatter(ref_inp, dim=dim, index=index, src=ref_src) + with flag_gems.use_gems(): + res_out = torch.select_scatter(inp, dim=dim, index=index, src=src) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.slice_scatter +@pytest.mark.parametrize(("dim", "shape"), DIM_SHAPE) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("start", [16, 64]) +@pytest.mark.parametrize("end", [1024, 256]) +@pytest.mark.parametrize("step", [1, 2]) +def test_accuracy_slice_scatter(shape, dim, dtype, start, end, step): + inp = torch.randn(shape, dtype=dtype, device="cuda") + + range = end - start + valid_shape = list(inp.shape) + if end < start: + range = 0 + elif (end - start) > valid_shape[dim]: + range = valid_shape[dim] + start = 0 + end = valid_shape[dim] + + valid_shape[dim] = (range + (step - 1)) // step + + src = torch.randn(valid_shape, dtype=dtype, device="cuda") + + ref_inp = to_reference(inp) + ref_src = to_reference(src) + ref_out = torch.slice_scatter( + ref_inp, dim=dim, src=ref_src, start=start, end=end, step=step + ) + with flag_gems.use_gems(): + res_out = torch.slice_scatter( + inp, dim=dim, src=src, start=start, end=end, step=step + ) + + gems_assert_equal(res_out, ref_out) + + # TODO: failed at (200, 40999, 3) @pytest.mark.index_select -@pytest.mark.parametrize("dim, shape", DIM_SHAPE) +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_index_select(shape, dim, dtype): inp = torch.randn(shape, dtype=dtype, device="cuda")