-
Notifications
You must be signed in to change notification settings - Fork 1
/
crossvalidate.go
45 lines (36 loc) · 1.06 KB
/
crossvalidate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
package bayes
import "math/rand"
func CrossValidate(data [][]float64, labels []string, percentage float64) float64 {
nb := New()
trainData, trainLabels, testData, testLabels := Split(data, labels, percentage)
nb.Fit(trainData, trainLabels)
correct := 0
for idx, label := range testLabels {
if nb.Predict(testData[idx]) == label {
correct++
}
}
return (float64(correct) / float64(len(testLabels))) * 100
}
func Split(data [][]float64, labels []string, percentage float64) ([][]float64, []string, [][]float64, []string) {
trainD, testD := [][]float64{}, [][]float64{}
trainL, testL := []string{}, []string{}
count := int(float64(len(data)) * percentage)
idxMap := map[int]bool{}
for len(idxMap) < count {
idx := rand.Intn(len(data))
if _, ok := idxMap[idx]; !ok {
idxMap[idx] = true
}
}
for idx, vals := range data {
if _, ok := idxMap[idx]; ok {
testD = append(testD, data[idx])
testL = append(testL, labels[idx])
} else {
trainD = append(trainD, vals)
trainL = append(trainL, labels[idx])
}
}
return trainD, trainL, testD, testL
}