Skip to content

Commit

Permalink
down into cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Apr 24, 2024
1 parent 76b4b82 commit 042d37c
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions test/test_train_mp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,21 +186,31 @@ def test_loop_fn(loader):
test_device_loader = pl.MpDeviceLoader(test_loader, device)
accuracy, max_accuracy = 0.0, 0.0
for epoch in range(1, flags.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
# xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
# train_loop_fn(train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))

# xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))
accuracy = test_loop_fn(test_device_loader)
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
epoch, test_utils.now(), accuracy))
# xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(epoch, test_utils.now(), accuracy))
max_accuracy = max(accuracy, max_accuracy)
test_utils.write_to_summary(
writer,
epoch,
dict_to_write={'Accuracy/test': accuracy},
write_xla_metrics=True)
if flags.metrics_debug:
xm.master_print(met.metrics_report())
# test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True)
# if flags.metrics_debug: xm.master_print(met.metrics_report())

# ### fori_loop
# # torch.set_grad_enabled(False)
# new_test_device_loader = pl.MpDeviceLoader(test_loader, device)
# upper = torch.tensor([flags.num_epochs + 1], dtype=torch.int32, device=device) # flags.num_epochs + 1
# lower = torch.tensor([1], dtype=torch.int32, device=device) # 1
# init_val = torch.tensor([1], dtype=torch.int32, device=device)
# # l_in_0 = torch.randn(10, device=xm.xla_device()) # test_device_loader
# # linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device())
# def body_fun(test_device_loader):
# accuracy = test_loop_fn(test_device_loader)
# max_accuracy = max(accuracy, max_accuracy)
# return max_accuracy

# upper_, lower_, one_value_, add_res_x_, l_in_i_plus_1_, weight_, bias_, l_out_ = fori_loop(
# upper, lower, body_fun, init_val, new_test_device_loader)


test_utils.close_summary_writer(writer)
xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
Expand Down

0 comments on commit 042d37c

Please sign in to comment.