-
Notifications
You must be signed in to change notification settings - Fork 4
/
data_utils.py
126 lines (112 loc) · 4.5 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import os
import time
import sys
import pandas as pd
import numpy as np
import imp
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import imagenet
def check_mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
class Logger:
def __init__(self, file, print=True):
self.file = file
local_time = time.strftime("%b%d_%H%M%S", time.localtime())
self.file += local_time
self.All_file = 'logs/All.log'
def print(self, content='', end = '\n', file=None):
if file is None:
file = self.file
with open(file, 'a') as f:
if isinstance(content, str):
f.write(content+end)
else:
old=sys.stdout
sys.stdout = f
print(content)
sys.stdout = old
if file is None:
self.print(content, file=self.All_file)
print(content,end=end)
class ImageSet(Dataset):
def __init__(self, df, input_dir, transformer):
self.df = df
self.transformer = transformer
self.input_dir = input_dir
def __len__(self):
return len(self.df)
def __getitem__(self, item):
image_name = self.df.iloc[item]['image_path']
image_path = os.path.join(self.input_dir,image_name)
image = torch.tensor(np.array(Image.open(image_path)).astype(np.float32).transpose((2,0,1)))/255.0
label_idx = self.df.iloc[item]['label_idx']
target_idx = self.df.iloc[item]['target_idx']
sample = {
'dataset_idx': item,
'image': image,
'label': label_idx+1,
'target': target_idx+1,
'filename': image_name
}
return sample
def load_images_data(input_dir, batch_size=16, shuffle=False, label_file='old_labels'): #Only forward
dev_data = pd.read_csv(input_dir+'/'+label_file,header=None, sep=' ',
names=['image_path','label_idx','target_idx'])
transformer = transforms.Compose([
transforms.ToTensor(),
#transforms.Normalize(mean=[0.5, 0.5, 0.5],
# std=[0.5, 0.5, 0.5]),
])
datasets = ImageSet(dev_data, input_dir, transformer)
dataloader = DataLoader(datasets,
batch_size=batch_size,
num_workers=0,
shuffle=shuffle)
return dataloader
def get_model(model_type):
'''if model_type=='inception_v3':
MainModel = imp.load_source('MainModel', "./pth_weights/tf_inception_v3.py")
model = torch.load('./pth_weights/tf_inception_v3.pth')
elif model_type=='adv_inception_v3':
MainModel = imp.load_source('MainModel', "./pth_weights/tf_inception_v3.py")
model = torch.load('./pth_weights/tf_adv_inception_v3.pth')
elif model_type=='resnet_v2_152':
MainModel = imp.load_source('MainModel', "./pth_weights/resnet_v2_152.py")
model = torch.load('./pth_weights/resnet_v2_152.pth')
elif model_type=="inception_v4":
MainModel = imp.load_source('MainModel', "./pth_weights/inception_v4.py")
model = torch.load('./pth_weights/inception_v4.pth')
elif model_type=="inception_resnet_v2":
MainModel = imp.load_source('MainModel', "./pth_weights/inception_resnet_v2.py")
model = torch.load('./pth_weights/inception_resnet_v2.pth')
elif model_type=="vgg16":
MainModel = imp.load_source('MainModel', "./pth_weights/vgg16.py")
model = torch.load('./pth_weights/vgg16.pth')
elif model_type=="resnet_50": #resnet_v2 from tensorflow
MainModel = imp.load_source('MainModel', "./pth_weights/resnet50.py")
model = torch.load('./pth_weights/resnet50.pth') '''
return imagenet.get_model(model_type)
def normalize(x, mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]):
mean_t = torch.Tensor([0.5, 0.5, 0.5]).reshape([1,3,1,1]).to(x.device)
std_t = torch.Tensor([0.5, 0.5, 0.5]).reshape([1,3,1,1]).to(x.device)
y = (x-mean_t)/std_t
return y
def get_preprocess(model):
def preprocess(images):
return normalize(images)
def preprocess2(images):
images = images*255.0
VGG_MEAN = [123.68, 116.78, 103.94]
for i in range(3):
images[:, i,:, :] = images[:, i,:, :] - VGG_MEAN[i]
new_images = F.interpolate(images,224)
return normalize(new_images)
if model=='vgg16':
return preprocess2
else:
return preprocess