diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 5a19800fe39..7bc355ae7ba 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -35,15 +35,13 @@ ### load data test_loader = xu.SampleGenerator( - data=(torch.zeros(8, 1, 28, - 28), torch.zeros(8, - dtype=torch.int64)), + data=(torch.zeros(8, 1, 28,28), torch.zeros(8, dtype=torch.int64)), sample_count=1000 // 8 // xm.xrt_world_size()) examples = enumerate(test_loader) batch_idx, (example_data, example_targets) = next(examples) -example_data.shape +print("shape: ", example_data.shape) ### build model import torch.nn as nn @@ -90,7 +88,7 @@ def test(): test_loss += F.nll_loss(output, target, size_average=False).item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).sum() - test_loss /= len(test_loader.dataset) + test_loss /= 20 test_losses.append(test_loss) print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset),