Skip to content

Commit

Permalink
Refactor configuration to yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
ArcCha committed May 9, 2018
1 parent f6a33fb commit 80aa65b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
train_config.yaml
notebooks/*.json
notebooks/*.state
FullyConnected
Expand Down
42 changes: 19 additions & 23 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,32 @@
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import yaml
from cuda import *
from data import *
from net import *
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm

config_path = Path('train_config.yaml')
config = None
with config_path.open('r') as f:
config = yaml.load(f)
H = {} # Training history and statistics
CUDA, device = get_cuda_if_available()
H['cuda'] = CUDA

train_path = Path(
'/home/arccha/.kaggle/competitions/digit-recognizer/train.csv')
if not train_path.exists():
train_path = '../train.csv'
else:
train_path = str(train_path)
DATA_NUM = 42000 # 42000 - max
H['data_num'] = DATA_NUM
VALIDATION_NUM = 4200
H['validation_num'] = VALIDATION_NUM
BATCH_SIZE = 32
H['batch_size'] = BATCH_SIZE
train_path = config['train_path']
H['data_num'] = config['data_num']
H['validation_num'] = config['validation_num']
H['batch_size'] = config['batch_size']
train_dataset, validation_dataset = train_validation_split(
train_path, max_rows=DATA_NUM, validation_num=VALIDATION_NUM, pretransform=True)
train_path, max_rows=config['data_num'], validation_num=config['validation_num'], pretransform=True)
train_loader = DataLoader(dataset=train_dataset,
batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
batch_size=config['batch_size'], shuffle=True, num_workers=4, pin_memory=True)
validation_loader = DataLoader(
dataset=validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, pin_memory=True)
dataset=validation_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=1, pin_memory=True)

validation_classes = np.zeros(10)
for x, y in tqdm(validation_loader, desc='Validation stats'):
Expand All @@ -48,19 +45,17 @@
net_dir.mkdir(parents=True, exist_ok=True)
net.to(device)

LR = 0.001
optimizer = torch.optim.Adam(net.parameters(), lr=LR)
optimizer = torch.optim.Adam(net.parameters(), lr=config['learning_rate'])
H['optimizer'] = str(optimizer)
criterion = nn.CrossEntropyLoss()
H['criterion'] = str(criterion)

EPOCH_NUM = 50
H['epoch_num'] = EPOCH_NUM
H['epoch_num'] = config['epoch_num']
H['loss'] = []
H['train_acc'] = []
H['test_acc'] = []
start = time.process_time()
for epoch in tqdm(range(EPOCH_NUM), desc='Total'):
for epoch in tqdm(range(config['epoch_num']), desc='Total'):
running_loss = 0.0
for x, y in tqdm(train_loader, desc='Epoch ' + str(epoch)):
x = x.to(device)
Expand All @@ -71,22 +66,23 @@
loss.backward()
optimizer.step()
running_loss += loss.item()
H['loss'].append(running_loss / (DATA_NUM / BATCH_SIZE))
H['loss'].append(
running_loss / (config['data_num'] / config['batch_size']))
acc = 0
for x, y_true in tqdm(train_loader, desc='Train acc ' + str(epoch)):
x = x.to(device)
y_true = y_true.to(device)
y_pred = net(x).argmax(dim=1)
acc += y_true.eq(y_pred).sum()
acc = float(acc) / (len(train_loader) * BATCH_SIZE)
acc = float(acc) / (len(train_loader) * config['batch_size'])
H['train_acc'].append(acc)
acc = 0
for x, y_true in tqdm(validation_loader, desc='Validation ' + str(epoch)):
x = x.to(device)
y_true = y_true.to(device)
y_pred = net(x).argmax(dim=1)
acc += y_true.eq(y_pred).sum()
acc = float(acc) / VALIDATION_NUM
acc = float(acc) / config['validation_num']
H['test_acc'].append(acc)
net_state_path = net_dir.joinpath('net' + str(epoch) + '.state')
net_state_path.touch(exist_ok=True)
Expand Down
6 changes: 6 additions & 0 deletions train_config.yaml.sample
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
train_path: '/home/arccha/.kaggle/competitions/digit-recognizer/train.csv'
data_num: 42000
validation_num: 4200
batch_size: 32
learning_rate: 0.001
epoch_num: 50

0 comments on commit 80aa65b

Please sign in to comment.