Skip to content

Commit

Permalink
try mnist test case
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Feb 29, 2024
1 parent 03f9a62 commit d770466
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion test/test_train_mp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import torch_xla.experimental.fori_loop

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch_xla.distributed.xla_backend
from torch._higher_order_ops.while_loop import while_loop


class MNIST(nn.Module):
Expand Down Expand Up @@ -196,7 +198,12 @@ def test_loop_fn(loader):
train_device_loader = pl.MpDeviceLoader(train_loader, device)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
accuracy, max_accuracy = 0.0, 0.0
for epoch in range(1, flags.num_epochs + 1):
# for epoch in range(1, flags.num_epochs + 1):
epoch = torch.tensor([flags.num_epochs], dtype=torch.int32, device=device)
ten = torch.ones(1, dtype=torch.int32, device=device)
def cond_fn(epoch):
return epoch >= ten[0]
def body_fn(epoch):
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
train_loop_fn(train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))
Expand All @@ -212,6 +219,9 @@ def test_loop_fn(loader):
write_xla_metrics=True)
if flags.metrics_debug:
xm.master_print(met.metrics_report())
return (torch.sub(epoch[0], 1),)

while_loop(cond_fn, body_fn, (epoch,))

test_utils.close_summary_writer(writer)
xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
Expand Down

0 comments on commit d770466

Please sign in to comment.