-
Notifications
You must be signed in to change notification settings - Fork 3
/
trainResCNN_Demo.m
54 lines (43 loc) · 1.57 KB
/
trainResCNN_Demo.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
% Training ResCNN Model (Demo)
%% Copyright,(c)2020-2024, Georgia Institute of Technology
%{
Created on: 04/30/2020 09:00
@File: trainResCNN_Demo.m
@Author:Tingli Xie
@Requirement: MATLAB R2020a
%}
clc;clear;close; % Release all data
addpath(genpath(pwd)); % Add all files to working path
%% Load data
% Load data
imds = imageDatastore('toImgs/Kat_64_pcaRGB/N09_M07_F10/','LabelSource','foldernames','IncludeSubfolders',true);
%%
% Prepare data
[trainData,valData,testData]=imds.splitEachLabel(0.7,0.15,0.15,'Randomize'); % split data to Train, Validation, Test
% Define network layers
lgraph = rescnn(16, [64 64 3], 3);
%analyzeNetwork(lgraph);
% Customize training option
options = trainingOptions('adam',...
'InitialLearnRate',5e-4, ...
'MaxEpochs',30, ...
'ValidationData',valData,...
'LearnRateSchedule','piecewise',...
'LearnRateDropFactor',0.5,...
'LearnRateDropPeriod',8,...
'Plots','None',...
'MiniBatchSize',64,...
'ValidationFrequency',100, ...
'ExecutionEnvironment','gpu'); %'Plots','training-progress',...'Plots','None',...
% Train
net = trainNetwork(trainData,lgraph,options);
%save net
% Test
testLabel = classify(net,testData);
precision = sum(testLabel==testData.Labels)/numel(testLabel)
figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]);
cm = confusionchart(testData.Labels,testLabel);
cm.Title = 'Confusion Matrix for Validation Data';
cm.Normalization = 'row-normalized';
%cm.ColumnSummary = 'column-normalized';
%cm.RowSummary = 'row-normalized';