From d770466aab717b97ed6bf0cfdb97c96e7ab36848 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 29 Feb 2024 01:28:34 +0000 Subject: [PATCH] try mnist test case --- test/test_train_mp_mnist.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 0a3c1df4a8b..8a948039c9d 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -36,10 +36,12 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.test.test_utils as test_utils +import torch_xla.experimental.fori_loop import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import torch_xla.distributed.xla_backend +from torch._higher_order_ops.while_loop import while_loop class MNIST(nn.Module): @@ -196,7 +198,12 @@ def test_loop_fn(loader): train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 - for epoch in range(1, flags.num_epochs + 1): + # for epoch in range(1, flags.num_epochs + 1): + epoch = torch.tensor([flags.num_epochs], dtype=torch.int32, device=device) + ten = torch.ones(1, dtype=torch.int32, device=device) + def cond_fn(epoch): + return epoch >= ten[0] + def body_fn(epoch): xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) @@ -212,6 +219,9 @@ def test_loop_fn(loader): write_xla_metrics=True) if flags.metrics_debug: xm.master_print(met.metrics_report()) + return (torch.sub(epoch[0], 1),) + + while_loop(cond_fn, body_fn, (epoch,)) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))