Skip to content

Commit

Permalink
[torch_xla2] Fix scatter op
Browse files Browse the repository at this point in the history
  • Loading branch information
dvhg committed Sep 20, 2024
1 parent 65c3333 commit e4865e3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
10 changes: 10 additions & 0 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3651,6 +3651,16 @@ def test_aten_select_scatter_2(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs)

def test_aten_select_scatter_3(self):
args = (
torch.randn((10, 10)).to(torch.float32),
torch.randint(0, 10, (10,)).to(torch.int64),
-1,
0,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs)

def test_aten_sigmoid_0(self):
args = (torch.randn((10, 10)).to(torch.float32),)
kwargs = dict()
Expand Down
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@
"resize_as_",
"rot90",
"rsub",
"scatter",
"scatter_reduce",
"searchsorted",
"special.airy_ai",
Expand Down
14 changes: 10 additions & 4 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,12 +1441,16 @@ def _aten_atan(self):


# aten.scatter_reduce
@op(torch.ops.aten.scatter)
@op(torch.ops.aten.scatter_reduce)
def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True):
if isinstance(src, float):
dtype = _torch_binary_scalar_type(src, input)
src = jnp.array(src, dtype=dtype)
input_indexes, source_indexes = _scatter_index(dim, index)
if reduce == "sum":
if reduce == "sum" or reduce == "add":
return input.at[input_indexes].add(src[source_indexes])
elif reduce == "prod":
elif reduce == "prod" or reduce == "multiply":
return input.at[input_indexes].multiply(src[source_indexes])
elif reduce == "mean":
return input.at[input_indexes].add(src[source_indexes])
Expand All @@ -1455,7 +1459,7 @@ def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True):
elif reduce == "amin":
return input.at[input_indexes].min(src[source_indexes])
else:
raise RuntimeError("Unknow reduction type: ", reduce)
raise RuntimeError("Unknown reduction type: ", reduce)


# aten.acos
Expand Down Expand Up @@ -1663,10 +1667,12 @@ def _aten_reciprocal(a):
return 1 / a


# aten.scatter
# aten.select_scatter
@op(torch.ops.aten.select_scatter)
def _aten_select_scatter(input, src, dim, index):
input_indexes = []
if dim < 0:
dim += len(input.shape)
for x in range(len(input.shape)):
if x == dim:
input_indexes.append(index)
Expand Down

0 comments on commit e4865e3

Please sign in to comment.