Skip to content

Commit

Permalink
down into cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Apr 24, 2024
1 parent 2de7464 commit e25e9d2
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions test/test_test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,13 @@

### load data
test_loader = xu.SampleGenerator(
data=(torch.zeros(8, 1, 28,
28), torch.zeros(8,
dtype=torch.int64)),
data=(torch.zeros(8, 1, 28,28), torch.zeros(8, dtype=torch.int64)),
sample_count=1000 // 8 // xm.xrt_world_size())

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

example_data.shape
print("shape: ", example_data.shape)

### build model
import torch.nn as nn
Expand Down Expand Up @@ -90,7 +88,7 @@ def test():
test_loss += F.nll_loss(output, target, size_average=False).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
test_loss /= 20
test_losses.append(test_loss)
print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
Expand Down

0 comments on commit e25e9d2

Please sign in to comment.