Skip to content

Commit

Permalink
Fix div lowering and core aten test script enhancement (#6873)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Liu <[email protected]>
  • Loading branch information
lsy323 and Siyuan Liu authored Apr 3, 2024
1 parent c54367c commit 895b0c2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
9 changes: 9 additions & 0 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True):
output2_cpu = output2.detach().cpu()
if output2_cpu.dtype != output1.dtype:
output2_cpu = output2_cpu.to(output1.dtype)
testcase.assertEqual(output1.shape, output2.shape)
testcase.assertTrue(
torch.allclose(
output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan))
Expand Down Expand Up @@ -1174,6 +1175,14 @@ def test_aten_div_Tensor_2(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs)

def test_aten_div_Tensor_3(self):
args = (
torch.rand(1, 3, 4, 1),
torch.rand(10),
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs)

def test_aten_div_Tensor_mode_0(self):

def aten_div_Tensor_mode_rounding_mode_trunc(input, other):
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,9 @@ torch::lazy::NodePtr Div(const torch::lazy::Value& input,
return node.ReturnOp(BuildDiv(xla_input, xla_divisor), loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::div), {input, divisor},
GetXlaShape(input), std::move(lower_fn));
XlaHelpers::GetPromotedBinaryOpShape(GetXlaShape(input),
GetXlaShape(divisor)),
std::move(lower_fn));
}

torch::lazy::NodePtr MaxUnary(const torch::lazy::Value& input) {
Expand Down

0 comments on commit 895b0c2

Please sign in to comment.