Skip to content

Commit

Permalink
Ran formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
Kappibw committed Mar 4, 2024
1 parent c4885ce commit cb446a6
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
class Net(nn.Module):
def __init__(self):


super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d( 32, 64, 3, 1)
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
Expand All @@ -36,14 +35,14 @@ def forward(self, x):
return output


def train(args, model, device, train_loader, optimizer, epoch):
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):

data, target = data.to(device), target.to(device)
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss( output, target)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
Expand All @@ -52,7 +51,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)
Expand All @@ -70,11 +69,11 @@ def test(model, device, test_loader, epoch):

data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss( output, target, reduction="sum").item() # sum up batch loss
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred) ).sum().item()
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset )
test_loss /= len(test_loader.dataset)

print(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
Expand Down Expand Up @@ -132,7 +131,7 @@ def main():
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(args.epochs):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader, epoch)
Expand Down

0 comments on commit cb446a6

Please sign in to comment.