Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Apr 8, 2024
1 parent 54b40be commit 4e0dcaa
Showing 1 changed file with 18 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ def body_fn(init, limit_value):
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

def test_while_loop_tpu_subtraction_nested(self):

device = xm.xla_device()

def cond_fn(init, limit_value):
return limit_value[0] <= init[0]

def body_fn(init, limit_value):
one_value = torch.ones(1, dtype=torch.int32, device=device)
two_value = limit_value.clone()
return (torch.sub(torch.sub(init, one_value), one_value), two_value)

init = torch.tensor([10], dtype=torch.int32, device=device)
limit_value = torch.tensor([0], dtype=torch.int32, device=device)
res = while_loop(cond_fn, body_fn, (init, limit_value))
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

def test_fori_loop_tpu_addition(self):

xm.mark_step()
Expand Down

0 comments on commit 4e0dcaa

Please sign in to comment.