diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..73f69e0 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml +# Editor-based HTTP Client requests +/httpRequests/ diff --git a/.idea/DeepSVDD_pytorch.iml b/.idea/DeepSVDD_pytorch.iml new file mode 100644 index 0000000..710f5e9 --- /dev/null +++ b/.idea/DeepSVDD_pytorch.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/deployment.xml b/.idea/deployment.xml new file mode 100644 index 0000000..4e422eb --- /dev/null +++ b/.idea/deployment.xml @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..f4f62be --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,34 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..e6208ec --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3d20e9d --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/dataloader.py b/dataloader.py new file mode 100644 index 0000000..07368be --- /dev/null +++ b/dataloader.py @@ -0,0 +1,80 @@ +import torch +import numpy as np + +from torch.utils.data import DataLoader +from torch.utils import data +from PIL import Image +from torchvision import datasets, transforms + + +class MNIST_loader(data.Dataset): + """Preprocessing을 포함한 dataloader를 구성""" + + def __init__(self, data, target, transform): + self.data = data + self.target = target + self.transform = transform + + def __getitem__(self, index): + x = self.data[index] + y = self.target[index] + if self.transform: + x = Image.fromarray(x.numpy(), mode='L') + x = self.transform(x) + return x, y + + def __len__(self): + return len(self.data) + + +def get_mnist(args, data_dir='.https://tistory1.daumcdn.net/tistory/0/MobileWeb/data/'): + """get dataloders""" + # min, max values for each class after applying GCN (as the original implementation) + min_max = [(-0.8826567065619495, 9.001545489292527), + (-0.6661464580883915, 20.108062262467364), + (-0.7820454743183202, 11.665100841080346), + (-0.7645772083211267, 12.895051191467457), + (-0.7253923114302238, 12.683235701611533), + (-0.7698501867861425, 13.103278415430502), + (-0.778418217980696, 10.457837397569108), + (-0.7129780970522351, 12.057777597673047), + (-0.8280402650205075, 10.581538445782988), + (-0.7369959242164307, 10.697039838804978)] + + transform = transforms.Compose([transforms.ToTensor(), + transforms.Lambda(lambda x: global_contrast_normalization(x)), + transforms.Normalize([min_max[args.normal_class][0]], + [min_max[args.normal_class][1] \ + - min_max[args.normal_class][0]])]) + train = datasets.MNIST(root=data_dir, train=True, download=True) + test = datasets.MNIST(root=data_dir, train=False, download=True) + + x_train = train.data + y_train = train.targets + + x_train = x_train[np.where(y_train == args.normal_class)] + y_train = y_train[np.where(y_train == args.normal_class)] + + data_train = MNIST_loader(x_train, y_train, transform) + dataloader_train = DataLoader(data_train, batch_size=args.batch_size, + shuffle=True, num_workers=0) + + x_test = test.data + y_test = test.targets + + # Normal class인 경우 0으로 바꾸고, 나머지는 1로 변환 (정상 vs 비정상 class) + y_test = np.where(y_test == args.normal_class, 0, 1) + + data_test = MNIST_loader(x_test, y_test, transform) + dataloader_test = DataLoader(data_test, batch_size=args.batch_size, + shuffle=False, num_workers=0) + return dataloader_train, dataloader_test + + +def global_contrast_normalization(x): + """Apply global contrast normalization to tensor. """ + mean = torch.mean(x) # mean over all features (pixels) per sample + x -= mean + x_scale = torch.mean(torch.abs(x)) + x /= x_scale + return x diff --git a/main.py b/main.py new file mode 100644 index 0000000..39ec156 --- /dev/null +++ b/main.py @@ -0,0 +1,34 @@ +import easydict + +from dataloader import * +from train import * + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +args = easydict.EasyDict({ + 'num_epochs': 50, + 'num_epochs_ae': 50, + 'lr': 1e-3, + 'lr_ae': 1e-3, + 'weight_decay': 5e-7, + 'weight_decay_ae': 5e-3, + 'lr_milestones': [50], + 'batch_size': 1024, + 'pretrain': True, + 'latent_dim': 32, + 'normal_class': 0 +}) + +if __name__ == '__main__': + + # Train/Test Loader 불러오기 + dataloader_train, dataloader_test = get_mnist(args) + + # Network 학습준비, 구조 불러오기 + deep_SVDD = TrainerDeepSVDD(args, dataloader_train, device) + + # DeepSVDD를 위한 DeepLearning pretrain 모델로 Weight 학습 + if args.pretrain: + deep_SVDD.pretrain() + + # 학습된 가중치로 Deep_SVDD모델 Train + net, c = deep_SVDD.train() diff --git a/model.py b/model.py new file mode 100644 index 0000000..39f6b15 --- /dev/null +++ b/model.py @@ -0,0 +1,71 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch + +from torchsummary import summary + + +class DeepSVDDNetwork(nn.Module): + def __init__(self, z_dim=32): + super(DeepSVDDNetwork, self).__init__() + self.pool = nn.MaxPool2d(2, 2) + self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2) + self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False) + self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2) + self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False) + self.fc1 = nn.Linear(4*7*7, z_dim, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.pool(F.leaky_relu(self.bn1(x))) + x = self.conv2(x) + x = self.pool(F.leaky_relu(self.bn2(x))) + x = x.view(x.size(0), -1) + return self.fc1(x) + + +class pretrain_autoencoder(nn.Module): + def __init__(self, z_dim=32): + super(pretrain_autoencoder, self).__init__() + self.z_dim = z_dim + self.pool = nn.MaxPool2d(2, 2) + + self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2) + self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False) + self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2) + self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False) + self.fc1 = nn.Linear(4 * 7 * 7, z_dim, bias=False) + + self.deconv1 = nn.ConvTranspose2d(2, 4, 5, bias=False, padding=2) + self.bn3 = nn.BatchNorm2d(4, eps=1e-04, affine=False) + self.deconv2 = nn.ConvTranspose2d(4, 8, 5, bias=False, padding=3) + self.bn4 = nn.BatchNorm2d(8, eps=1e-04, affine=False) + self.deconv3 = nn.ConvTranspose2d(8, 1, 5, bias=False, padding=2) + + def encoder(self, x): + x = self.conv1(x) + x = self.pool(F.leaky_relu(self.bn1(x))) + x = self.conv2(x) + x = self.pool(F.leaky_relu(self.bn2(x))) + x = x.view(x.size(0), -1) + return self.fc1(x) + + def decoder(self, x): + x = x.view(x.size(0), int(self.z_dim / 16), 4, 4) + x = F.interpolate(F.leaky_relu(x), scale_factor=2) + x = self.deconv1(x) + x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2) + x = self.deconv2(x) + x = F.interpolate(F.leaky_relu(self.bn4(x)), scale_factor=2) + x = self.deconv3(x) + return torch.sigmoid(x) + + def forward(self, x): + z = self.encoder(x) + x_hat = self.decoder(z) + return x_hat + +if __name__ == "__main__": + # model = pretrain_autoencoder().cuda() + model = DeepSVDDNetwork().cuda() + summary(model, input_size=(1, 28, 28)) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..f5730bd --- /dev/null +++ b/train.py @@ -0,0 +1,54 @@ +import torch + +from model import pretrain_autoencoder, DeepSVDDNetwork +from tqdm import tqdm + + +class TrainerDeepSVDD(object): + def __init__(self, args, data_loader, device): + self.args = args + self.train_loader = data_loader + self.device = device + + def pretrain(self): + ae = pretrain_autoencoder(self.args.latent_dim).to(self.device) + ae.apply(weights_init_normal) + optimizer = torch.optim.Adam(ae.parameters(), lr=self.args.lr_ae, + weight_decay=self.args.weight_decay_ae) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, + milestones=self.args.lr_milestones, gamma=0.1) + ae.train() + + for epoch in range(self.args.num_epochs_ae): + total_loss = 0 + tq = tqdm(self.train_loader, total=len(self.train_loader)) + + for x, _ in tq: + x = x.float().to(self.device) + + optimizer.zero_grad() + x_hat = ae(x) + reconst_loss = torch.mean(torch.sum((x_hat - x) ** 2, dim=tuple(range(1, x_hat.dim())))) + reconst_loss.backward() + optimizer.step() + + total_loss += reconst_loss.item() + errors = { + 'epoch': epoch, + 'train loss': reconst_loss.item() + } + + tq.set_postfix(errors) + + scheduler.step() + print('total_loss: {:.2f}'.format(total_loss)) + + self.save_weights_for_DeepSVDD(ae, self.train_loader) + + +def weights_init_normal(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1 and classname != 'Conv': + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("Linear") != -1: + torch.nn.init.normal_(m.weight.data, 0.0, 0.02)