forked from bnsreenu/python_for_microscopists
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path087-auto_denoise_custom_file_V3.0.py
127 lines (87 loc) · 3.68 KB
/
087-auto_denoise_custom_file_V3.0.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python
__author__ = "Sreenivas Bhattiprolu"
__license__ = "Feel free to copy, I appreciate if you acknowledge Python for Microscopists"
# https://www.youtube.com/watch?v=wZVjMbnpzl4
"""
Try for 100 epochs and plot accuracy
Also try for 920x920 image size
Definitely need a lot of training images.
Does data augmentation help?
"""
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
from keras.preprocessing.image import img_to_array
# x is noisy data and y is clean data
SIZE = 320
from tqdm import tqdm
noisy_data=[]
path1 = 'sandstone/noisy_images/'
files=os.listdir(path1)
for i in tqdm(files):
img=cv2.imread(path1+'/'+i,0) #Change 0 to 1 for color images
img=cv2.resize(img,(SIZE, SIZE))
noisy_data.append(img_to_array(img))
clean_data=[]
path2 = 'sandstone/clean_images/'
files=os.listdir(path2)
for i in tqdm(files):
img=cv2.imread(path2+'/'+i,0) #Change 0 to 1 for color images
img=cv2.resize(img,(SIZE, SIZE))
clean_data.append(img_to_array(img))
noisy_train = np.reshape(noisy_data, (len(noisy_data), SIZE, SIZE, 1))
noisy_train = noisy_train.astype('float32') / 255.
clean_train = np.reshape(clean_data, (len(clean_data), SIZE, SIZE, 1))
clean_train = clean_train.astype('float32') / 255.
#Displaying images with noise
plt.figure(figsize=(10, 2))
for i in range(1,4):
ax = plt.subplot(1, 4, i)
plt.imshow(noisy_train[i].reshape(SIZE, SIZE), cmap="binary")
plt.show()
#Displaying clean images
plt.figure(figsize=(10, 2))
for i in range(1,4):
ax = plt.subplot(1, 4, i)
plt.imshow(clean_train[i].reshape(SIZE, SIZE), cmap="binary")
plt.show()
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(SIZE, SIZE, 1)))
model.add(MaxPooling2D((2, 2), padding='same'))
model.add(Conv2D(8, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D((2, 2), padding='same'))
model.add(Conv2D(8, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D((2, 2), padding='same'))
model.add(Conv2D(8, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(8, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(1, (3, 3), activation='relu', padding='same'))
model.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])
model.summary()
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(noisy_train, clean_train,
test_size = 0.20, random_state = 0)
model.fit(x_train, y_train, epochs=10, batch_size=8, shuffle=True, verbose = 1,
validation_split = 0.1)
print("Test_Accuracy: {:.2f}%".format(model.evaluate(np.array(x_test), np.array(y_test))[1]*100))
model.save('denoising_autoencoder.model')
no_noise_img = model.predict(x_test)
plt.imshow(no_noise_img[i].reshape(SIZE,SIZE), cmap="gray")
#plt.imsave('sandstone/denoised_images/denoised_image.tif', no_noise_img[3].reshape(SIZE,SIZE))
"""
plt.figure(figsize=(40, 4))
for i in range(10):
# display original
ax = plt.subplot(3, 20, i + 1)
plt.imshow(y_test[i].reshape(SIZE,SIZE), cmap="gray")
# display reconstructed (after noise removed) image
ax = plt.subplot(3, 20, 40 +i+ 1)
plt.imshow(no_noise_img[i].reshape(SIZE,SIZE), cmap="gray")
plt.show()
"""