Skip to content

Commit

Permalink
Making train code
Browse files Browse the repository at this point in the history
  • Loading branch information
yunseokddi committed Oct 12, 2021
1 parent 71817d3 commit 5087e74
Show file tree
Hide file tree
Showing 12 changed files with 328 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/DeepSVDD_pytorch.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions .idea/deployment.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

80 changes: 80 additions & 0 deletions dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
import numpy as np

from torch.utils.data import DataLoader
from torch.utils import data
from PIL import Image
from torchvision import datasets, transforms


class MNIST_loader(data.Dataset):
"""Preprocessing을 포함한 dataloader를 구성"""

def __init__(self, data, target, transform):
self.data = data
self.target = target
self.transform = transform

def __getitem__(self, index):
x = self.data[index]
y = self.target[index]
if self.transform:
x = Image.fromarray(x.numpy(), mode='L')
x = self.transform(x)
return x, y

def __len__(self):
return len(self.data)


def get_mnist(args, data_dir='.https://tistory1.daumcdn.net/tistory/0/MobileWeb/data/'):
"""get dataloders"""
# min, max values for each class after applying GCN (as the original implementation)
min_max = [(-0.8826567065619495, 9.001545489292527),
(-0.6661464580883915, 20.108062262467364),
(-0.7820454743183202, 11.665100841080346),
(-0.7645772083211267, 12.895051191467457),
(-0.7253923114302238, 12.683235701611533),
(-0.7698501867861425, 13.103278415430502),
(-0.778418217980696, 10.457837397569108),
(-0.7129780970522351, 12.057777597673047),
(-0.8280402650205075, 10.581538445782988),
(-0.7369959242164307, 10.697039838804978)]

transform = transforms.Compose([transforms.ToTensor(),
transforms.Lambda(lambda x: global_contrast_normalization(x)),
transforms.Normalize([min_max[args.normal_class][0]],
[min_max[args.normal_class][1] \
- min_max[args.normal_class][0]])])
train = datasets.MNIST(root=data_dir, train=True, download=True)
test = datasets.MNIST(root=data_dir, train=False, download=True)

x_train = train.data
y_train = train.targets

x_train = x_train[np.where(y_train == args.normal_class)]
y_train = y_train[np.where(y_train == args.normal_class)]

data_train = MNIST_loader(x_train, y_train, transform)
dataloader_train = DataLoader(data_train, batch_size=args.batch_size,
shuffle=True, num_workers=0)

x_test = test.data
y_test = test.targets

# Normal class인 경우 0으로 바꾸고, 나머지는 1로 변환 (정상 vs 비정상 class)
y_test = np.where(y_test == args.normal_class, 0, 1)

data_test = MNIST_loader(x_test, y_test, transform)
dataloader_test = DataLoader(data_test, batch_size=args.batch_size,
shuffle=False, num_workers=0)
return dataloader_train, dataloader_test


def global_contrast_normalization(x):
"""Apply global contrast normalization to tensor. """
mean = torch.mean(x) # mean over all features (pixels) per sample
x -= mean
x_scale = torch.mean(torch.abs(x))
x /= x_scale
return x
34 changes: 34 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import easydict

from dataloader import *
from train import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args = easydict.EasyDict({
'num_epochs': 50,
'num_epochs_ae': 50,
'lr': 1e-3,
'lr_ae': 1e-3,
'weight_decay': 5e-7,
'weight_decay_ae': 5e-3,
'lr_milestones': [50],
'batch_size': 1024,
'pretrain': True,
'latent_dim': 32,
'normal_class': 0
})

if __name__ == '__main__':

# Train/Test Loader 불러오기
dataloader_train, dataloader_test = get_mnist(args)

# Network 학습준비, 구조 불러오기
deep_SVDD = TrainerDeepSVDD(args, dataloader_train, device)

# DeepSVDD를 위한 DeepLearning pretrain 모델로 Weight 학습
if args.pretrain:
deep_SVDD.pretrain()

