diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 3b078d22fab..661c47f8f8a 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -170,6 +170,7 @@ def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() + print("loader: ", loader) for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] @@ -185,7 +186,7 @@ def test_loop_fn(loader): accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) - train_loop_fn(train_device_loader, epoch) + # train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index 3d0fff3d49f..3bab0bbc665 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -117,7 +117,7 @@ def new_body_fn(*carried_inputs): # print("res: ", res) # return list(body_fn(*carried_inputs)).extend(additional_inputs) # self.named_parameters - weight = self.linear.weight + # weight = self.linear.weight res = list(body_fn(*carried_inputs)) # print("res: ", res) # trynewres = res[:-1] + [res[-1]]