Skip to content

Commit

Permalink
[Pallas] Support _histogram (#7115)
Browse files Browse the repository at this point in the history
Summary:
This pull request implements a limited version of pytorch.org/docs/stable/generated/torch.histc.html to support _make_group_metadata. In the future, we can consider properly lower the op or just make this python version more completed.

Test Plan:
PJRT_DEVICE=TPU python test/test_gmm.py
  • Loading branch information
alanwaketan authored May 25, 2024
1 parent 3369bf7 commit cb80583
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
58 changes: 57 additions & 1 deletion test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram
from torch_xla import runtime as xr
from torch_xla._internal import tpu

Expand Down Expand Up @@ -183,6 +183,62 @@ def test_make_group_metadata(self):
torch.all(torch.from_numpy(np.array(jax_meta[i])) == torch_meta[i]))
self.assertEqual(jax_num_tiles, torch_meta[-1].item())

def test_histogram(self):
test_grids = [
{
'input': [1, 4, 4, 1, 2, 3],
'min': 1,
'max': 4,
},
{
'input': [1, 4, 4, 1, 2, 3],
'min': 2,
'max': 3,
},
{
'input': [1, 4, 4, 1, 2, 3],
'min': 0,
'max': 5,
},
{
'input': [],
'min': 0,
'max': 5,
},
]

for test_grid in test_grids:
torch_chart = torch.histc(
torch.tensor(test_grid['input'], dtype=torch.float),
bins=test_grid['max'] - test_grid['min'] + 1,
min=test_grid['min'],
max=test_grid['max'],
)

chart, _ = _histogram(
torch.tensor(test_grid['input'], dtype=torch.int32).to("xla"),
min=test_grid['min'],
max=test_grid['max'],
)

self.assertTrue(torch.all(torch_chart == chart.cpu()))

def test_histogram_raise(self):
with self.assertRaisesRegex(AssertionError,
"input must be of torch.int32 dtype."):
_histogram(
torch.tensor([1, 4, 4, 1, 2, 3], dtype=torch.float),
min=4,
max=5,
)

with self.assertRaisesRegex(AssertionError, "min must be less than max."):
_histogram(
torch.tensor([1, 4, 4, 1, 2, 3], dtype=torch.int32),
min=4,
max=3,
)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,22 @@ def _calculate_num_tiles(x: int, tx: int) -> int:
return tiles


def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
"""
Compute the histogram of a int32 tensor. The bin edges are defined by the min and max values, with step = 1.
"""
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
assert min < max, "min must be less than max."

def searchsorted(sorted_sequence: torch.Tensor,
values_to_search: torch.Tensor) -> torch.Tensor:
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)

bin_edges = torch.linspace(
min, max, max - min + 1, dtype=input.dtype).to(input.device)
return searchsorted(bin_edges, input), bin_edges


# This can only be ran in cpu now as repeat_interleave is not lowered to xla.
def _make_group_metadata(
*,
Expand Down

0 comments on commit cb80583

Please sign in to comment.