From 044c56067d2d55fdab8876d0627ddbd10191a46c Mon Sep 17 00:00:00 2001 From: Jeremy Date: Fri, 7 Sep 2018 12:37:17 -0400 Subject: [PATCH] fix 0 percent training accuracy --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 8adf504..7f3bad7 100644 --- a/train.py +++ b/train.py @@ -158,8 +158,8 @@ def train(epoch): train_loss += loss.data[0] _, predicted = torch.max(outputs.data, 1) total += targets.size(0) - correct += (lam * predicted.eq(targets_a.data).cpu().sum() - + (1 - lam) * predicted.eq(targets_b.data).cpu().sum()) + correct += (lam * predicted.eq(targets_a.data).cpu().sum().float() + + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float()) optimizer.zero_grad() loss.backward()