Skip to content

Commit

Permalink
add unbounded dynamism test for some aten ops, support dynamism on add (
Browse files Browse the repository at this point in the history
#6443)

Co-authored-by: Siyuan Liu <[email protected]>
  • Loading branch information
lsy323 and Siyuan Liu authored Feb 9, 2024
1 parent 1cd3f46 commit 8d91ff5
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 46 deletions.
10 changes: 3 additions & 7 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,6 @@ function run_pt_xla_debug {
PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
}

function run_stablehlo_compile {
echo "Running in StableHlo Compile mode: $@"
XLA_STABLEHLO_COMPILE=1 run_test "$@"
}

function run_xla_backend_mp {
echo "Running XLA backend multiprocessing test: $@"
MASTER_ADDR=localhost MASTER_PORT=6000 run_test "$@"
Expand Down Expand Up @@ -201,8 +196,9 @@ function run_xla_op_tests3 {
# TODO(qihqi): this test require tensorflow to run. need to setup separate
# CI with tf.
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_inference.py"
run_stablehlo_compile "$CDIR/stablehlo/test_stablehlo_compile.py"
run_stablehlo_compile "$CDIR/stablehlo/test_implicit_broadcasting.py"
run_test "$CDIR/stablehlo/test_stablehlo_compile.py"
run_test "$CDIR/stablehlo/test_implicit_broadcasting.py"
run_test "$CDIR/stablehlo/test_unbounded_dynamism.py"
run_test "$CDIR/spmd/test_xla_sharding.py"
run_test "$CDIR/spmd/test_xla_sharding_hlo.py"
run_test "$CDIR/spmd/test_xla_virtual_device.py"
Expand Down
4 changes: 2 additions & 2 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,13 +819,13 @@ def test_mark_sharding_ir(self):
(0, 1))
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
self.assertIn(
'%custom-call.10 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.9), custom_call_target="Sharding", sharding=',
'%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6), custom_call_target="Sharding", sharding=',
hlo)

actual += 0
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
self.assertIn(
'%add.15 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.13, f32[1,128]{1,0} %broadcast.14)',
'%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.10, f32[1,128]{1,0} %broadcast.11)',
hlo)

self.assertTrue(torch.allclose(expected, actual.cpu()))
Expand Down
1 change: 1 addition & 0 deletions test/stablehlo/test_pt2e_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_per_channel_qdq(self):
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1)
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1)

@unittest.skip("Failed because PT2E BC break change on constant folding.")
def test_resnet18(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
Expand Down
12 changes: 8 additions & 4 deletions test/stablehlo/test_stablehlo_compile.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
import unittest

import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch
import torchvision
import unittest
import torch_xla.debug.metrics as met
import torch_xla.debug.metrics_compare_utils as mcu
import numpy as np
import torchvision

os.environ['XLA_STABLEHLO_COMPILE'] = '1'


class StableHloCompileTest(unittest.TestCase):
Expand Down
243 changes: 216 additions & 27 deletions test/stablehlo/test_unbounded_dynamism.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import sys
import unittest

Expand All @@ -14,45 +15,233 @@

class UnboundedDynamismExportTest(unittest.TestCase):

def test_simply_add(self):
a = torch.tensor([[1, 2], [2, 4]], device=device)
torch_xla._XLAC._xla_mark_dynamic(a, 0)
b = torch.tensor([[1, 2], [2, 4]], device=device)
torch_xla._XLAC._xla_mark_dynamic(b, 0)
c = a * b
hlo_content = torch_xla._XLAC._get_xla_tensors_hlo([c])
self.assertTrue(
"(p0.1: s64[?,2], p1.2: s64[?,2]) -> (s64[?,2])" in hlo_content)

def test_export_dynamism(self):
def _test_export_dynamism_wrapper(self, f, args, constraints):

class M(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, x, y):
return x * y
def forward(self, *args):
return f(*args)

m = M()
ep = torch.export.export(m, args=args, constraints=constraints)
return ep

def test_add(self):
args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768)))
constraints = [
torch.export.dynamic_dim(args[0], 0),
torch.export.dynamic_dim(args[1], 0),
torch.export.dynamic_dim(args[0],
0) == torch.export.dynamic_dim(args[1], 0),
]
ep = self._test_export_dynamism_wrapper(torch.ops.aten.add.Tensor, args,
constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r'tensor<\?x197x768xf32>.*tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>',
shlo_text) is not None)

def test_add_scalar(self):
args = (torch.rand((10, 197, 768)), 0.345)
constraints = [
torch.export.dynamic_dim(args[0], 0),
]
ep = self._test_export_dynamism_wrapper(torch.ops.aten.add.Tensor, args,
constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r'tensor<f32>.*tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>',
shlo_text) is not None)

