diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 8a948039c9d..0f0d7dc7757 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -201,9 +201,9 @@ def test_loop_fn(loader): # for epoch in range(1, flags.num_epochs + 1): epoch = torch.tensor([flags.num_epochs], dtype=torch.int32, device=device) ten = torch.ones(1, dtype=torch.int32, device=device) - def cond_fn(epoch): - return epoch >= ten[0] - def body_fn(epoch): + def cond_fn(x): + return x[0] >= ten[0] + def body_fn(x): xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) @@ -219,7 +219,7 @@ def body_fn(epoch): write_xla_metrics=True) if flags.metrics_debug: xm.master_print(met.metrics_report()) - return (torch.sub(epoch[0], 1),) + return (torch.sub(x[0], 1),) while_loop(cond_fn, body_fn, (epoch,))