Skip to content

Commit

Permalink
[Fori_loop|While_loop] Add nested test (#6807)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Mar 25, 2024
1 parent 7b9177d commit 22fe1dc
Showing 1 changed file with 18 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,24 @@ def body_fn(init, limit_value):
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

def test_while_loop_tpu_subtraction_nested(self):

device = xm.xla_device()

def cond_fn(init, limit_value):
return limit_value[0] <= init[0]

def body_fn(init, limit_value):
one_value = torch.ones(1, dtype=torch.int32, device=device)
two_value = limit_value.clone()
return (torch.sub(torch.sub(init, one_value), one_value), two_value)

init = torch.tensor([10], dtype=torch.int32, device=device)
limit_value = torch.tensor([0], dtype=torch.int32, device=device)
res = while_loop(cond_fn, body_fn, (init, limit_value))
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)


if __name__ == '__main__':
test = unittest.main()
Expand Down

0 comments on commit 22fe1dc

Please sign in to comment.