@unittest.skip("Unbounded Dynamism not supported on addmm.")
def test_addmm(self):
args = (torch.rand((5)), torch.rand((10, 5)), torch.rand((5, 5)))
constraints = [
torch.export.dynamic_dim(args[1], 0),
]
ep = self._test_export_dynamism_wrapper(torch.ops.aten.addmm.default, args,
constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r'tensor<\?x5xf32>.*->.*tensor<\?x5xf32>', shlo_text)
is not None)

def test_bmm(self):
args = (
torch.rand((24, 197, 64)),
torch.rand((24, 64, 197)),
)
constraints = [
torch.export.dynamic_dim(args[0], 0),
torch.export.dynamic_dim(args[1], 0),
torch.export.dynamic_dim(args[0],
0) == torch.export.dynamic_dim(args[1], 0),
]
ep = self._test_export_dynamism_wrapper(torch.ops.aten.bmm.default, args,
constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r'%arg.: tensor<\?x64x197xf32>.*%arg.: tensor<\?x197x64xf32>.*->.*tensor<\?x197x197xf32>',
shlo_text) is not None)

def test_cat(self):
args = ([torch.rand((10, 1, 768)), torch.rand((10, 196, 768))], 1)
constraints = [
torch.export.dynamic_dim(args[0][0], 0),
torch.export.dynamic_dim(args[0][1], 0),
torch.export.dynamic_dim(args[0][0],
0) == torch.export.dynamic_dim(args[0][1], 0),
]
ep = self._test_export_dynamism_wrapper(torch.ops.aten.cat.default, args,
constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r'%arg.: tensor<\?x196x768xf32>.*%arg.: tensor<\?x1x768xf32>.*->.*tensor<\?x197x768xf32>',
shlo_text) is not None)

@unittest.skip("Unbounded Dynamism not supported on conv.")
def test_conv(self):
args = (
torch.rand((10, 3, 224, 224)),
torch.rand((5, 3, 16, 16)),
torch.rand((5)),
[16, 16],
[0, 0],
[1, 1],
False,
[0, 0],
1,
)
constraints = [
torch.export.dynamic_dim(args[0], 0),
torch.export.dynamic_dim(args[0], 0) < 16,
]
ep = self._test_export_dynamism_wrapper(torch.ops.aten.convolution.default,
args, constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r'tensor<\?x3x224x224xf32>.*->.*tensor<\?x5x14x14xf32>',
shlo_text) is not None)

def test_div(self):
args = (torch.rand((10, 12, 197)), 8.0)
constraints = [
torch.export.dynamic_dim(args[0], 0),
]
ep = self._test_export_dynamism_wrapper(torch.ops.aten.div.Tensor, args,
constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r'tensor<\?x12x197xf32>.*->.*tensor<\?x12x197xf32>',
shlo_text) is not None)

@unittest.skip("xla::Erf doesn't support unbounded dynamic input.")
def test_gelu(self):
args = (torch.rand((3, 5)),)
constraints = [
torch.export.dynamic_dim(args[0], 0),
]
ep = self._test_export_dynamism_wrapper(torch.ops.aten.gelu, args,
constraints)
shlo_module = exported_program_to_stablehlo(ep)
# shlo_text = shlo_module.get_stablehlo_text()
# self.assertTrue(
# "(%arg0: tensor<?x2xi64>, %arg1: tensor<?x2xi64>) -> tensor<?x2xi64>" in
# shlo_text)

@unittest.skip("Unbounded Dynamism not supported on view.")
def test_native_layer_norm(self):
args = (
torch.rand((10, 197, 768)),
[768],
torch.rand((768)),
torch.rand((768)),
1e-12,
)
constraints = [
torch.export.dynamic_dim(args[0], 0),
]
ep = self._test_export_dynamism_wrapper(
torch.ops.aten.native_layer_norm.default, args, constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r"%arg.: tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>",
shlo_text) is not None)

def test_permute(self):
args = (torch.rand((10, 197, 12, 64)), [0, 2, 1, 3])
constraints = [
torch.export.dynamic_dim(args[0], 0),
]
ep = self._test_export_dynamism_wrapper(torch.ops.aten.permute.default,
args, constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r"%arg.: tensor<\?x197x12x64xf32>.*->.*tensor<\?x12x197x64xf32>",
shlo_text) is not None)

