diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index fab43ebaaa95..0a3c1df4a8b1 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -171,8 +171,16 @@ def test_loop_fn(loader): correct = 0 model.eval() interator_local = 0 + print("loader: ", loader) print("type loader: ", type(loader)) + + loop_length = loader.__len__ + test_dataset = enumerate(loader) + + print("test_dataset: ", test_dataset) + print("type test_dataset: ", type(test_dataset)) + for data, target in loader: interator_local = interator_local + 1 print("test_loop_fn for loop: ", interator_local)