-
Notifications
You must be signed in to change notification settings - Fork 24
/
AD_Standard_2DRandomSlicesData.py
101 lines (80 loc) · 2.92 KB
/
AD_Standard_2DRandomSlicesData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import nibabel as nib
import os
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize
from PIL import Image
import random
AX_INDEX = 78
COR_INDEX = 79
SAG_INDEX = 57
AX_SCETION = "[:, :, slice_i]"
COR_SCETION = "[:, slice_i, :]"
SAG_SCETION = "[slice_i, :, :]"
class AD_Standard_2DRandomSlicesData(Dataset):
"""labeled Faces in the Wild dataset."""
def __init__(self, root_dir, data_file, transform=None, slice = slice):
"""
Args:
root_dir (string): Directory of all the images.
data_file (string): File name of the train/test split file.
transform (callable, optional): Optional transform to be applied on a sample.
data_augmentation (boolean): Optional data augmentation.
"""
self.root_dir = root_dir
self.data_file = data_file
self.transform = transform
def __len__(self):
return sum(1 for line in open(self.data_file))
def __getitem__(self, idx):
df = open(self.data_file)
lines = df.readlines()
lst = lines[idx].split()
img_name = lst[0]
img_label = lst[1]
image_path = os.path.join(self.root_dir, img_name)
image = nib.load(image_path)
samples = []
if img_label == 'Normal':
label = 0
elif img_label == 'AD':
label = 1
elif img_label == 'MCI':
label = 2
AXimageList = axRandomSlice(image)
CORimageList = corRandomSlice(image)
SAGimageList = sagRandomSlice(image)
for img2DList in (AXimageList, CORimageList, SAGimageList):
for image2D in img2DList:
if self.transform:
image2D = self.transform(image2D)
sample = {'image': image2D, 'label': label}
samples.append(sample)
random.shuffle(samples)
return samples
def getRandomSlice(image_array, keyIndex, section, step = 1):
slice_p = keyIndex
slice_2Dimgs = []
slice_select_0 = None
slice_select_1 = None
slice_select_2 = None
randomShift = random.randint(-9, 9)
slice_p = slice_p + randomShift
i = 0
for slice_i in range(slice_p-step, slice_p+step+1, step):
slice_select = eval("image_array"+section)
exec("slice_select_"+str(i)+"=slice_select")
i += 1
slice_2Dimg = np.stack((slice_select_0, slice_select_1, slice_select_2), axis = 2)
slice_2Dimgs.append(slice_2Dimg)
return slice_2Dimgs
def axRandomSlice(image):
image_array = np.array(image.get_data())
return getRandomSlice(image_array, AX_INDEX, AX_SCETION)
def corRandomSlice(image):
image_array = np.array(image.get_data())
return getRandomSlice(image_array, COR_INDEX, COR_SCETION)
def sagRandomSlice(image):
image_array = np.array(image.get_data())
return getRandomSlice(image_array, SAG_INDEX, SAG_SCETION)