From bee343f182da8bfeed83c6ff924c9ca3e816fff3 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Wed, 24 Apr 2024 23:24:36 +0000 Subject: [PATCH] down into cpp --- test/test_test_mnist.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index 6fdce77017a..5bf3ed588dd 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -1,6 +1,27 @@ import torch import torchvision +import os +import shutil +import sys +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.distributed.parallel_loader as pl +import torch_xla.utils.utils as xu +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.test.test_utils as test_utils + +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch_xla.distributed.xla_backend + n_epochs = 3 batch_size_train = 8 # 64 batch_size_test = 10 # 1000 @@ -13,12 +34,11 @@ torch.manual_seed(random_seed) ### load data -import torch_xla.utils.utils as xu test_loader = xu.SampleGenerator( - data=(torch.zeros(flags.batch_size, 1, 28, - 28), torch.zeros(flags.batch_size, + data=(torch.zeros(8, 1, 28, + 28), torch.zeros(8, dtype=torch.int64)), - sample_count=10000 // flags.batch_size // xm.xrt_world_size()) + sample_count=1000 // 8 // xm.xrt_world_size()) examples = enumerate(test_loader) batch_idx, (example_data, example_targets) = next(examples)