diff --git a/train.py b/train.py index 2cfae71..b24b50d 100644 --- a/train.py +++ b/train.py @@ -11,10 +11,9 @@ class Net(nn.Module): def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 32, 3, 1) - self.conv2 = nn.Conv2d( 32, 64, 3, 1) + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout(0.25) self.dropout2 = nn.Dropout(0.5) self.fc1 = nn.Linear(9216, 128) @@ -36,14 +35,14 @@ def forward(self, x): return output -def train(args, model, device, train_loader, optimizer, epoch): +def train(args, model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) + data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) - loss = F.nll_loss( output, target) + loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: @@ -52,7 +51,7 @@ def train(args, model, device, train_loader, optimizer, epoch): epoch, batch_idx * len(data), len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), + 100.0 * batch_idx / len(train_loader), loss.item(), ) ) @@ -70,11 +69,11 @@ def test(model, device, test_loader, epoch): data, target = data.to(device), target.to(device) output = model(data) - test_loss += F.nll_loss( output, target, reduction="sum").item() # sum up batch loss + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred) ).sum().item() + correct += pred.eq(target.view_as(pred)).sum().item() - test_loss /= len(test_loader.dataset ) + test_loss /= len(test_loader.dataset) print( "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( @@ -132,7 +131,7 @@ def main(): model = Net().to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) - scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch in range(args.epochs): train(args, model, device, train_loader, optimizer, epoch) test(model, device, test_loader, epoch)