From 9df9f0c755556b34c70ffb92548068507d769217 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 6 Feb 2024 21:08:41 +0000 Subject: [PATCH] cover mul (cherry picked from commit f55abc88ae361e89da675a1aa1e4a19e7a5c762a) --- test/stablehlo/test_unbounded_dynamism.py | 19 ++++++++++++++++++- torch_xla/csrc/tensor_methods.cpp | 23 ++++++++++++++++++----- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index d489997284e6..d8da36c329ee 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -29,7 +29,6 @@ def forward(self, *args): ep = torch.export.export(m, args=args, constraints=constraints) return ep - @unittest.skip("Unbounded Dynamism not supported on add.") def test_add(self): args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768))) constraints = [ @@ -42,6 +41,24 @@ def test_add(self): constraints) shlo_module = exported_program_to_stablehlo(ep) shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search( + r'tensor<\?x197x768xf32>.*tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>', + shlo_text) is not None) + + def test_add_scalar(self): + args = (torch.rand((10, 197, 768)), 0.345) + constraints = [ + torch.export.dynamic_dim(args[0], 0), + ] + ep = self._test_export_dynamism_wrapper(torch.ops.aten.add.Tensor, args, + constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search( + r'tensor.*tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>', + shlo_text) is not None) def test_addmm(self): args = (torch.rand((5)), torch.rand((10, 5)), torch.rand((5, 5))) diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index a366221720c7..4e6b08fdf7b5 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -738,14 +738,20 @@ XLATensorPtr add(const XLATensorPtr& input, const XLATensorPtr& other, xla::Shape input_shape = input->shape().get(); xla::Shape other_shape = other->shape().get(); torch::lazy::Value constant; + const torch::lazy::BackendDevice& device = input->GetDevice(); if (!input_shape.is_dynamic() && !other_shape.is_dynamic()) { constant = XLAGraphExecutor::Get()->GetIrValueForScalar( - alpha, other->shape(), logical_element_type, input->GetDevice()); + alpha, + xla::ShapeUtil::MakeScalarShape( + MakeXlaPrimitiveType(*logical_element_type, &device)), + logical_element_type, device); } else { SymIntElements sym_int_elements(other->GetIrValue()); constant = XLAGraphExecutor::Get()->GetIrValueForScalar( - alpha, other->shape(), sym_int_elements, logical_element_type, - input->GetDevice()); + alpha, + xla::ShapeUtil::MakeScalarShape( + MakeXlaPrimitiveType(*logical_element_type, &device)), + sym_int_elements, logical_element_type, device); } return input->CreateFrom(input->GetIrValue() + other->GetIrValue() * constant, @@ -755,12 +761,19 @@ XLATensorPtr add(const XLATensorPtr& input, const XLATensorPtr& other, XLATensorPtr add(const XLATensorPtr& input, const at::Scalar& other, const at::Scalar& alpha, c10::optional logical_element_type) { + const torch::lazy::BackendDevice& device = input->GetDevice(); torch::lazy::Value other_constant = XLAGraphExecutor::Get()->GetIrValueForScalar( - other, input->shape(), logical_element_type, input->GetDevice()); + other, + xla::ShapeUtil::MakeScalarShape( + MakeXlaPrimitiveType(*logical_element_type, &device)), + logical_element_type, device); torch::lazy::Value alpha_constant = XLAGraphExecutor::Get()->GetIrValueForScalar( - alpha, input->shape(), logical_element_type, input->GetDevice()); + alpha, + xla::ShapeUtil::MakeScalarShape( + MakeXlaPrimitiveType(*logical_element_type, &device)), + logical_element_type, device); return input->CreateFrom( input->GetIrValue() + other_constant * alpha_constant, logical_element_type);