Skip to content

Commit

Permalink
down into cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Apr 24, 2024
1 parent 760d49c commit 9f7ddc2
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion test/test_train_mp_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def test_loop_fn(loader):
total_samples = 0
correct = 0
model.eval()
print("loader: ", loader)
for data, target in loader:
output = model(data)
pred = output.max(1, keepdim=True)[1]
Expand All @@ -185,7 +186,7 @@ def test_loop_fn(loader):
accuracy, max_accuracy = 0.0, 0.0
for epoch in range(1, flags.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
train_loop_fn(train_device_loader, epoch)
# train_loop_fn(train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))

accuracy = test_loop_fn(test_device_loader)
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def new_body_fn(*carried_inputs):
# print("res: ", res)
# return list(body_fn(*carried_inputs)).extend(additional_inputs)
# self.named_parameters
weight = self.linear.weight
# weight = self.linear.weight
res = list(body_fn(*carried_inputs))
# print("res: ", res)
# trynewres = res[:-1] + [res[-1]]
Expand Down

0 comments on commit 9f7ddc2

Please sign in to comment.