-
Notifications
You must be signed in to change notification settings - Fork 8
/
dataloader.py
113 lines (91 loc) · 3.79 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#coding:utf-8
import torch.utils.data as data
import torch
import numpy as np
from glob import glob
import cv2
import torchvision.transforms as transforms
import random
import os
class MyDataSet(data.Dataset):
def __init__(self,floderPath,width,height,max_len):
self.floderPath=floderPath
self.dataFileList=glob(self.floderPath+'*.jpg')
self.len=len(self.dataFileList)
self.width=width
self.height=height
self.alphabet='0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ-'
self.alphabet_dict = {}
self.max_len = max_len
for ii in range(len(self.alphabet)):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.alphabet_dict[self.alphabet[ii]] = ii + 1
print('find ',len(self.dataFileList),' images')
mean=[x/255 for x in [125.3,123.0,113.0]]
std=[x/255 for x in [63.0,62.1,66.7]]
self.transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])
def __len__(self):
return len(self.dataFileList)
def __getitem__(self,index):
try:
imgPath=self.dataFileList[index]
img = cv2.imread(imgPath)
if random.randint(0,10)>7:
img = self.addSaltNoise(img)
if random.randint(0,10)>7:
img = self.addBlurNoise(img)
img = self.random_crop(img)
txt_name = os.path.basename(imgPath).replace('.jpg', '').split('_')[-1]
txt_len = len(txt_name)
#get one hot encode
txt_label = [0 for ii in range(self.max_len)]
for ii in range(txt_len):
txt_label[ii] = self.alphabet_dict[txt_name[ii]]
img= cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = cv2.resize(img, (self.width, self.height))
img = np.reshape(img, newshape=[1, self.height, self.width])
img = img.astype(np.float32)
img = img / 255
img = img - 0.5
img = img * 2
img_tensor = torch.from_numpy(img).float()
txt_len = (torch.zeros(1) + txt_len).int()
txt_label = torch.from_numpy(np.array(txt_label)).int()
except Exception as e:
return self.__getitem__(index + 1)
return img_tensor, txt_len, txt_label, txt_name
def random_crop(self, img):
height, width, _ = np.shape(img)
padding = 10
crop_img = np.zeros(shape=[height + 2*padding, width + 2*padding, 3], dtype=np.uint8)
crop_img[padding:padding+height, padding:padding+width, :] = img.copy()
x = np.random.randint(0,2*padding-1)
y = np.random.randint(0,2*padding-1)
img = crop_img[y:y+height, x:x+width, :]
return img
def addNoise(self,subImage):
subImage=self.addSaltNoise(subImage)
subImage=self.addBlurNoise(subImage)
return subImage
def addSaltNoise(self,subImage):
if np.random.randint(0,10)<8:
return subImage
height,width,_=np.shape(subImage)
rate=1.0*np.random.randint(0,10)/100.0
for jj in range(int(width*height*rate)):
row=np.random.randint(0,width-1)
col=np.random.randint(0,height-1)
value=np.random.randint(0,255)
subImage[col][row][:]=value
return subImage
def addBlurNoise(self,subImage):
if np.random.randint(0,10)<8:
return subImage
rand=np.random.randint(1,5)
subImage=cv2.blur(subImage,(rand,rand))
return subImage
if __name__=='__main__':
dataset=MyDataSet('./data_sample/',100,32,20)
trainloader=torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
for batch_id,(img_tensor, txt_len, txt_label, txt_name) in enumerate(trainloader):
break