-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
328 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Code snippet from local drive | ||
import sys | ||
sys.path.append('.') | ||
|
||
# 패키지 임포트 | ||
import os # os 패키지 임포트 | ||
import time # 시간 측정을 위한 time 패키지 임포트 | ||
import torch # 파이토치 패키지 임포트 | ||
import torch.nn as nn # nn 패키지 임포트 | ||
from torch.optim import Adam # Adam 클래스 임포트 | ||
|
||
# from networks.MLP_network import MLP # MLP 클래스 임포트 | ||
# from networks.LeNet_network import LeNet, LeNet_Linear, LeNet_MultiConv, LeNet_MergeConv # LeNet,_Linear, _MultiConv, _MergeConv 클래스 임포트 | ||
# from networks.VGG_network import VGG # VGG 클래스 임포트 | ||
from networks.ResNet_network import ResNet # ResNet 클래스 임포트 | ||
from utils.parser import parse_args # 하이퍼파라미터를 받기 위한 함수 임포트 | ||
from utils.save_folder import get_save_path # 결과를 저장할 폴더를 만들기 위한 함수 임포트 | ||
from utils.get_loader import get_loaders # 데이터를 불러오기 위한 함수 임포트 | ||
from utils.eval import evaluate, evaluate_by_class # 정확도를 계산하기 위한 함수 임포트 | ||
|
||
# 메인 함수 정의 | ||
def main(): | ||
# 하이퍼파라미터 받기 | ||
args = parse_args() | ||
|
||
# 결과를 저장할 폴더 만들기 | ||
save_path = get_save_path(args) | ||
|
||
# 모델 객체 만들기 | ||
# model = MLP(args.image_size, args.hidden_size, args.num_classes).to(args.device) | ||
# model = LeNet(args.image_size, args.num_classes).to(args.device) | ||
# model = LeNet_Linear(args.image_size, args.num_classes).to(args.device) | ||
# model = LeNet_MultiConv(args.image_size, args.num_classes).to(args.device) | ||
# model = VGG(num_classes=args.num_classes, image_size=args.image_size, config=args.vgg_config).to(args.device) | ||
model = ResNet(args.num_classes, args.resnet_config).to(args.device) | ||
|
||
|
||
# 데이터 불러오기 | ||
train_loader, test_loader = get_loaders(args) | ||
|
||
# Loss 선언 | ||
loss_fn = nn.CrossEntropyLoss() | ||
# Optimizer 선언 | ||
optim = Adam(params=model.parameters(), lr=args.lr) | ||
|
||
_max = -1 # 최대 정확도 저장 변수 | ||
durations = [] # 시간 측정을 위한 리스트 | ||
# 학습 시작 | ||
for epoch in range(args.total_epochs): # 에포크 수만큼 반복 | ||
|
||
# 데이터로더에서 데이터를 불러오기 | ||
for idx, (image, label) in enumerate(train_loader): # 데이터를 불러오기 | ||
image = image.to(args.device) # 데이터를 디바이스에 올리기 | ||
label = label.to(args.device) # 타깃을 디바이스에 올리기 | ||
|
||
# 모델이 추론 | ||
start = time.time() # 시간 측정 시작 | ||
output = model(image) # 모델에 이미지 입력 후 출력값 저장 | ||
duration = time.time() - start # 시간 측정 종료 | ||
durations.append(duration) # 시간 측정 결과 저장 | ||
|
||
# 출력값 바탕으로 loss 계산 | ||
loss = loss_fn(output, label) # loss 계산 | ||
|
||
# 파라미터 업데이트 | ||
loss.backward() # 역전파 수행 | ||
optim.step() # 그래디언트 업데이트 | ||
optim.zero_grad() # 그래디언트 초기화 | ||
|
||
# 100번 반복마다 loss 출력 | ||
if idx % 100 == 0: | ||
print(loss) | ||
|
||
# 전체 데이터(/클래스별) 정확도 계산 | ||
acc = evaluate(model, test_loader, args.device) | ||
# acc = evaluate_by_class(model, test_loader, args.device, args.num_classes) | ||
|
||
# 정확도가 높아지면 모델 저장 | ||
if _max < acc : # 정확도가 높아지면 | ||
print('새로운 max 값 달성, 모델 저장', acc) # 새로운 최대 정확도 출력 | ||
_max = acc # 최대 정확도 업데이트 | ||
torch.save( # 모델 저장 | ||
model.state_dict(), | ||
os.path.join(args.save_path, 'model_best.ckpt') | ||
) | ||
print('duration', sum(durations) / len(durations)) # 평균 시간 출력 | ||
|
||
# 이 파일이 메인 파일이면 main 함수 실행 | ||
if __name__ == '__main__' : | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import torch # 파이토치 패키지 임포트 | ||
import torch.nn as nn # nn 패키지 임포트 | ||
|
||
# ResNet의 여러 버전에 따른 레이어 수 정의 | ||
_NUMS_18 = [2, 2, 2, 2] | ||
_NUMS_34 = [3, 4, 6, 3] | ||
_NUMS_50 = [3, 4, 6, 3] | ||
_NUMS_101 = [3, 4, 23, 3] | ||
_NUMS_152 = [3, 8, 36, 3] | ||
|
||
# ResNet의 채널 수 정의 | ||
_CHANNELS_33 = [64, 128, 256, 512] | ||
_CHANNELS_131 = [256, 512, 1024, 2048] | ||
|
||
class InputPart(nn.Module): | ||
# ResNet의 입력 부분을 정의하는 클래스 | ||
def __init__(self, in_channel=3, out_channel=64, image_size=224): | ||
super().__init__() | ||
# 초기 convolutional 레이어 정의 | ||
self.conv = nn. Sequential( | ||
nn.Conv2d(in_channel, out_channel, 7, 2, 3), # 7x7 convolutional, stride=2, padding=3 | ||
nn.BatchNorm2d(out_channel), # batch normalization | ||
nn.ReLU(), # ReLU | ||
) | ||
# Max Pooling 레이어 정의 | ||
self.pool = nn.MaxPool2d(3, 2, 1) # 3x3 max pooling, stride=2, padding=1 | ||
poolsize = 56 if image_size == 224 else 8 # 입력 이미지 크기에 따라 max pooling 크기 조정 | ||
self.pool2 = nn.AdaptiveMaxPool2d((poolsize, poolsize)) | ||
|
||
# 입력 이미지를 convolutional 및 pooling 레이어를 통과시키는 함수 | ||
def forward(self, x): | ||
x = self.conv(x) | ||
x = self.pool(x) | ||
return x | ||
|
||
class OutputPart(nn.Module): | ||
# ResNet의 출력 부분을 정의하는 클래스 | ||
def __init__(self, config, num_classes): | ||
super().__init__() | ||
self.config = config # ResNet 버전(18, 34, 50, 101, 152) | ||
self.in_channel = 512 if config in [18, 34] else 2048 # ResNet 버전에 따른 입력 채널 수 | ||
|
||
# Average pooling 및 fully connected 레이어 정의 | ||
self.pool = nn.AdaptiveAvgPool2d((1, 1)) # 1x1 average pooling | ||
self.fc = nn.Linear(self.in_channel, num_classes) # fully connected 레이어 | ||
|
||
# 입력 텐서를 pooling 및 fully connected 레이어를 통과시키는 함수 | ||
def forward(self, x): | ||
# x: (batch_size, out_channel= 512 / 2048, h= 7, w= 7) -> 18, 34 / 50, 101, 152 layer | ||
batch_size, c, h, w = x.shape # 입력 텐서의 크기 저장 | ||
x = self.pool(x) # average pooling 레이어 통과 | ||
x = torch.reshape(x, (batch_size, c)) # fully connected 레이어에 입력할 수 있도록 텐서 크기 조정 | ||
x = self.fc(x) # fully connected 레이어 통과 | ||
return x | ||
|
||
class conv(nn.Module): | ||
# 기본 convolutional 레이어를 정의하는 클래스 | ||
def __init__(self, filter_size, in_channel, out_channel, stride=1, use_relu=True): | ||
super().__init__() | ||
padding = 1 if filter_size == 3 else 0 # filter_size가 3x3이면 padding=1, 1x1이면 padding=0 | ||
self.conv = nn.Conv2d(in_channel, | ||
out_channel, | ||
filter_size, | ||
stride, padding) | ||
self.bn = nn.BatchNorm2d(out_channel) | ||
self.use_relu = use_relu # ReLU 사용 여부 | ||
if use_relu: # ReLU 사용 여부에 따라 ReLU 레이어 정의 | ||
self.rl = nn.ReLU() | ||
|
||
# 입력 텐서를 convolutional 및 batch normalization 레이어를 통과시키는 함수 | ||
def forward(self, x): | ||
x = self.conv(x) | ||
x = self.bn(x) | ||
if self.use_relu: | ||
x = self.rl(x) | ||
return x | ||
|
||
class Block(nn.Module): | ||
# ResNet의 기본 블록을 정의하는 클래스 | ||
def __init__(self, in_channel, out_channel, down_sample=False): | ||
super().__init__() | ||
self.down_sample = down_sample # down sampling 여부 | ||
stride = 1 # convolutional 레이어의 stride | ||
if self.down_sample: # down sampling이면 | ||
stride = 2 # convolutional 레이어의 stride를 2로 설정 | ||
# down sampling을 위한 convolutional 레이어 정의 | ||
self.down_sample_net = conv(filter_size=3, in_channel=in_channel, out_channel=out_channel, stride=stride) | ||
|
||
# 두 개의 convolutional 레이어 정의 | ||
self.conv1 = conv(filter_size=3, in_channel=in_channel, out_channel=out_channel, stride=stride) | ||
self.conv2 = conv(filter_size=3, in_channel=out_channel, out_channel=out_channel, use_relu=False) | ||
self.relu = nn.ReLU() | ||
|
||
# 입력 텐서를 두 개의 convolutional 레이어 및 skip connection을 통과시키는 함수 | ||
def forward(self, x): | ||
x_skip = x.clone() # skip connection을 위해 입력 텐서 복사 | ||
|
||
x = self.conv1(x) # 첫 번째 convolutional 레이어 통과 | ||
x = self.conv2(x) # 두 번째 convolutional 레이어 통과 | ||
|
||
if self.down_sample: # down sampling이면 | ||
x_skip = self.down_sample_net(x_skip) # 입력 텐서를 down sampling 레이어 통과 | ||
|
||
x = x + x_skip # skip connection | ||
|
||
x = self.relu(x) # ReLU 통과 | ||
return x | ||
|
||
class BottleNeck(nn.Module): | ||
# ResNet의 BottleNeck 블록을 정의하는 클래스 | ||
def __init__(self, in_channel, out_channel, down_sample=False): | ||
super().__init__() | ||
|
||
middle_channel = out_channel // 4 # BottleNeck 블록의 중간 채널 수 | ||
stride = 2 if down_sample else 1 # convolutional 레이어의 stride | ||
|
||
# down sampling을 위한 convolutional 레이어 정의 | ||
self.down_sample_net = conv(filter_size=3, in_channel=in_channel, out_channel=out_channel, stride=stride) | ||
|
||
self.conv1 = conv(filter_size=1, in_channel=in_channel, out_channel=middle_channel, stride=stride) | ||
self.conv2 = conv(filter_size=3, in_channel=middle_channel, out_channel=middle_channel) | ||
self.conv3 = conv(filter_size=1, in_channel=middle_channel, out_channel=out_channel, use_relu=False) | ||
self.relu = nn.ReLU() | ||
|
||
def forward(self, x): | ||
x_skip = x.clone() # skip connection을 위해 입력 텐서 복사 | ||
|
||
x = self.conv1(x) # 첫 번째 convolutional 레이어 통과 | ||
x = self.conv2(x) # 두 번째 convolutional 레이어 통과 | ||
x = self.conv3(x) # 세 번째 convolutional 레이어 통과 | ||
|
||
x_skip = self.down_sample_net(x_skip) # 입력 텐서를 down sampling 레이어 통과 | ||
|
||
x = x + x_skip # skip connection | ||
|
||
x = self.relu(x) # ReLU 통과 | ||
return x | ||
|
||
class MiddlePart(nn.Module): | ||
# ResNet의 중간 부분을 정의하는 클래스 | ||
def __init__(self, config): | ||
super().__init__() | ||
if config == 18: | ||
_nums = _NUMS_18 | ||
_channels = _CHANNELS_33 | ||
self.TARGET = Block | ||
elif config == 34: | ||
_nums = _NUMS_34 | ||
_channels = _CHANNELS_33 | ||
self.TARGET = Block | ||
elif config == 50: | ||
_nums = _NUMS_50 | ||
_channels = _CHANNELS_131 | ||
self.TARGET = BottleNeck | ||
elif config == 101: | ||
_nums = _NUMS_101 | ||
_channels = _CHANNELS_131 | ||
self.TARGET = BottleNeck | ||
elif config == 152: | ||
_nums = _NUMS_152 | ||
_channels = _CHANNELS_131 | ||
self.TARGET = BottleNeck | ||
|
||
self.layer1 = self.make_layer(_nums[0], 64, _channels[0]) | ||
self.layer2 = self.make_layer(_nums[1], _channels[0], _channels[1], down_sample=True) | ||
self.layer3 = self.make_layer(_nums[2], _channels[1], _channels[2], down_sample=True) | ||
self.layer4 = self.make_layer(_nums[3], _channels[2], _channels[3], down_sample=True) | ||
|
||
def make_layer(self, _num, in_channel, out_channel, down_sample=False): | ||
layer = [ # 레이어 정의 | ||
self.TARGET(in_channel, out_channel, down_sample) # ResNet의 기본 블록 또는 BottleNeck 블록 | ||
] | ||
for idx in range(_num-1): # 레이어 반복 | ||
layer.append( # 레이어 추가 | ||
self.TARGET(out_channel, out_channel) # ResNet의 기본 블록 또는 BottleNeck 블록 | ||
) | ||
layer = nn.Sequential(*layer) # 레이어를 Sequential로 묶기 | ||
return layer | ||
|
||
def forward(self, x): | ||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x = self.layer3(x) | ||
x = self.layer4(x) | ||
return x | ||
|
||
class ResNet(nn.Module): | ||
# 전체 ResNet 아키텍처를 정의하는 클래스 | ||
def __init__(self, num_classes, config=18): | ||
super().__init__() | ||
# ResNet의 입력, 중간, 출력 부분 정의 | ||
self.input_part = InputPart() | ||
self.output_part = OutputPart(config, num_classes) | ||
self.middel_part = MiddlePart(config) | ||
|
||
# 입력 이미지를 ResNet 아키텍처를 통과시키는 함수 | ||
def forward(self, x): | ||
x = self.input_part(x) | ||
x = self.middel_part(x) | ||
x = self.output_part(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import argparse # 하이퍼파라미터 파싱을 위한 argparse 라이브러리 | ||
import torch # 파이토치 라이브러리 | ||
|
||
# 하이퍼파라미터 파싱 함수 | ||
def parse_args(): | ||
# 하이퍼파라미터를 받기 위한 parser 객체 생성 | ||
parser = argparse.ArgumentParser() | ||
# 하이퍼파라미터를 받기 위한 인자 추가 | ||
parser.add_argument('--lr', type=float, default=0.001) | ||
parser.add_argument('--image_size', type=int, default=28) | ||
parser.add_argument('--num_classes', type=int, default=10) | ||
parser.add_argument('--batch_size', type=int, default=100) | ||
parser.add_argument('--hidden_size', type=int, default=1000) | ||
parser.add_argument('--total_epochs', type=int, default=3) | ||
parser.add_argument('--results_folder', type=str, default='results') | ||
parser.add_argument('--device', default=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | ||
parser.add_argument('--do_save', action='store_true', help='if given, save results') | ||
parser.add_argument('--data', type=str, default='mnist', choices=['mnist', 'cifar']) | ||
parser.add_argument('--vgg_config', type=str, default='a', choices=['a', 'b', 'c', 'd', 'e']) | ||
parser.add_argument('--resnet_config', type=int, default=18, choices=[18, 34, 50, 101, 152]) | ||
# 하이퍼파라미터를 받아서 args에 저장 | ||
args = parser.parse_args() | ||
# 하이퍼파라미터 반환 | ||
return args | ||
|
||
# 추론 시 하이퍼파라미터 파싱 함수 | ||
def infer_parse_args(): | ||
# 하이퍼파라미터를 받기 위한 parser 객체 생성 | ||
parser = argparse.ArgumentParser() | ||
# 하이퍼파라미터를 받기 위한 인자 추가 | ||
parser.add_argument('--trained_folder', type=str) | ||
parser.add_argument('--target_image', type=str) | ||
parser.add_argument('--device', default=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | ||
# 하이퍼파라미터를 받아서 args에 저장 | ||
args = parser.parse_args() | ||
# 하이퍼파라미터 반환 | ||
return args |