Skip to content

Commit

Permalink
[Modify] 일부 수정
Browse files Browse the repository at this point in the history
  • Loading branch information
BowonKwon committed Sep 23, 2023
1 parent 517afa3 commit 80de81e
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions 주중수업/7주차/1_train_follow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn # nn 패키지 임포트
from torch.optim import Adam # Adam 클래스 임포트

# from networks.MLP_network import MLP # MLP 클래스 임포트
# from networks.MLP_network import MLP # MLP 클래스 임포트
# from networks.LeNet_network import LeNet, LeNet_Linear, LeNet_MultiConv, LeNet_MergeConv # LeNet,_Linear, _MultiConv, _MergeConv 클래스 임포트
from networks.VGG_network import VGG # VGG 클래스 임포트
from utils.parser import parse_args # 하이퍼파라미터를 받기 위한 함수 임포트
Expand Down Expand Up @@ -42,18 +42,18 @@ def main():
for epoch in range(args.total_epochs): # 에포크 수만큼 반복

# 데이터로더에서 데이터를 불러오기
for idx, (image, label) in enumerate(train_loader): # 데이터를 불러오기
image = image.to(args.device) # 데이터를 디바이스에 올리기
label = label.to(args.device) # 타깃을 디바이스에 올리기
for idx, (image, label) in enumerate(train_loader): # 데이터를 불러오기
image = image.to(args.device) # 데이터를 디바이스에 올리기
label = label.to(args.device) # 타깃을 디바이스에 올리기

# 모델이 추론
start = time.time() # 시간 측정 시작
output = model(image) # 모델에 이미지 입력 후 출력값 저장
output = model(image) # 모델에 이미지 입력 후 출력값 저장
duration = time.time() - start # 시간 측정 종료
durations.append(duration) # 시간 측정 결과 저장

# 출력값 바탕으로 loss 계산
loss = loss_fn(output, label) # loss 계산
loss = loss_fn(output, label) # loss 계산

# 파라미터 업데이트
loss.backward() # 역전파 수행
Expand All @@ -70,7 +70,7 @@ def main():

# 정확도가 높아지면 모델 저장
if _max < acc : # 정확도가 높아지면
print('새로운 max 값 달성, 모델 저장', acc) # 새로운 최대 정확도 출력
print('새로운 max 값 달성, 모델 저장', acc) # 새로운 최대 정확도 출력
_max = acc # 최대 정확도 업데이트
torch.save( # 모델 저장
model.state_dict(),
Expand Down

0 comments on commit 80de81e

Please sign in to comment.