forked from wangyirui/AD_Prediction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AD_Standard_3DRandomPatch.py
94 lines (74 loc) · 3.12 KB
/
AD_Standard_3DRandomPatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import nibabel as nib
import os
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize
from PIL import Image
import random
import torch
NON_AX = (0, 1)
NON_COR = (0, 2)
NON_SAG = (1, 2)
class AD_Standard_3DRandomPatch(Dataset):
"""labeled Faces in the Wild dataset."""
def __init__(self, root_dir, data_file):
"""
Args:
root_dir (string): Directory of all the images.
data_file (string): File name of the train/test split file.
"""
self.root_dir = root_dir
self.data_file = data_file
def __len__(self):
with open(self.data_file) as df:
summation = sum(1 for line in df)
return summation
def __getitem__(self, idx):
with open(self.data_file) as df:
lines = df.readlines()
lst = lines[idx].split()
img_name = lst[0]
image_path = os.path.join(self.root_dir, img_name)
image = nib.load(image_path)
image_array = np.array(image.get_data())
patch_samples = getRandomPatches(image_array)
patch_dict = {"patch": patch_samples}
return patch_dict
def customToTensor(pic):
if isinstance(pic, np.ndarray):
img = torch.from_numpy(pic)
img = torch.unsqueeze(img,0)
# backward compatibility
return img.float()
def getRandomPatches(image_array):
patches = []
mean_ax = np.ndarray.mean(image_array, axis = NON_AX)
mean_cor = np.ndarray.mean(image_array, axis = NON_COR)
mean_sag = np.ndarray.mean(image_array, axis = NON_SAG)
first_ax = int(round(list(mean_ax).index(filter(lambda x: x>0, mean_ax)[0])))
last_ax = int(round(list(mean_ax).index(filter(lambda x: x>0, mean_ax)[-1])))
first_cor = int(round(list(mean_cor).index(filter(lambda x: x>0, mean_cor)[0])))
last_cor = int(round(list(mean_cor).index(filter(lambda x: x>0, mean_cor)[-1])))
first_sag = int(round(list(mean_sag).index(filter(lambda x: x>0, mean_sag)[0])))
last_sag = int(round(list(mean_sag).index(filter(lambda x: x>0, mean_sag)[-1])))
first_ax = first_ax + 20
last_ax = last_ax - 5
ax_samples = [random.randint(first_ax - 3, last_ax - 3) for r in xrange(10000)]
cor_samples = [random.randint(first_cor - 3, last_cor - 3) for r in xrange(10000)]
sag_samples = [random.randint(first_sag - 3, last_sag - 3) for r in xrange(10000)]
for i in range(1000):
ax_i = ax_samples[i]
cor_i = cor_samples[i]
sag_i = sag_samples[i]
patch = image_array[ax_i-3:ax_i+4, cor_i-3:cor_i+4, sag_i-3:sag_i+4]
while (np.ndarray.sum(patch) == 0):
ax_ni = random.randint(first_ax - 3, last_ax - 4)
cor_ni = random.randint(first_cor - 3, last_cor - 4)
sag_ni = random.randint(first_sag - 3, last_sag - 4)
patch = image_array[ax_ni-3:ax_ni+4, cor_ni-3:cor_ni+4, sag_ni-3:sag_ni+4]
patch = customToTensor(patch)
patches.append(patch)
return patches
# plt.imshow(array[i][3,:,:], cmap = 'gray')
# plt.savefig('./section.png', dpi=100)