Skip to content

Commit

Permalink
Merge branch 'master' into chunnienc/init-tag-tensor-api
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc committed Mar 25, 2024
2 parents f8d35cc + 22fe1dc commit 4d0263d
Showing 1 changed file with 18 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,24 @@ def body_fn(init, limit_value):
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

def test_while_loop_tpu_subtraction_nested(self):

device = xm.xla_device()

def cond_fn(init, limit_value):
return limit_value[0] <= init[0]

def body_fn(init, limit_value):
one_value = torch.ones(1, dtype=torch.int32, device=device)
two_value = limit_value.clone()
return (torch.sub(torch.sub(init, one_value), one_value), two_value)

init = torch.tensor([10], dtype=torch.int32, device=device)
limit_value = torch.tensor([0], dtype=torch.int32, device=device)
res = while_loop(cond_fn, body_fn, (init, limit_value))
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)


if __name__ == '__main__':
test = unittest.main()
Expand Down

0 comments on commit 4d0263d

Please sign in to comment.