-
Notifications
You must be signed in to change notification settings - Fork 7
/
Approximate_KNN.r
61 lines (39 loc) · 1.46 KB
/
Approximate_KNN.r
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#Approximate Knn
#Date 18th March 2018
#Load the packages
library(RANN)
library(caTools)
library(caret)
MNIST <- read.table("train.csv", header = TRUE, sep = ",", quote = "")
head(MNIST)
set.seed(1234)
#create index for random, stratified 70/30 training / testing split
index <- sample.split(MNIST$label, SplitRatio = 0.7)
trainMNIST <- MNIST[index==1,] #Create training set 70% of Total
testMNIST <- MNIST[index!=1,] #create testing set, 30% of Total
head(trainMNIST)
head(testMNIST)
#set a start time
ann1.start <- proc.time()
ann1 <- nn2(trainMNIST[,2:785], query = testMNIST[,2:785], k=10, treetype = "kd", searchtype = "priority", eps=0.99) #Execute ANN
#set end time
ann1.end <- proc.time()
#Show indices of nearest neighbours
print(head(ann1$nn.idx))
#Show distances of nearest neighbours
print(head(ann1$nn.idx.dists))
#Calculate run time
ann1.total <- ann1.end - ann1.start
print(ann1.total)
#Calcualte model accuracy and generate a confusion matrix
for(i in 1:nrow(testMNIST)){
testMNIST$ann1.preds[i] <- names(which.max(table(trainMNIST[ann1$nn.idx[i,],1])))
} #assign predictions based on majority of votes of nearest neighbors
#Calculate Accuracy
ann1.auc <- (sum(testMNIST$label==testMNIST$ann1.preds) / nrow(testMNIST))
print(ann1.auc)
#Create a confusion matrix for label
table(testMNIST$ann1.preds, testMNIST$label)
#Let's do same with confusion matrix commans
confusionMatrix(testMNIST$ann1.preds, testMNIST$label)
#We achieved 96 % accuracy