diff --git a/train.py b/train.py index 8adf504..7f3bad7 100644 --- a/train.py +++ b/train.py @@ -158,8 +158,8 @@ def train(epoch): train_loss += loss.data[0] _, predicted = torch.max(outputs.data, 1) total += targets.size(0) - correct += (lam * predicted.eq(targets_a.data).cpu().sum() - + (1 - lam) * predicted.eq(targets_b.data).cpu().sum()) + correct += (lam * predicted.eq(targets_a.data).cpu().sum().float() + + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float()) optimizer.zero_grad() loss.backward()