Skip to content

Commit

Permalink
only example code insert without using original tensor code
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Feb 29, 2024
1 parent 1b4c075 commit 1ca8cc8
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions test/test_train_mp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,30 +198,32 @@ def test_loop_fn(loader):
train_device_loader = pl.MpDeviceLoader(train_loader, device)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
accuracy, max_accuracy = 0.0, 0.0
# for epoch in range(1, flags.num_epochs + 1):
epoch = torch.tensor([flags.num_epochs], dtype=torch.int32, device=device)
for epoch in range(1, flags.num_epochs + 1):
xm.master_print('Epoch {} train begin {}') # .format(epoch, test_utils.now()))
train_loop_fn(train_device_loader, epoch)
xm.master_print('Epoch {} train end {}') # .format(epoch, test_utils.now()))

accuracy = test_loop_fn(test_device_loader)
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}') # .format(
epoch, test_utils.now(), accuracy))
max_accuracy = max(accuracy, max_accuracy)
test_utils.write_to_summary(
writer,
epoch,
dict_to_write={'Accuracy/test': accuracy},
write_xla_metrics=True)
if flags.metrics_debug:
xm.master_print(met.metrics_report())

epoch = torch.tensor([5], dtype=torch.int32, device=device) # flags.num_epochs
ten = torch.ones(1, dtype=torch.int32, device=device)
def cond_fn(x):
return x[0] >= ten[0]
def body_fn(x):
# xm.master_print('Epoch {} train begin {}') # .format(epoch, test_utils.now()))
# train_loop_fn(train_device_loader, epoch)
# xm.master_print('Epoch {} train end {}') # .format(epoch, test_utils.now()))

# accuracy = test_loop_fn(test_device_loader)
# xm.master_print('Epoch {} test end {}, Accuracy={:.2f}') # .format(
# epoch, test_utils.now(), accuracy))
# max_accuracy = max(accuracy, max_accuracy)
# test_utils.write_to_summary(
# writer,
# epoch,
# dict_to_write={'Accuracy/test': accuracy},
# write_xla_metrics=True)
# if flags.metrics_debug:
# xm.master_print(met.metrics_report())
return (torch.sub(x[0], 1),)

while_loop(cond_fn, body_fn, (epoch,))
res = while_loop(cond_fn, body_fn, (epoch,))
print("res: ", res)

test_utils.close_summary_writer(writer)
xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
Expand Down

0 comments on commit 1ca8cc8

Please sign in to comment.