-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprocessing_functions.py
141 lines (100 loc) · 4.81 KB
/
processing_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sb
from PIL import Image
import json
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
# Define transforms for the training, validation, and testing sets
def data_transforms():
training_transforms = transforms.Compose([transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
validation_transforms = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
testing_transforms = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
return training_transforms, validation_transforms, testing_transforms
# Load the datasets with ImageFolder
def load_datasets(train_dir, training_transforms, valid_dir, validation_transforms, test_dir, testing_transforms):
training_dataset = datasets.ImageFolder(train_dir, transform=training_transforms)
validation_dataset = datasets.ImageFolder(valid_dir, transform=validation_transforms)
testing_dataset = datasets.ImageFolder(test_dir, transform=testing_transforms)
return training_dataset, validation_dataset, testing_dataset
# Function for processing a PIL image for use in the PyTorch model
# change image_path to pil_image, because the streamlit app already reads in the PIL image.
def process_image(pil_image, hidden_size=5000):
''' Scales, crops, and normalizes a PIL image for a PyTorch model,
returns an Numpy array
'''
# pil_image = Image.open(image_path)
# Resize
if pil_image.size[0] > pil_image.size[1]:
pil_image.thumbnail((hidden_size, 256))
else:
pil_image.thumbnail((256, hidden_size))
# Crop
left_margin = (pil_image.width-224)/2
bottom_margin = (pil_image.height-224)/2
right_margin = left_margin + 224
top_margin = bottom_margin + 224
pil_image = pil_image.crop((left_margin, bottom_margin, right_margin, top_margin))
# Normalize
np_image = np.array(pil_image)/255
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
np_image = (np_image - mean) / std
# PyTorch expects the color channel to be the first dimension but it's the third dimension in the PIL image and Numpy array
# Color channel needs to be first; retain the order of the other two dimensions.
np_image = np_image.transpose((2, 0, 1))
return np_image
# Function to convert a PyTorch tensor and display it
def imshow(image, ax=None, title=None):
if ax is None:
fig, ax = plt.subplots()
# PyTorch tensors assume the color channel is the first dimension
# but matplotlib assumes is the third dimension
image = image.transpose((1, 2, 0))
# Undo preprocessing
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = std * image + mean
if title is not None:
ax.set_title(title)
# Image needs to be clipped between 0 and 1 or it looks like noise when displayed
image = np.clip(image, 0, 1)
ax.imshow(image)
return ax
# Load class_to_name json file
def load_json(json_file):
with open(json_file, 'r') as f:
tree_to_name = json.load(f)
return tree_to_name
# Function to display an image along with the top 5 classes
def display_image(image_dir, tree_to_name, classes):
# Plot tree input image
plt.figure(figsize = (6,10))
plot_1 = plt.subplot(2,1,1)
image = process_image(image_dir)
key = image_dir.split('/')[-2]
tree_title = tree_to_name[key]
imshow(image, plot_1, title=tree_title);
# Convert from the class integer encoding to actual tree names
tree_names = [tree_to_name[i] for i in classes]
# Plot the probabilities for the top 5 classes as a bar graph
plt.subplot(2,1,2)
sb.barplot(x=probs, y=tree_names, color=sb.color_palette()[0]);
plt.show()