# 학습된 가중치로 Deep_SVDD모델 Train
net, c = deep_SVDD.train()
71 changes: 71 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch.nn as nn
import torch.nn.functional as F
import torch

from torchsummary import summary


class DeepSVDDNetwork(nn.Module):
def __init__(self, z_dim=32):
super(DeepSVDDNetwork, self).__init__()
self.pool = nn.MaxPool2d(2, 2)
self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
self.fc1 = nn.Linear(4*7*7, z_dim, bias=False)

def forward(self, x):
x = self.conv1(x)
x = self.pool(F.leaky_relu(self.bn1(x)))
x = self.conv2(x)
x = self.pool(F.leaky_relu(self.bn2(x)))
x = x.view(x.size(0), -1)
return self.fc1(x)


class pretrain_autoencoder(nn.Module):
def __init__(self, z_dim=32):
super(pretrain_autoencoder, self).__init__()
self.z_dim = z_dim
self.pool = nn.MaxPool2d(2, 2)

self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
self.fc1 = nn.Linear(4 * 7 * 7, z_dim, bias=False)

self.deconv1 = nn.ConvTranspose2d(2, 4, 5, bias=False, padding=2)
self.bn3 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
self.deconv2 = nn.ConvTranspose2d(4, 8, 5, bias=False, padding=3)
self.bn4 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
self.deconv3 = nn.ConvTranspose2d(8, 1, 5, bias=False, padding=2)

def encoder(self, x):
x = self.conv1(x)
x = self.pool(F.leaky_relu(self.bn1(x)))
x = self.conv2(x)
x = self.pool(F.leaky_relu(self.bn2(x)))
x = x.view(x.size(0), -1)
return self.fc1(x)

def decoder(self, x):
x = x.view(x.size(0), int(self.z_dim / 16), 4, 4)
x = F.interpolate(F.leaky_relu(x), scale_factor=2)
x = self.deconv1(x)
x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2)
x = self.deconv2(x)
x = F.interpolate(F.leaky_relu(self.bn4(x)), scale_factor=2)
x = self.deconv3(x)
return torch.sigmoid(x)

def forward(self, x):
z = self.encoder(x)
x_hat = self.decoder(z)
return x_hat

if __name__ == "__main__":
# model = pretrain_autoencoder().cuda()
model = DeepSVDDNetwork().cuda()
summary(model, input_size=(1, 28, 28))
54 changes: 54 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

from model import pretrain_autoencoder, DeepSVDDNetwork
from tqdm import tqdm


class TrainerDeepSVDD(object):
def __init__(self, args, data_loader, device):
self.args = args
self.train_loader = data_loader
self.device = device

def pretrain(self):
ae = pretrain_autoencoder(self.args.latent_dim).to(self.device)
ae.apply(weights_init_normal)
optimizer = torch.optim.Adam(ae.parameters(), lr=self.args.lr_ae,
weight_decay=self.args.weight_decay_ae)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=self.args.lr_milestones, gamma=0.1)
ae.train()

for epoch in range(self.args.num_epochs_ae):
total_loss = 0
tq = tqdm(self.train_loader, total=len(self.train_loader))

for x, _ in tq:
x = x.float().to(self.device)

optimizer.zero_grad()
x_hat = ae(x)
reconst_loss = torch.mean(torch.sum((x_hat - x) ** 2, dim=tuple(range(1, x_hat.dim()))))
reconst_loss.backward()
optimizer.step()

total_loss += reconst_loss.item()
errors = {
'epoch': epoch,
'train loss': reconst_loss.item()
}

tq.set_postfix(errors)

scheduler.step()
print('total_loss: {:.2f}'.format(total_loss))

self.save_weights_for_DeepSVDD(ae, self.train_loader)


def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1 and classname != 'Conv':
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("Linear") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

0 comments on commit 5087e74

Please sign in to comment.