From 76b4b829bd8905f29d07de1c2b05243250f55062 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:05:36 +0000 Subject: [PATCH] down into cpp --- test/test_train_mp_mnist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 661c47f8f8a..03589a21fb3 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -171,6 +171,7 @@ def test_loop_fn(loader): correct = 0 model.eval() print("loader: ", loader) + print("type loader: ", type(loader)) for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] @@ -181,7 +182,7 @@ def test_loop_fn(loader): accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy - train_device_loader = pl.MpDeviceLoader(train_loader, device) + # train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1):