Skip to content

Commit

Permalink
down into cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Apr 24, 2024
1 parent 1b81829 commit bee343f
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions test/test_test_mnist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
import torch
import torchvision

import os
import shutil
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch_xla.distributed.xla_backend

n_epochs = 3
batch_size_train = 8 # 64
batch_size_test = 10 # 1000
Expand All @@ -13,12 +34,11 @@
torch.manual_seed(random_seed)

### load data
import torch_xla.utils.utils as xu
test_loader = xu.SampleGenerator(
data=(torch.zeros(flags.batch_size, 1, 28,
28), torch.zeros(flags.batch_size,
data=(torch.zeros(8, 1, 28,
28), torch.zeros(8,
dtype=torch.int64)),
sample_count=10000 // flags.batch_size // xm.xrt_world_size())
sample_count=1000 // 8 // xm.xrt_world_size())

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
Expand Down

0 comments on commit bee343f

Please sign in to comment.