Skip to content

Commit

Permalink
Fix some more core aten ops (#6462)
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 authored Feb 6, 2024
1 parent 6a0cb71 commit e8dfc86
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 80 deletions.
114 changes: 35 additions & 79 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,54 +529,6 @@ def test_aten_argmin_2(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.argmin, args, kwargs)

@unittest.skip
def test_aten_as_strided_0(self):
args = (
torch.randn((10, 10)).to(torch.float32),
[
0,
1,
],
[
0,
1,
],
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs)

@unittest.skip
def test_aten_as_strided_1(self):
args = (
torch.randn((10, 10)).to(torch.float16),
[
0,
1,
],
[
0,
1,
],
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs)

@unittest.skip
def test_aten_as_strided_2(self):
args = (
torch.randint(0, 10, (10, 10)).to(torch.int32),
[
0,
1,
],
[
0,
1,
],
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs)

@unittest.skip
def test_aten_as_strided_copy_0(self):
args = (
Expand Down Expand Up @@ -1369,41 +1321,44 @@ def test_aten_div_Scalar_2(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.div.Scalar, args, kwargs)

@unittest.skip
def test_aten_div_Scalar_mode_0(self):

def aten_div_Scalar_mode_rounding_mode_trunc(input, other):
return torch.ops.aten.div.Scalar_mode(input, other, rounding_mode='floor')

args = (
torch.randn((10, 10)).to(torch.float32),
0.123,
)
kwargs = dict((
"rounding_mode",
"trunc",
))
run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs)
kwargs = dict()
run_export_and_compare(self, aten_div_Scalar_mode_rounding_mode_trunc, args,
kwargs)

@unittest.skip
def test_aten_div_Scalar_mode_1(self):

def aten_div_Scalar_mode_rounding_mode_trunc(input, other):
return torch.ops.aten.div.Scalar_mode(input, other, rounding_mode='floor')

args = (
torch.randn((10, 10)).to(torch.float16),
0.123,
)
kwargs = dict((
"rounding_mode",
"trunc",
))
run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs)
kwargs = dict()
run_export_and_compare(self, aten_div_Scalar_mode_rounding_mode_trunc, args,
kwargs)

@unittest.skip
def test_aten_div_Scalar_mode_2(self):

def aten_div_Scalar_mode_rounding_mode_trunc(input, other):
return torch.ops.aten.div.Scalar_mode(input, other, rounding_mode='floor')

args = (
torch.randint(0, 10, (10, 10)).to(torch.int32),
0.123,
)
kwargs = dict((
"rounding_mode",
"trunc",
))
run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs)
kwargs = dict()
run_export_and_compare(self, aten_div_Scalar_mode_rounding_mode_trunc, args,
kwargs)

def test_aten_div_Tensor_0(self):
args = (
Expand All @@ -1429,29 +1384,31 @@ def test_aten_div_Tensor_2(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs)

@unittest.skip
def test_aten_div_Tensor_mode_0(self):

def aten_div_Tensor_mode_rounding_mode_trunc(input, other):
return torch.ops.aten.div.Tensor_mode(input, other, rounding_mode='trunc')

args = (
torch.randn((10, 10)).to(torch.float32),
torch.randn((10, 10)).to(torch.float32),
)
kwargs = dict((
"rounding_mode",
"trunc",
))
run_export_and_compare(self, torch.ops.aten.div.Tensor_mode, args, kwargs)
kwargs = dict()
run_export_and_compare(self, aten_div_Tensor_mode_rounding_mode_trunc, args,
kwargs)

@unittest.skip
def test_aten_div_Tensor_mode_1(self):

def aten_div_Tensor_mode_rounding_mode_trunc(input, other):
return torch.ops.aten.div.Tensor_mode(input, other, rounding_mode='trunc')

args = (
torch.randn((10, 10)).to(torch.float16),
torch.randn((10, 10)).to(torch.float16),
)
kwargs = dict((
"rounding_mode",
"trunc",
))
run_export_and_compare(self, torch.ops.aten.div.Tensor_mode, args, kwargs)
kwargs = dict()
run_export_and_compare(self, aten_div_Tensor_mode_rounding_mode_trunc, args,
kwargs)

def test_aten_embedding_0(self):
args = (
Expand Down Expand Up @@ -4039,7 +3996,6 @@ def test_aten_sin_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.sin, args, kwargs)

@unittest.skip
def test_aten_sin_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
kwargs = dict()
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,10 @@ torch_xla::XlaOpVector SiluBackward::Lower(LoweringContext* loctx) const {

torch_xla::XlaOpVector Sin::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Sin(xla_input), loctx);
}

Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,11 @@ xla::Shape SiluBackwardOutputShape(const torch::lazy::Value& grad_output,
}

xla::Shape SinOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
xla::Shape result_shape = GetXlaShape(input);
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
result_shape.set_element_type(xla::PrimitiveType::F32);
}
return result_shape;
}

xla::Shape SinhOutputShape(const torch::lazy::Value& input) {
Expand Down

0 comments on commit e8dfc86

Please sign in to comment.