Skip to content

Commit

Permalink
Implement index_select op.
Browse files Browse the repository at this point in the history
Fixes #7454.
  • Loading branch information
cornmander committed Aug 23, 2024
1 parent e4de7e5 commit a3bb2e5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 11 deletions.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
"igamma",
"igammac",
"index_reduce",
"index_select",
"isclose",
"kthvalue",
"lgamma",
Expand Down
13 changes: 3 additions & 10 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,9 @@ def _aten_select(x, dim, indexes):
@op(torch.ops.aten.index_select)
@op(torch.ops.aten.select_copy)
def _aten_index_select(x, dim, index):
if isinstance(index, jax.Array):
index = index.astype(jnp.int64)

dims = []
for i in range(len(x.shape)):
if i == dim:
dims.append(index)
else:
dims.append(slice(None, None, None))
return x[tuple(dims)]
if x.shape == ():
return x
return jnp.take(x, index, dim)


@op(torch.ops.aten.mean)
Expand Down

0 comments on commit a3bb2e5

Please sign in to comment.