diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 30c1f9379..8ece28a89 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -142,8 +142,10 @@ def enable(lib=aten_lib): lib.impl("fill.Scalar", fill_scalar, "CUDA") lib.impl("fill.Tensor", fill_tensor, "CUDA") lib.impl("flip", flip, "CUDA") - lib.impl("tile", tile, "CUDA") + lib.impl("slice_scatter", slice_scatter, "CUDA") + lib.impl("select_scatter", select_scatter, "CUDA") lib.impl("index_select", index_select, "CUDA") + lib.impl("tile", tile, "CUDA") lib.impl("masked_fill", masked_fill, "CUDA") lib.impl("_unique2", _unique2, "CUDA") lib.impl("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa, "CUDA") diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 364802de0..1c5292394 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -80,9 +80,11 @@ from .rms_norm import rms_norm from .rsqrt import rsqrt from .scatter import scatter +from .select_scatter import select_scatter from .sigmoid import sigmoid from .silu import silu from .sin import sin +from .slice_scatter import slice_scatter from .softmax import softmax from .stack import stack from .sub import sub @@ -226,6 +228,8 @@ "where_self", "where_scalar_self", "where_scalar_other", + "select_scatter", + "slice_scatter", "masked_fill", "_unique2", "_upsample_bicubic2d_aa", diff --git a/src/flag_gems/ops/select_scatter.py b/src/flag_gems/ops/select_scatter.py new file mode 100644 index 000000000..40c15cb2a --- /dev/null +++ b/src/flag_gems/ops/select_scatter.py @@ -0,0 +1,86 @@ +import logging + +import torch +import triton +import triton.language as tl + +from ..utils import libentry, offsetCalculator, restride_dim + + +def cfggen(): + block_m = [1, 2, 4, 8] + configs = [ + triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m + ] + return configs + + +@libentry() +@triton.autotune(configs=cfggen(), key=["M", "N"]) +@triton.jit +def select_scatter_kernel( + inp, + inp_indices, + src, + src_offsets, + M, + N, + index, + stride_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + rows_mask = rows_offsets < M + + for off in range(0, N, BLOCK_N): + cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] + cols_mask = cols_offsets < N + + offsets = rows_offsets * N + cols_offsets + mask = rows_mask and cols_mask + + indices = tl.load(inp_indices + offsets, mask=mask, other=0) + src_indices = tl.load(src_offsets + offsets, mask=mask, other=0) + cur_src = tl.load(src + src_indices, mask=mask, other=0) + + indices += index * stride_dim + tl.store(inp + indices, cur_src, mask=mask) + + +def select_scatter(inp, src, dim, index): + logging.debug("GEMS SELECT_SCATTER") + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" + assert index >= -inp.size(dim) and index < inp.size(dim), "Invalid index" + dim = dim % inp.ndim + index = index % inp.size(dim) + out = inp.clone().contiguous() + src = src.contiguous() + + valid_shape = list(inp.shape) + del valid_shape[dim] + assert ( + list(src.shape) == valid_shape + ), "Expected src to have a size equal to the slice of self" + + src_expanded_shape = list(inp.shape) + src_expanded_shape[dim] = 1 + out_strided = restride_dim(out, dim, src_expanded_shape) + idx = torch.arange(0, src.numel(), device=inp.device).reshape(src_expanded_shape) + indices = offsetCalculator( + out_strided, idx, out.stride(), dim, isInp=False + ).squeeze(dim=dim) + src_offsets = offsetCalculator(src, idx, src.stride(), dim, isInp=False).squeeze( + dim=dim + ) + + N = valid_shape[src.ndim - 1] + M = src.numel() // N + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + select_scatter_kernel[grid]( + out, indices, src, src_offsets, M, N, index, out.stride(dim) + ) + + return out diff --git a/src/flag_gems/ops/slice_scatter.py b/src/flag_gems/ops/slice_scatter.py new file mode 100644 index 000000000..71881a5c4 --- /dev/null +++ b/src/flag_gems/ops/slice_scatter.py @@ -0,0 +1,96 @@ +import logging + +import torch +import triton +import triton.language as tl + +from ..utils import libentry, offsetCalculator, restride_dim + + +def cfggen(): + block_m = [1, 2, 4, 8] + configs = [ + triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m + ] + return configs + + +@libentry() +@triton.autotune(configs=cfggen(), key=["M", "N"]) +@triton.jit +def slice_scatter_kernel( + inp, + inp_indices, + src, + src_offsets, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + rows_mask = rows_offsets < M + + for off in range(0, N, BLOCK_N): + cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] + cols_mask = cols_offsets < N + + offsets = rows_offsets * N + cols_offsets + mask = rows_mask and cols_mask + + indices = tl.load(inp_indices + offsets, mask=mask, other=0) + src_indices = tl.load(src_offsets + offsets, mask=mask, other=0) + cur_src = tl.load(src + src_indices, mask=mask, other=0) + + tl.store(inp + indices, cur_src, mask=mask) + + +def slice_scatter(inp, src, dim=0, start=None, end=None, step=1): + logging.debug("GEMS SLICE_SCATTER") + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" + assert step > 0, "slice step must be positive" + dim = dim % inp.ndim + out = inp.clone().contiguous() + src = src.contiguous() + size_dim = inp.size(dim) + + if start is None: + start = 0 + if end is None: + end = size_dim + + range = end - start + if end < start: + range = 0 + elif (end - start) > size_dim: + range = size_dim + start = 0 + end = size_dim + + if range == 0: + return out + + valid_shape = list(inp.shape) + valid_shape[dim] = (range + (step - 1)) // step + assert ( + list(src.shape) == valid_shape + ), "Expected src to have a size equal to the slice of self" + + storage_offset = out.storage_offset() + start * out.stride(dim) + out_strided = restride_dim(out, dim, valid_shape, step, storage_offset) + idx = torch.arange(0, src.numel(), device=inp.device).reshape(valid_shape) + strides = list(out.stride()) + strides[dim] *= step + indices = ( + offsetCalculator(out_strided, idx, strides, dim, isInp=False) + storage_offset + ) + src_offsets = offsetCalculator(src, idx, src.stride(), dim, isInp=False) + + N = valid_shape[src.ndim - 1] + M = src.numel() // N + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + slice_scatter_kernel[grid](out, indices, src, src_offsets, M, N) + + return out diff --git a/src/flag_gems/utils/__init__.py b/src/flag_gems/utils/__init__.py index 6704a12c1..7340f729b 100644 --- a/src/flag_gems/utils/__init__.py +++ b/src/flag_gems/utils/__init__.py @@ -4,7 +4,7 @@ broadcastable, broadcastable_to, dim_compress, - offset_calculator, + offsetCalculator, restride_dim, ) @@ -13,7 +13,7 @@ "pointwise_dynamic", "dim_compress", "restride_dim", - "offset_calculator", + "offsetCalculator", "broadcastable_to", "broadcastable", ] diff --git a/src/flag_gems/utils/shape_utils.py b/src/flag_gems/utils/shape_utils.py index c1d88f7bf..e96eb771f 100644 --- a/src/flag_gems/utils/shape_utils.py +++ b/src/flag_gems/utils/shape_utils.py @@ -223,27 +223,6 @@ def can_use_int32_index(a): return True -def offsetCalculator(inp, idx, strides, dim, isInp): - ndim = inp.ndim - shape = list(inp.shape) - offsets = 0 - idx_dim = 0 - for d in range(0, ndim): - mod = idx % shape[d] - add_on = mod * strides[d] - offsets += add_on - if d == dim: - idx_dim = add_on - idx = idx // shape[d] - # FIXME: Should we write a fast div/mod - # to boost the '%' and '//'? (Since they may be run many times) - # See also: - # - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html - # - Division by Invariant Integers Using Multiplication, - # Torbjörn Granlund and Peter L. Montgomery, 1994. - return (offsets) if not isInp else (offsets - idx_dim) - - def restride_dim(src, dim, shape, step=0, storage_offset=None): strides = list(src.stride()) strides[dim] *= step @@ -290,6 +269,48 @@ def add_on_kernel( def offset_calculator(inp, idx, strides, dim, isInp): + """ + Calculate the flat index(a.k.a offset) for a given ravel index in a multi-dimensional array. + The formula can be seen in: + - https://numpy.org/doc/stable/reference/arrays.ndarray.html#internal-memory-layout-of-an-ndarray + - https://numpy.org/devdocs/user/basics.indexing.html#single-element-indexing + + + Parameters: + inp (tensor): The input multi-dimensional array from which the offset is calculated. + idx (tensor): The linear index for which the offset is to be calculated. + strides (list of int): A list containing the stride lengths for each dimension of the input array. + dim (int): The specific dimension for which the index offset needs to be calculated. + isInp (bool): A flag indicating whether the tensor 'inp' is the parameter 'self' + in scatter/gather/index_* operators or not. + + In operators such as scatter/gather and index_*, when the input tensor 'inp' + is the 'self' tensor to be processed, we may need to modify its offsets later. + For instance, in the scatter operator, the offset is calculated using the formula: + + inp_offset = origin_offset - stride[dim] * n_dim + stride[dim] * index. + + In this case, we return the fixed part of the formula: + + origin_offset - stride[dim] * n_dim, + + to facilitate subsequent modifications. + For other types of input 'inp', we return the complete calculation result + of origin_offsets directly. + + + Returns: + The calculated offset. If isInp is True, the fixed offset is returned; otherwise, the origin offset is returned. + + + Note: + The function includes a comment suggesting the potential optimization of division and modulus operations, + which may be beneficial if this function is called frequently. + See also: + - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html + - Division by Invariant Integers Using Multiplication, + Torbjörn Granlund and Peter L. Montgomery, 1994. + """ ndim = inp.ndim shape = list(inp.shape) offsets = torch.zeros_like(inp, dtype=torch.int32, device=inp.device) @@ -309,3 +330,24 @@ def offset_calculator(inp, idx, strides, dim, isInp): idx_dim = add_on idx = idx // shape[d] return offsets if not isInp else (offsets - idx_dim) + + +def offsetCalculator(inp, idx, strides, dim, isInp): + ndim = inp.ndim + shape = list(inp.shape) + offsets = 0 + idx_dim = 0 + for d in range(0, ndim): + mod = idx % shape[d] + add_on = mod * strides[d] + offsets += add_on + if d == dim: + idx_dim = add_on + idx = idx // shape[d] + # FIXME: Should we write a fast div/mod + # to boost the '%' and '//'? (Since they may be run many times) + # See also: + # - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html + # - Division by Invariant Integers Using Multiplication, + # Torbjörn Granlund and Peter L. Montgomery, 1994. + return (offsets) if not isInp else (offsets - idx_dim) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index ec9f3ca4a..542516452 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -443,9 +443,68 @@ def test_accuracy_gather(inp_shape, dim, dtype): gems_assert_equal(res_out, ref_out) +@pytest.mark.select_scatter +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_select_scatter(shape, dim, dtype): + import random + + index = random.randint(0, shape[dim] - 1) + inp = torch.randn(shape, dtype=dtype, device="cuda") + + src_shape = list(inp.shape) + del src_shape[dim] + src = torch.randn(src_shape, dtype=dtype, device="cuda") + + ref_inp = to_reference(inp) + ref_src = to_reference(src) + ref_out = torch.select_scatter(ref_inp, dim=dim, index=index, src=ref_src) + with flag_gems.use_gems(): + res_out = torch.select_scatter(inp, dim=dim, index=index, src=src) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.slice_scatter +@pytest.mark.parametrize(("dim", "shape"), DIM_SHAPE) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("start", [16, 64]) +@pytest.mark.parametrize("end", [1024, 256]) +@pytest.mark.parametrize("step", [1, 2]) +def test_accuracy_slice_scatter(shape, dim, dtype, start, end, step): + inp = torch.randn(shape, dtype=dtype, device="cuda") + + range = end - start + valid_shape = list(inp.shape) + if end < start: + range = 0 + elif (end - start) > valid_shape[dim]: + range = valid_shape[dim] + start = 0 + end = valid_shape[dim] + + valid_shape[dim] = (range + (step - 1)) // step + + src = torch.randn(valid_shape, dtype=dtype, device="cuda") + + ref_inp = to_reference(inp) + ref_src = to_reference(src) + ref_out = torch.slice_scatter( + ref_inp, dim=dim, src=ref_src, start=start, end=end, step=step + ) + with flag_gems.use_gems(): + res_out = torch.slice_scatter( + inp, dim=dim, src=src, start=start, end=end, step=step + ) + + gems_assert_equal(res_out, ref_out) + + # TODO: failed at (200, 40999, 3) @pytest.mark.index_select -@pytest.mark.parametrize("dim, shape", DIM_SHAPE) +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_index_select(shape, dim, dtype): inp = torch.randn(shape, dtype=dtype, device="cuda")