From e4865e36372a4acd4cbc1581917d1d930e5f9f7c Mon Sep 17 00:00:00 2001 From: David Huang Date: Fri, 20 Sep 2024 06:11:43 +0000 Subject: [PATCH] [torch_xla2] Fix scatter op --- experimental/torch_xla2/test/test_core_aten_ops.py | 10 ++++++++++ experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 14 ++++++++++---- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index 68e37209a1b..d207bc22a82 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -3651,6 +3651,16 @@ def test_aten_select_scatter_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs) + def test_aten_select_scatter_3(self): + args = ( + torch.randn((10, 10)).to(torch.float32), + torch.randint(0, 10, (10,)).to(torch.int64), + -1, + 0, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs) + def test_aten_sigmoid_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index aa920af29f8..1c078d1e760 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -148,7 +148,6 @@ "resize_as_", "rot90", "rsub", - "scatter", "scatter_reduce", "searchsorted", "special.airy_ai", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index dae6c5e908b..cb5cae464d9 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1441,12 +1441,16 @@ def _aten_atan(self): # aten.scatter_reduce +@op(torch.ops.aten.scatter) @op(torch.ops.aten.scatter_reduce) def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): + if isinstance(src, float): + dtype = _torch_binary_scalar_type(src, input) + src = jnp.array(src, dtype=dtype) input_indexes, source_indexes = _scatter_index(dim, index) - if reduce == "sum": + if reduce == "sum" or reduce == "add": return input.at[input_indexes].add(src[source_indexes]) - elif reduce == "prod": + elif reduce == "prod" or reduce == "multiply": return input.at[input_indexes].multiply(src[source_indexes]) elif reduce == "mean": return input.at[input_indexes].add(src[source_indexes]) @@ -1455,7 +1459,7 @@ def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): elif reduce == "amin": return input.at[input_indexes].min(src[source_indexes]) else: - raise RuntimeError("Unknow reduction type: ", reduce) + raise RuntimeError("Unknown reduction type: ", reduce) # aten.acos @@ -1663,10 +1667,12 @@ def _aten_reciprocal(a): return 1 / a -# aten.scatter +# aten.select_scatter @op(torch.ops.aten.select_scatter) def _aten_select_scatter(input, src, dim, index): input_indexes = [] + if dim < 0: + dim += len(input.shape) for x in range(len(input.shape)): if x == dim: input_indexes.append(index)