forked from TeamBCP5/image-reconstruction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
142 lines (126 loc) · 5.03 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import argparse
from tqdm import tqdm
from glob import glob
import numpy as np
import cv2
import torch
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast
import torch.nn.functional as F
from utils import save_samples, get_model, Flags, print_arguments
from data import get_valid_transform, CutImageDataset
def predict(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(args.output_dir, exist_ok=True)
# load model
main_model = get_model(Flags(args.cfg_main).get(), mode="test")
try:
main_model.load_state_dict(torch.load(args.ckpt_main)["G_model"])
except:
main_model.load_state_dict(torch.load(args.ckpt_main))
main_model.to(device)
main_model.eval()
post_model = get_model(Flags(args.cfg_post).get(), mode="test")
try:
post_model.load_state_dict(torch.load(args.ckpt_post)["model"])
except:
post_model.load_state_dict(torch.load(args.ckpt_post))
post_model.to(device)
post_model.eval()
# set preprocessing process
img_paths = sorted(glob(os.path.join(args.img_dir, "*.png")))
patch_size = args.patch_size
stride = args.stride
batch_size = args.batch_size
transforms = get_valid_transform("inference")
# inference
with torch.no_grad():
results = []
pbar = tqdm(
img_paths, total=len(img_paths), position=0, leave=True, desc="[Inference]"
)
for img_path in pbar:
# cut image
ds = CutImageDataset(
img_path, patch_size=patch_size, stride=stride, transforms=transforms
)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, drop_last=False)
# main light scattering reduction(pix2pix)
preds = torch.zeros(3, ds.shape[0], ds.shape[1]).to(device)
votes = torch.zeros(3, ds.shape[0], ds.shape[1]).to(device)
for images, (x1, x2, y1, y2) in dl:
with autocast():
pred = main_model(images.to(device).float())
pred = (pred * 0.5) + 0.5
# Recover to origin size
for i in range(len(x1)):
preds[:, x1[i] : x2[i], y1[i] : y2[i]] += pred[i]
votes[:, x1[i] : x2[i], y1[i] : y2[i]] += 1
preds /= votes
preds = F.interpolate(
preds.unsqueeze(0),
size=(1224, 1632),
mode="bilinear",
align_corners=False,
)
# postprocessing(hinet)
with autocast():
post_preds = post_model(preds)
post_preds = F.interpolate(
post_preds[-1], size=(2448, 3264), mode="bicubic", align_corners=False
)
post_preds = torch.clamp(post_preds, 0, 1) * 255
result_img = post_preds[0].cpu().detach().numpy()
result_img = result_img.transpose(1, 2, 0)
result_img = result_img.astype(np.uint8)
result_img = cv2.cvtColor(result_img, cv2.COLOR_RGB2BGR)
results.append(result_img) # (IMG_ID, np.array)
# save predicted images
save_samples(results, save_dir=args.output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_main",
dest="cfg_main",
default="./configs/Pix2Pix.yaml",
help="Main 모델 config 파일 경로",
)
parser.add_argument(
"--config_post",
dest="cfg_post",
default="./configs/HINet_phase2.yaml",
help="Postprocessing 모델 config 파일 경로",
)
parser.add_argument(
"--checkpoint_main",
dest="ckpt_main",
default="./best_models/pix2pix.pth", # 초기값: 최종 결과물 제출에 활용한 pth 경로
help="학습한 main 모델 경로",
)
parser.add_argument(
"--checkpoint_post",
dest="ckpt_post",
default="./best_models/hinet.pth", # 초기값: 최종 결과물 제출에 활용한 pth 경로
help="학습한 postprocessing 모델 경로",
)
parser.add_argument(
"--image_dir",
dest="img_dir",
default="./camera_dataset/test_input_img",
help="추론 시 활용할 데이터 경로",
)
parser.add_argument("--patch_size", default=512, type=int, help="추론 시 사용될 윈도우의 크기")
parser.add_argument("--stride", default=256, type=int, help="추론 시 사용될 stride의 크기")
parser.add_argument("--batch_size", default=32, type=int, help="추론 시 사용될 배치의 크기")
parser.add_argument(
"--output_dir", default="./submission/", type=str, help="추론 결과를 저장할 디렉토리 경로"
)
args = parser.parse_args()
# exception
if not os.path.isfile(args.ckpt_main):
raise ValueError(f"There's no checkpoint '{args.ckpt_main}'")
if not os.path.isfile(args.ckpt_post):
raise ValueError(f"There's no checkpoint '{args.ckpt_post}'")
print_arguments(args)
predict(args)