Skip to content

Commit

Permalink
cover mul
Browse files Browse the repository at this point in the history
(cherry picked from commit f55abc88ae361e89da675a1aa1e4a19e7a5c762a)
  • Loading branch information
Siyuan Liu committed Feb 6, 2024
1 parent 03f8d94 commit 9df9f0c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
19 changes: 18 additions & 1 deletion test/stablehlo/test_unbounded_dynamism.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def forward(self, *args):
ep = torch.export.export(m, args=args, constraints=constraints)
return ep

@unittest.skip("Unbounded Dynamism not supported on add.")
def test_add(self):
args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768)))
constraints = [
Expand All @@ -42,6 +41,24 @@ def test_add(self):
constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r'tensor<\?x197x768xf32>.*tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>',
shlo_text) is not None)

def test_add_scalar(self):
args = (torch.rand((10, 197, 768)), 0.345)
constraints = [
torch.export.dynamic_dim(args[0], 0),
]
ep = self._test_export_dynamism_wrapper(torch.ops.aten.add.Tensor, args,
constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r'tensor<f32>.*tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>',
shlo_text) is not None)

def test_addmm(self):
args = (torch.rand((5)), torch.rand((10, 5)), torch.rand((5, 5)))
Expand Down
23 changes: 18 additions & 5 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,14 +738,20 @@ XLATensorPtr add(const XLATensorPtr& input, const XLATensorPtr& other,
xla::Shape input_shape = input->shape().get();
xla::Shape other_shape = other->shape().get();
torch::lazy::Value constant;
const torch::lazy::BackendDevice& device = input->GetDevice();
if (!input_shape.is_dynamic() && !other_shape.is_dynamic()) {
constant = XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha, other->shape(), logical_element_type, input->GetDevice());
alpha,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(*logical_element_type, &device)),
logical_element_type, device);
} else {
SymIntElements sym_int_elements(other->GetIrValue());
constant = XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha, other->shape(), sym_int_elements, logical_element_type,
input->GetDevice());
alpha,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(*logical_element_type, &device)),
sym_int_elements, logical_element_type, device);
}

return input->CreateFrom(input->GetIrValue() + other->GetIrValue() * constant,
Expand All @@ -755,12 +761,19 @@ XLATensorPtr add(const XLATensorPtr& input, const XLATensorPtr& other,
XLATensorPtr add(const XLATensorPtr& input, const at::Scalar& other,
const at::Scalar& alpha,
c10::optional<at::ScalarType> logical_element_type) {
const torch::lazy::BackendDevice& device = input->GetDevice();
torch::lazy::Value other_constant =
XLAGraphExecutor::Get()->GetIrValueForScalar(
other, input->shape(), logical_element_type, input->GetDevice());
other,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(*logical_element_type, &device)),
logical_element_type, device);
torch::lazy::Value alpha_constant =
XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha, input->shape(), logical_element_type, input->GetDevice());
alpha,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(*logical_element_type, &device)),
logical_element_type, device);
return input->CreateFrom(
input->GetIrValue() + other_constant * alpha_constant,
logical_element_type);
Expand Down

0 comments on commit 9df9f0c

Please sign in to comment.