From f87e971854019d744ddb7954b2f57415b3fc675b Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Wed, 13 Dec 2023 10:29:49 -0800 Subject: [PATCH] [Backport]Handle negative dim for Diagonal Scatter (#6123) (#6129) --- test/test_operations.py | 7 +++++++ torch_xla/csrc/aten_xla_type.cpp | 9 +++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index fbac40dbba9..17ef3359655 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1200,6 +1200,13 @@ def test_fn(a): self.runAtenTest(torch.rand(4, 3), test_fn) + def test_diagonal_scatter_negative_dim(self): + + def test_fn(input, src): + return torch.diagonal_scatter(input, src, 0, dim1=-1, dim2=0) + + self.runAtenTest([torch.zeros(3, 3), torch.ones(3)], test_fn) + def test_scatter_add_bool(self): xla_device = xm.xla_device() a = torch.tensor([[True, True, True, True, True], diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 560afa0cf8e..b2cd3c0420d 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1105,10 +1105,15 @@ at::Tensor XLANativeFunctions::diagonal_scatter(const at::Tensor& base, int64_t dim2) { auto base_ = bridge::GetXlaTensor(base); auto mutated_view_ = bridge::GetXlaTensor(mutated_view); + int64_t base_rank = bridge::GetXlaTensor(base)->shape().get().rank(); + int64_t canonical_dim1 = + torch::lazy::GetCanonicalDimensionIndex(dim1, base_rank); + int64_t canonical_dim2 = + torch::lazy::GetCanonicalDimensionIndex(dim2, base_rank); return bridge::AtenFromXlaTensor( base_->CreateFrom(torch::lazy::MakeNode( - base_->GetIrValue(), mutated_view_->GetIrValue(), offset, dim1, - dim2))); + base_->GetIrValue(), mutated_view_->GetIrValue(), offset, + canonical_dim1, canonical_dim2))); } at::Tensor XLANativeFunctions::div(const at::Tensor& self,