From a3bb2e5542536ab6685c347d363d6d084b4583b0 Mon Sep 17 00:00:00 2001 From: Gregory Shikhman Date: Fri, 23 Aug 2024 19:21:22 +0000 Subject: [PATCH] Implement index_select op. Fixes #7454. --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 13 +++---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 2c14784c821..9b0452734f1 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -78,7 +78,6 @@ "igamma", "igammac", "index_reduce", - "index_select", "isclose", "kthvalue", "lgamma", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index bfaf6228678..6703b3e2b93 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -109,16 +109,9 @@ def _aten_select(x, dim, indexes): @op(torch.ops.aten.index_select) @op(torch.ops.aten.select_copy) def _aten_index_select(x, dim, index): - if isinstance(index, jax.Array): - index = index.astype(jnp.int64) - - dims = [] - for i in range(len(x.shape)): - if i == dim: - dims.append(index) - else: - dims.append(slice(None, None, None)) - return x[tuple(dims)] + if x.shape == (): + return x + return jnp.take(x, index, dim) @op(torch.ops.aten.mean)