-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
- Loading branch information
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
import numpy as np | ||
|
||
from torch.utils.data import DataLoader | ||
from torch.utils import data | ||
from PIL import Image | ||
from torchvision import datasets, transforms | ||
|
||
|
||
class MNIST_loader(data.Dataset): | ||
"""Preprocessing을 포함한 dataloader를 구성""" | ||
|
||
def __init__(self, data, target, transform): | ||
self.data = data | ||
self.target = target | ||
self.transform = transform | ||
|
||
def __getitem__(self, index): | ||
x = self.data[index] | ||
y = self.target[index] | ||
if self.transform: | ||
x = Image.fromarray(x.numpy(), mode='L') | ||
x = self.transform(x) | ||
return x, y | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
|
||
def get_mnist(args, data_dir='.https://tistory1.daumcdn.net/tistory/0/MobileWeb/data/'): | ||
"""get dataloders""" | ||
# min, max values for each class after applying GCN (as the original implementation) | ||
min_max = [(-0.8826567065619495, 9.001545489292527), | ||
(-0.6661464580883915, 20.108062262467364), | ||
(-0.7820454743183202, 11.665100841080346), | ||
(-0.7645772083211267, 12.895051191467457), | ||
(-0.7253923114302238, 12.683235701611533), | ||
(-0.7698501867861425, 13.103278415430502), | ||
(-0.778418217980696, 10.457837397569108), | ||
(-0.7129780970522351, 12.057777597673047), | ||
(-0.8280402650205075, 10.581538445782988), | ||
(-0.7369959242164307, 10.697039838804978)] | ||
|
||
transform = transforms.Compose([transforms.ToTensor(), | ||
transforms.Lambda(lambda x: global_contrast_normalization(x)), | ||
transforms.Normalize([min_max[args.normal_class][0]], | ||
[min_max[args.normal_class][1] \ | ||
- min_max[args.normal_class][0]])]) | ||
train = datasets.MNIST(root=data_dir, train=True, download=True) | ||
test = datasets.MNIST(root=data_dir, train=False, download=True) | ||
|
||
x_train = train.data | ||
y_train = train.targets | ||
|
||
x_train = x_train[np.where(y_train == args.normal_class)] | ||
y_train = y_train[np.where(y_train == args.normal_class)] | ||
|
||
data_train = MNIST_loader(x_train, y_train, transform) | ||
dataloader_train = DataLoader(data_train, batch_size=args.batch_size, | ||
shuffle=True, num_workers=0) | ||
|
||
x_test = test.data | ||
y_test = test.targets | ||
|
||
# Normal class인 경우 0으로 바꾸고, 나머지는 1로 변환 (정상 vs 비정상 class) | ||
y_test = np.where(y_test == args.normal_class, 0, 1) | ||
|
||
data_test = MNIST_loader(x_test, y_test, transform) | ||
dataloader_test = DataLoader(data_test, batch_size=args.batch_size, | ||
shuffle=False, num_workers=0) | ||
return dataloader_train, dataloader_test | ||
|
||
|
||
def global_contrast_normalization(x): | ||
"""Apply global contrast normalization to tensor. """ | ||
mean = torch.mean(x) # mean over all features (pixels) per sample | ||
x -= mean | ||
x_scale = torch.mean(torch.abs(x)) | ||
x /= x_scale | ||
return x |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import easydict | ||
|
||
from dataloader import * | ||
from train import * | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
args = easydict.EasyDict({ | ||
'num_epochs': 50, | ||
'num_epochs_ae': 50, | ||
'lr': 1e-3, | ||
'lr_ae': 1e-3, | ||
'weight_decay': 5e-7, | ||
'weight_decay_ae': 5e-3, | ||
'lr_milestones': [50], | ||
'batch_size': 1024, | ||
'pretrain': True, | ||
'latent_dim': 32, | ||
'normal_class': 0 | ||
}) | ||
|
||
if __name__ == '__main__': | ||
|
||
# Train/Test Loader 불러오기 | ||
dataloader_train, dataloader_test = get_mnist(args) | ||
|
||
# Network 학습준비, 구조 불러오기 | ||
deep_SVDD = TrainerDeepSVDD(args, dataloader_train, device) | ||
|
||
# DeepSVDD를 위한 DeepLearning pretrain 모델로 Weight 학습 | ||
if args.pretrain: | ||
deep_SVDD.pretrain() | ||
|
||
# 학습된 가중치로 Deep_SVDD모델 Train | ||
net, c = deep_SVDD.train() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch | ||
|
||
from torchsummary import summary | ||
|
||
|
||
class DeepSVDDNetwork(nn.Module): | ||
def __init__(self, z_dim=32): | ||
super(DeepSVDDNetwork, self).__init__() | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2) | ||
self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False) | ||
self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2) | ||
self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False) | ||
self.fc1 = nn.Linear(4*7*7, z_dim, bias=False) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.pool(F.leaky_relu(self.bn1(x))) | ||
x = self.conv2(x) | ||
x = self.pool(F.leaky_relu(self.bn2(x))) | ||
x = x.view(x.size(0), -1) | ||
return self.fc1(x) | ||
|
||
|
||
class pretrain_autoencoder(nn.Module): | ||
def __init__(self, z_dim=32): | ||
super(pretrain_autoencoder, self).__init__() | ||
self.z_dim = z_dim | ||
self.pool = nn.MaxPool2d(2, 2) | ||
|
||
self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2) | ||
self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False) | ||
self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2) | ||
self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False) | ||
self.fc1 = nn.Linear(4 * 7 * 7, z_dim, bias=False) | ||
|
||
self.deconv1 = nn.ConvTranspose2d(2, 4, 5, bias=False, padding=2) | ||
self.bn3 = nn.BatchNorm2d(4, eps=1e-04, affine=False) | ||
self.deconv2 = nn.ConvTranspose2d(4, 8, 5, bias=False, padding=3) | ||
self.bn4 = nn.BatchNorm2d(8, eps=1e-04, affine=False) | ||
self.deconv3 = nn.ConvTranspose2d(8, 1, 5, bias=False, padding=2) | ||
|
||
def encoder(self, x): | ||
x = self.conv1(x) | ||
x = self.pool(F.leaky_relu(self.bn1(x))) | ||
x = self.conv2(x) | ||
x = self.pool(F.leaky_relu(self.bn2(x))) | ||
x = x.view(x.size(0), -1) | ||
return self.fc1(x) | ||
|
||
def decoder(self, x): | ||
x = x.view(x.size(0), int(self.z_dim / 16), 4, 4) | ||
x = F.interpolate(F.leaky_relu(x), scale_factor=2) | ||
x = self.deconv1(x) | ||
x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2) | ||
x = self.deconv2(x) | ||
x = F.interpolate(F.leaky_relu(self.bn4(x)), scale_factor=2) | ||
x = self.deconv3(x) | ||
return torch.sigmoid(x) | ||
|
||
def forward(self, x): | ||
z = self.encoder(x) | ||
x_hat = self.decoder(z) | ||
return x_hat | ||
|
||
if __name__ == "__main__": | ||
# model = pretrain_autoencoder().cuda() | ||
model = DeepSVDDNetwork().cuda() | ||
summary(model, input_size=(1, 28, 28)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
|
||
from model import pretrain_autoencoder, DeepSVDDNetwork | ||
from tqdm import tqdm | ||
|
||
|
||
class TrainerDeepSVDD(object): | ||
def __init__(self, args, data_loader, device): | ||
self.args = args | ||
self.train_loader = data_loader | ||
self.device = device | ||
|
||
def pretrain(self): | ||
ae = pretrain_autoencoder(self.args.latent_dim).to(self.device) | ||
ae.apply(weights_init_normal) | ||
optimizer = torch.optim.Adam(ae.parameters(), lr=self.args.lr_ae, | ||
weight_decay=self.args.weight_decay_ae) | ||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, | ||
milestones=self.args.lr_milestones, gamma=0.1) | ||
ae.train() | ||
|
||
for epoch in range(self.args.num_epochs_ae): | ||
total_loss = 0 | ||
tq = tqdm(self.train_loader, total=len(self.train_loader)) | ||
|
||
for x, _ in tq: | ||
x = x.float().to(self.device) | ||
|
||
optimizer.zero_grad() | ||
x_hat = ae(x) | ||
reconst_loss = torch.mean(torch.sum((x_hat - x) ** 2, dim=tuple(range(1, x_hat.dim())))) | ||
reconst_loss.backward() | ||
optimizer.step() | ||
|
||
total_loss += reconst_loss.item() | ||
errors = { | ||
'epoch': epoch, | ||
'train loss': reconst_loss.item() | ||
} | ||
|
||
tq.set_postfix(errors) | ||
|
||
scheduler.step() | ||
print('total_loss: {:.2f}'.format(total_loss)) | ||
|
||
self.save_weights_for_DeepSVDD(ae, self.train_loader) | ||
|
||
|
||
def weights_init_normal(m): | ||
classname = m.__class__.__name__ | ||
if classname.find("Conv") != -1 and classname != 'Conv': | ||
torch.nn.init.normal_(m.weight.data, 0.0, 0.02) | ||
elif classname.find("Linear") != -1: | ||
torch.nn.init.normal_(m.weight.data, 0.0, 0.02) |