@unittest.skip("Unbounded Dynamism not supported on select..")
def test_select(self):
args = (torch.rand((10, 197, 768)), 1, 0)
constraints = [
torch.export.dynamic_dim(args[0], 0),
]
ep = self._test_export_dynamism_wrapper(torch.ops.aten.select.int, args,
constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r"%arg.: tensor<\?x197x768xf32>.*->.*tensor<\?x768xf32>",
shlo_text) is not None)

@unittest.skip("Unbounded Dynamism not supported on slice.")
def test_slice(self):
args = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
constraints = [
torch.export.dynamic_dim(args[0], 0),
]
ep = self._test_export_dynamism_wrapper(torch.ops.aten.slice.Tensor, args,
constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r"%arg.: tensor<\?x3x224x224xf32>.*->.*tensor<\?x3x224x224xf32>",
shlo_text) is not None)

example_args = (torch.tensor([[1, 2], [2, 4]], device=device),
torch.tensor([[1, 2], [2, 4]], device=device))
@unittest.skip("Unbounded Dynamism not supported on softmax.")
def test_softmax(self):
args = (torch.rand((10, 12, 197, 197)), -1, False)
constraints = [
# First dimension of each input is a dynamic batch size
torch.export.dynamic_dim(example_args[0], 0),
torch.export.dynamic_dim(example_args[1], 0),
# The dynamic batch size between the inputs are equal
torch.export.dynamic_dim(example_args[0],
0) == torch.export.dynamic_dim(
example_args[1], 0),
torch.export.dynamic_dim(args[0], 0),
]
ep = torch.export.export(M(), args=example_args, constraints=constraints)
ep = self._test_export_dynamism_wrapper(torch.ops.aten._softmax.default,
args, constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text("forward")
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
"(%arg0: tensor<?x2xi64>, %arg1: tensor<?x2xi64>) -> tensor<?x2xi64>" in
shlo_text)
re.search(
r"%arg.: tensor<\?x12x197x197xf32>.*->.*tensor<\?x12x197x197xf32>",
shlo_text) is not None)


if __name__ == '__main__':
if __name__ == "__main__":
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
29 changes: 23 additions & 6 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,14 +751,20 @@ XLATensorPtr add(const XLATensorPtr& input, const XLATensorPtr& other,
xla::Shape input_shape = input->shape().get();
xla::Shape other_shape = other->shape().get();
torch::lazy::Value constant;
const torch::lazy::BackendDevice& device = input->GetDevice();
if (!input_shape.is_dynamic() && !other_shape.is_dynamic()) {
constant = XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha, other->shape(), logical_element_type, input->GetDevice());
alpha,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(other->dtype(), &device)),
logical_element_type, device);
} else {
SymIntElements sym_int_elements(other->GetIrValue());
constant = XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha, other->shape(), sym_int_elements, logical_element_type,
input->GetDevice());
alpha,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(other->dtype(), &device)),
sym_int_elements, logical_element_type, device);
}

return input->CreateFrom(input->GetIrValue() + other->GetIrValue() * constant,
Expand All @@ -768,12 +774,19 @@ XLATensorPtr add(const XLATensorPtr& input, const XLATensorPtr& other,
XLATensorPtr add(const XLATensorPtr& input, const at::Scalar& other,
const at::Scalar& alpha,
c10::optional<at::ScalarType> logical_element_type) {
const torch::lazy::BackendDevice& device = input->GetDevice();
torch::lazy::Value other_constant =
XLAGraphExecutor::Get()->GetIrValueForScalar(
other, input->shape(), logical_element_type, input->GetDevice());
other,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(input->dtype(), &device)),
logical_element_type, device);
torch::lazy::Value alpha_constant =
XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha, input->shape(), logical_element_type, input->GetDevice());
alpha,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(input->dtype(), &device)),
logical_element_type, device);
return input->CreateFrom(
input->GetIrValue() + other_constant * alpha_constant,
logical_element_type);
Expand Down Expand Up @@ -1860,8 +1873,12 @@ XLATensorPtr mul(const XLATensorPtr& input, const XLATensorPtr& other,

XLATensorPtr mul(const XLATensorPtr& input, const at::Scalar& other,
c10::optional<at::ScalarType> logical_element_type) {
const torch::lazy::BackendDevice& device = input->GetDevice();
torch::lazy::Value constant = XLAGraphExecutor::Get()->GetIrValueForScalar(
other, input->shape(), logical_element_type, input->GetDevice());
other,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(input->dtype(), &device)),
logical_element_type, device);
return input->CreateFrom(input->GetIrValue() * constant,
logical_element_type);
}
Expand Down

0 comments on commit 8d91ff5

Please sign in to comment.