Skip to content

Commit

Permalink
implement Repeat with fixed output shape (#7114)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored May 25, 2024
1 parent 22e912e commit 3369bf7
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
34 changes: 34 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,40 @@ def fn(boxes, scores):
self.runAtenTest((boxes, scores), fn)


class TestHelperFunction(test_utils.XlaTestCase):

def test_repeat_truncated(self):
from torch_xla.experimental.custom_kernel import repeat_with_fixed_output_size
met.clear_all()
device = torch_xla.device()
total_repeat_length = 20
input = torch.randn(10).to(device)
repeats = torch.tensor([0, 1, 2, 0, 4, 0, 6, 7, 8, 9]).to(device)
res = repeat_with_fixed_output_size(input, repeats, total_repeat_length)
# make sure there is no graph break
assert 'aten::' not in met.short_metrics_report()
expected = torch.repeat_interleave(input, repeats)[:total_repeat_length]
self.assertTrue(torch.allclose(res.cpu(), expected.cpu()))

def test_repeat_extended(self):
from torch_xla.experimental.custom_kernel import repeat_with_fixed_output_size
met.clear_all()
device = torch_xla.device()
total_repeat_length = 100
input = torch.randn(10).to(device)
repeats = torch.tensor([0, 5, 2, 0, 4, 9, 6, 7, 8, 0]).to(device)
res = repeat_with_fixed_output_size(input, repeats, total_repeat_length)
# make sure there is no graph break
assert 'aten::' not in met.short_metrics_report()
base = torch.repeat_interleave(input, repeats)[:total_repeat_length]
# remaining space will be filled with last value in `input`.
expected = torch.cat(
(base,
torch.repeat_interleave(input[-1],
total_repeat_length - base.size()[0])))
self.assertTrue(torch.allclose(res.cpu(), expected.cpu()))


if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
Expand Down
47 changes: 47 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,53 @@ def _make_group_metadata(
return group_offsets, group_ids, m_tile_ids, num_tiles


# Repeat the `input` tensor `repeats` number of times. We expect `input` and
# `repeats` both be 1d tensor with same shape. output shape will be [total_repeat_length].
# If `total_repeat_length` is larger than the repeated tensor length we will use the last value
# in the `input` to fill it up. If `total_repeat_length` is smaller than repeated tensor length
# we will truncate the repeated tensor.
def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor,
total_repeat_length: int):
# currently only support 1d input and 1d repeats
assert len(input.size()) == 1
assert len(repeats.size()) == 1
device = input.device

# to better understand this code, let's assume
# input.size() = [10]
# repeats = [0, 1, 2, 0, 4, 0, 6, 7, 8, 9]
# total_repeat_length = 20

# shift the repeats by one
# tensor([0, 0, 1, 2, 0, 4, 0, 6, 7, 8])
exclusive_repeats = torch.roll(repeats, shifts=1)
exclusive_repeats[0] = 0

# tensor([ 0, 0, 1, 3, 3, 7, 7, 13, 20, 28])
scatter_indices = torch.cumsum(exclusive_repeats, dim=0)
# set the out of bound indices to 0 and calculate how many of them.
# tensor([ 0, 0, 1, 3, 3, 7, 7, 13, 0, 0])
valid_indices = torch.where(scatter_indices >= total_repeat_length,
torch.zeros_like(scatter_indices),
scatter_indices)
out_of_bound_count = torch.where(scatter_indices >= total_repeat_length, 1,
0).sum()

# tensor([2, 1, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
block_split_indicators = torch.zeros(
total_repeat_length, dtype=torch.int64, device=device)
block_split_indicators.scatter_add_(0, valid_indices,
torch.ones_like(block_split_indicators))
# out_of_bound indices also scatter to index 0, need to offset them
block_split_indicators[0] -= out_of_bound_count

# value in gather_indices represents the index in the input.
# tensor([1, 2, 2, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7])
gather_indices = torch.cumsum(block_split_indicators, dim=0) - 1
res = torch.gather(input, 0, gather_indices)
return res


def gmm(lhs: torch.Tensor, rhs: torch.Tensor,
group_sizes: torch.Tensor) -> torch.Tensor:
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
Expand Down

0 comments on commit 3369bf7

Please sign in to comment.