-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
330 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# 패키지 임포트 | ||
import os # 경로 설정을 위한 os 패키지 임포트 | ||
import torch # 파이토치 패키지 임포트 | ||
from PIL import Image # 이미지를 다루기 위한 PIL 패키지 임포트 | ||
|
||
from utils.parser import infer_parse_args # 하이퍼파라미터를 받기 위한 함수 임포트 | ||
from utils.load_hparam import load_hparams # 하이퍼파라미터를 불러오기 위한 함수 임포트 | ||
from networks.MLP_network import MLP # MLP 클래스 임포트 | ||
from utils.get_loader import get_transform # 데이터를 불러오기 위한 함수 임포트 | ||
|
||
# 메인 함수 정의 | ||
def main(): | ||
# 하이퍼파라미터 받기 | ||
args = infer_parse_args() | ||
|
||
# 타겟 폴더와 타겟 이미지가 존재하는지 확인 | ||
assert os.path.exists(args.trained_folder), 'target folder does not exists' | ||
assert os.path.exists(args.target_image), 'target image does not exists' | ||
|
||
# 하이퍼파라미터 불러오기 | ||
args = load_hparams(args) | ||
|
||
# 모델 객체 만들기 | ||
myMLP = MLP(args.image_size, args.hidden_size, args.num_classes).to(args.device) | ||
|
||
# 저장된 모델 가중치 불러오기 | ||
ckpt = torch.load( # 모델 가중치 불러오기 | ||
os.path.join( # 경로 설정 | ||
args.trained_folder, 'myMLP_best.ckpt' # 가중치가 저장된 경로 | ||
) | ||
) | ||
myMLP.load_state_dict(ckpt) # 모델에 가중치 저장 | ||
|
||
# 추론할 이미지 불러오기 | ||
input_image = Image.open(args.target_image).convert('L') | ||
|
||
# 이미지를 모델에 입력할 수 있는 형태로 변환 | ||
trans = get_transform(args) # 이미지를 변환하기 위한 함수 불러오기 | ||
image = trans(input_image).to(args.device) # 이미지를 디바이스에 올리기 | ||
|
||
# 모델에 이미지 입력 후 출력값 저장 | ||
output = myMLP(image) | ||
|
||
# 출력값 중 가장 큰 값의 인덱스를 추론 결과로 저장 | ||
output = torch.argmax(output).item() | ||
|
||
# 추론 결과 출력 | ||
print(f'Model says, the image is {output}') | ||
|
||
# 이 파일이 메인 파일이면 main 함수 실행 | ||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# 패키지 임포트 | ||
import os # os 패키지 임포트 | ||
import torch # 파이토치 패키지 임포트 | ||
import torch.nn as nn # nn 패키지 임포트 | ||
from torch.optim import Adam # Adam 클래스 임포트 | ||
from networks.MLP_network import MLP # MLP 클래스 임포트 | ||
from utils.parser import parse_args # 하이퍼파라미터를 받기 위한 함수 임포트 | ||
from utils.target_folder import get_target_folder # 결과를 저장할 폴더를 만들기 위한 함수 임포트 | ||
from utils.get_loader import get_loaders # 데이터를 불러오기 위한 함수 임포트 | ||
from utils.eval import evaluate #, evaluate_by_class # 정확도를 계산하기 위한 함수 임포트 | ||
|
||
# 메인 함수 정의 | ||
def main(): | ||
# 하이퍼파라미터 받기 | ||
args = parse_args() | ||
# 결과를 저장할 폴더 만들기 | ||
target_folder = get_target_folder(args) | ||
# 모델 객체 만들기 | ||
myMLP = MLP(args.image_size, args.hidden_size, args.num_classes).to(args.device) | ||
# 데이터 불러오기 | ||
train_loader, test_loader = get_loaders(args) | ||
# Loss 선언 | ||
loss_fn = nn.CrossEntropyLoss() | ||
# Optimizer 선언 | ||
optim = Adam(params=myMLP.parameters(), lr=args.lr) | ||
|
||
_max = -1 # 최대 정확도 저장 변수 | ||
# 학습 시작 | ||
for epoch in range(args.total_epochs): # 에포크 수만큼 반복 | ||
for idx, (images, targets) in enumerate(train_loader): # 데이터를 불러오기 | ||
images = images.to(args.device) # 데이터를 디바이스에 올리기 | ||
targets = targets.to(args.device) # 타깃을 디바이스에 올리기 | ||
output = myMLP(images) # 모델에 이미지 입력 후 출력값 저장 | ||
loss = loss_fn(output, targets) # loss 계산 | ||
loss.backward() # 역전파 수행 | ||
optim.step() # 그래디언트 업데이트 | ||
optim.zero_grad() # 그래디언트 초기화 | ||
|
||
# 100번 반복마다 loss 출력 | ||
if idx % 100 == 0: | ||
print(loss) | ||
|
||
# 전체 데이터(/클래스별) 정확도 계산 | ||
acc = evaluate(myMLP, test_loader, args.device) | ||
# acc = evaluate_by_class(myMLP, test_loader, args.device, args.num_classes) | ||
|
||
# 정확도가 높아지면 모델 저장 | ||
if _max < acc : # 정확도가 높아지면 | ||
print('새로운 acc 등장, 모델 weight 업데이트', acc) # 새로운 최대 정확도 출력 | ||
_max = acc # 최대 정확도 업데이트 | ||
torch.save( # 모델 저장 | ||
myMLP.state_dict(), | ||
os.path.join(target_folder, 'myMLP_best.ckpt') | ||
) | ||
|
||
# 이 파일이 메인 파일이면 main 함수 실행 | ||
if __name__ == '__main__' : | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# 패키지 임포트 | ||
# 경로 설정을 위한 os 패키지 임포트 | ||
# 파이토치 패키지 임포트 | ||
# 이미지를 다루기 위한 PIL 패키지 임포트 | ||
# 하이퍼파라미터를 받기 위한 함수 임포트 | ||
# 하이퍼파라미터를 불러오기 위한 함수 임포트 | ||
# MLP 클래스 임포트 | ||
# 데이터를 불러오기 위한 함수 임포트 | ||
|
||
# 메인 함수 정의 | ||
# 하이퍼파라미터 받기 | ||
|
||
# 타겟 폴더와 타겟 이미지가 존재하는지 확인 | ||
|
||
# 하이퍼파라미터 불러오기 | ||
|
||
# 모델 객체 만들기 | ||
|
||
# 저장된 모델 가중치 불러오기 | ||
# 모델 가중치 불러오기 | ||
# 경로 설정 | ||
# 가중치가 저장된 경로 | ||
|
||
# 모델에 가중치 저장 | ||
|
||
# 추론할 이미지 불러오기 | ||
|
||
# 이미지를 모델에 입력할 수 있는 형태로 변환 | ||
# 이미지를 변환하기 위한 함수 불러오기 | ||
# 이미지를 디바이스에 올리기 | ||
|
||
# 모델에 이미지 입력 후 출력값 저장 | ||
|
||
# 출력값 중 가장 큰 값의 인덱스를 추론 결과로 저장 | ||
|
||
# 추론 결과 출력 | ||
|
||
# 이 파일이 메인 파일이면 main 함수 실행 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# 패키지 임포트 | ||
# os 패키지 임포트 | ||
# 파이토치 패키지 임포트 | ||
# nn 패키지 임포트 | ||
# Adam 클래스 임포트 | ||
# MLP 클래스 임포트 | ||
# 하이퍼파라미터를 받기 위한 함수 임포트 | ||
# 결과를 저장할 폴더를 만들기 위한 함수 임포트 | ||
# 데이터를 불러오기 위한 함수 임포트 | ||
# 정확도를 계산하기 위한 함수 임포트 | ||
|
||
# 메인 함수 정의 | ||
# 하이퍼파라미터 받기 | ||
# 결과를 저장할 폴더 만들기 | ||
# 모델 객체 만들기 | ||
# 데이터 불러오기 | ||
# Loss 선언 | ||
# Optimizer 선언 | ||
|
||
# 최대 정확도 저장 변수 | ||
# 학습 시작 | ||
# 에포크 수만큼 반복 | ||
# 데이터를 불러오기 | ||
# 데이터를 디바이스에 올리기 | ||
# 타깃을 디바이스에 올리기 | ||
# 모델에 이미지 입력 후 출력값 저장 | ||
# loss 계산 | ||
# 역전파 수행 | ||
# 그래디언트 업데이트 | ||
# 그래디언트 초기화 | ||
|
||
# 100번 반복마다 loss 출력 | ||
|
||
# 전체 데이터(/클래스별) 정확도 계산 | ||
|
||
# 정확도가 높아지면 모델 저장 | ||
# 정확도가 높아지면 | ||
# 새로운 최대 정확도 출력 | ||
# 최대 정확도 업데이트 | ||
# 모델 저장 | ||
|
||
# 이 파일이 메인 파일이면 main 함수 실행 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import torch # 파이토치 라이브러리 임포트 | ||
import torch.nn as nn # 파이토치 뉴럴넷 라이브러리 임포트 | ||
|
||
# 모델 설계도 그리기 | ||
# nn.Module을 상속받는 MLP 클래스 선언 | ||
class MLP(nn.Module): | ||
# 클래스 초기화 함수 정의 | ||
def __init__(self, image_size, hidden_size, num_classes): | ||
# 상속받은 클래스의 초기화 메서드 호출 | ||
super().__init__() | ||
# 하이퍼파라미터 저장 | ||
self.image_size = image_size # 이미지 크기 | ||
self.mlp1 = nn.Linear(image_size * image_size, hidden_size) # 첫 번째 MLP 레이어 선언(입력층 -> 은닉층1) | ||
self.mlp2 = nn.Linear(hidden_size, hidden_size) # 두 번째 MLP 레이어 선언(은닉층1 -> 은닉층2) | ||
self.mlp3 = nn.Linear(hidden_size, hidden_size) # 세 번째 MLP 레이어 선언(은닉층2 -> 은닉층3) | ||
self.mlp4 = nn.Linear(hidden_size, num_classes) # 네 번째 MLP 레이어 선언(은닉층3 -> 출력층) | ||
# 순전파 함수 정의 | ||
def forward(self, x): | ||
# 입력 이미지의 배치 크기 저장 | ||
batch_size = x.shape[0] | ||
# 입력 이미지를 1차원 벡터로 변환 | ||
x = torch.reshape(x, (-1, self.image_size * self.image_size)) | ||
# MLP 레이어를 통과한 후, ReLU 함수를 적용 | ||
x = self.mlp1(x) # [batch_size, 500] | ||
x = self.mlp2(x) # [batch_size, 500] | ||
x = self.mlp3(x) # [batch_size, 500] | ||
x = self.mlp4(x) # [batch_size, 10] | ||
# 출력값 반환 | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch # 파이토치 라이브러리 임포트 | ||
|
||
# 평가함수 구현 | ||
# 전체 데이터에 대한 정확도를 계산하는 함수 | ||
def evaluate(model, loader, device): # 모델, 데이터 로더, 디바이스를 인자로 받음 | ||
with torch.no_grad(): # 그래디언트 계산 비활성화 | ||
model.eval() # 모델을 평가 모드로 설정 | ||
total = 0 # 전체 데이터 개수 저장 변수 | ||
correct = 0 # 정답 개수 저장 변수 | ||
for images, targets in loader: # 데이터 로더로부터 미니배치를 하나씩 꺼내옴 | ||
images, targets = images.to(device), targets.to(device) # 디바이스에 데이터를 보냄 | ||
output = model(images) # 모델에 미니배치 데이터 입력하여 결괏값 계산 | ||
output_index = torch.argmax(output, dim = 1) # 결괏값 중 가장 큰 값의 인덱스를 뽑아냄 | ||
total += targets.shape[0] # 전체 데이터 개수 누적 | ||
correct += (output_index == targets).sum().item() # 정답 개수 누적 | ||
|
||
acc = correct / total * 100 # 정확도(%) 계산 | ||
model.train() # 모델을 학습 모드로 설정 | ||
return acc # 정확도(%) 반환 | ||
|
||
# 클래스별 정확도를 계산하는 함수 | ||
def evaluate_by_class(model, loader, device, num_classes): # 모델, 데이터 로더, 디바이스, 클래스 개수를 인자로 받음 | ||
with torch.no_grad(): # 그래디언트 계산 비활성화 | ||
model.eval() # 모델을 평가 모드로 설정 | ||
total = torch.zeros(num_classes) # 클래스별 전체 데이터 개수 저장 변수 | ||
correct = torch.zeros(num_classes) # 클래스별 정답 개수 저장 변수 | ||
for images, targets in loader: # 데이터 로더로부터 미니배치를 하나씩 꺼내옴 | ||
images, targets = images.to(device), targets.to(device) # 디바이스에 데이터를 보냄 | ||
output = model(images) # 모델에 미니배치 데이터 입력하여 결괏값 계산 | ||
output_index = torch.argmax(output, dim = 1) # 결괏값 중 가장 큰 값의 인덱스를 뽑아냄 | ||
for _class in range(num_classes): # 클래스 개수만큼 반복 | ||
total[_class] += (targets == _class).sum().item() # 클래스별 전체 데이터 개수 누적 | ||
correct[_class] += ((targets == _class) * (output_index == _class)).sum().item() # 클래스별 정답 개수 누적 | ||
|
||
acc = correct / total * 100 # 클래스별 정확도(%) 계산 | ||
model.train() # 모델을 학습 모드로 설정 | ||
return acc # 클래스별 정확도(%) 반환 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import os # 경로 설정을 위한 os 패키지 임포트 | ||
import json # json 파일을 다루기 위한 json 패키지 임포트 | ||
|
||
# 하이퍼파라미터를 불러오기 위한 함수 정의 | ||
def load_hparams(args): | ||
# hparam.json 파일 불러오기 | ||
with open(os.path.join(args.trained_folder, 'hparam.json'), 'r') as f: | ||
# json 파일을 파이썬 딕셔너리로 변환 | ||
data = json.load(f) | ||
# 딕셔너리의 key와 value를 하이퍼파라미터로 저장 | ||
for key, value in data.items(): # 딕셔너리의 key와 value를 하나씩 불러오기 | ||
setattr(args, key, value) # args에 key와 value 저장 | ||
# 하이퍼파라미터를 저장한 args 반환 | ||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import argparse # 하이퍼파라미터 파싱을 위한 argparse 라이브러리 | ||
import torch # 파이토치 라이브러리 | ||
|
||
# 하이퍼파라미터 파싱 함수 | ||
def parse_args(): | ||
# 하이퍼파라미터를 받기 위한 parser 객체 생성 | ||
parser = argparse.ArgumentParser() | ||
# 하이퍼파라미터를 받기 위한 인자 추가 | ||
parser.add_argument('--lr', type=float, default=0.001) | ||
parser.add_argument('--image_size', type=int, default=28) | ||
parser.add_argument('--num_classes', type=int, default=10) | ||
parser.add_argument('--batch_size', type=int, default=100) | ||
parser.add_argument('--hidden_size', type=int, default=1000) | ||
parser.add_argument('--total_epochs', type=int, default=3) | ||
parser.add_argument('--results_folder', type=str, default='results') | ||
parser.add_argument('--device', default=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | ||
parser.add_argument('--do_save', action='store_true', help='if given, save results') | ||
parser.add_argument('--data', type=str, default='mnist', choices=['mnist', 'cifar']) | ||
# 하이퍼파라미터를 받아서 args에 저장 | ||
args = parser.parse_args() | ||
# 하이퍼파라미터 반환 | ||
return args | ||
|
||
# 추론 시 하이퍼파라미터 파싱 함수 | ||
def infer_parse_args(): | ||
# 하이퍼파라미터를 받기 위한 parser 객체 생성 | ||
parser = argparse.ArgumentParser() | ||
# 하이퍼파라미터를 받기 위한 인자 추가 | ||
parser.add_argument('--trained_folder', type=str) | ||
parser.add_argument('--target_image', type=str) | ||
parser.add_argument('--device', default=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | ||
# 하이퍼파라미터를 받아서 args에 저장 | ||
args = parser.parse_args() | ||
# 하이퍼파라미터 반환 | ||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import os # 경로 설정을 위한 os 패키지 임포트 | ||
import json # 하이퍼파라미터를 json 파일로 저장하기 위한 json 패키지 임포트 | ||
|
||
# 타깃 폴더 생성 함수 정의 | ||
def get_target_folder(args): | ||
|
||
# 상위 저장 폴더가 없으면 상위 저장 폴더 생성 | ||
if not os.path.exists(args.results_folder): | ||
os.makedirs(args.results_folder) | ||
# 결과 저장할 하위 타깃 폴더 생성 | ||
target_folder_name = max([0] + [int(e) for e in os.listdir(args.results_folder)])+1 # 하위 타깃 폴더 이름 | ||
target_folder = os.path.join(args.results_folder, str(target_folder_name)) # 하위 타깃 폴더 경로 | ||
os.makedirs(target_folder) # 하위 타깃 폴더 생성 | ||
|
||
# 하이퍼파라미터에 타깃 폴더 경로 저장 | ||
args.target_folder = target_folder | ||
|
||
# 하이퍼파라미터를 json 파일로 저장 | ||
with open(os.path.join(target_folder, 'hparam.json'), 'w') as f: # json 파일 생성 | ||
write_args = args.__dict__.copy() # args를 딕셔너리로 변환 | ||
del write_args['device'] # 디바이스는 저장하지 않음 | ||
json.dump(write_args, f, indent=4) # json 파일에 저장 | ||
|
||
# 타깃 폴더 경로 반환 | ||
return target_folder |