Skip to content

Commit

Permalink
down into cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Apr 24, 2024
1 parent 9f7ddc2 commit 76b4b82
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion test/test_train_mp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def test_loop_fn(loader):
correct = 0
model.eval()
print("loader: ", loader)
print("type loader: ", type(loader))
for data, target in loader:
output = model(data)
pred = output.max(1, keepdim=True)[1]
Expand All @@ -181,7 +182,7 @@ def test_loop_fn(loader):
accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
return accuracy

train_device_loader = pl.MpDeviceLoader(train_loader, device)
# train_device_loader = pl.MpDeviceLoader(train_loader, device)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
accuracy, max_accuracy = 0.0, 0.0
for epoch in range(1, flags.num_epochs + 1):
Expand Down

0 comments on commit 76b4b82

Please sign in to comment.