-
Notifications
You must be signed in to change notification settings - Fork 1
/
bayes.go
132 lines (109 loc) · 2.61 KB
/
bayes.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package bayes
import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"sync"
"time"
)
func init() {
rand.Seed(time.Now().Unix())
}
type NaiveBayes struct {
Stats map[string][]Stats
Grouped map[string][][]float64
ColumnCnt int
SampleCnt int
sync.Mutex
}
func New() *NaiveBayes {
return &NaiveBayes{
Stats: map[string][]Stats{},
Grouped: map[string][][]float64{},
}
}
func Load(r io.Reader) (*NaiveBayes, error) {
nb := New()
b, err := ioutil.ReadAll(r)
if err != nil {
return nb, err
}
err = json.Unmarshal(b, nb)
return nb, err
}
func (n *NaiveBayes) Fit(data [][]float64, labels []string) error {
if len(data) == 0 || len(data) != len(labels) {
return errors.New("Invalid data: data and label length dont match or are 0")
}
n.SampleCnt = len(data)
n.ColumnCnt = len(data[0])
//Separate into groups based on data/labels
for sampleIdx, label := range labels {
n.Append(data[sampleIdx], label)
}
n.ComputeStats()
return nil
}
// Add new observation dynamically
// Must call ComputeStats before calling Predict again
func (n *NaiveBayes) Append(data []float64, label string) error {
if len(data) != n.ColumnCnt {
return fmt.Errorf("Invalid data: column count mismatch %d != %d", n.ColumnCnt, data)
}
///Initialize raw grouped
if _, ok := n.Grouped[label]; !ok {
n.Grouped[label] = make([][]float64, n.ColumnCnt)
for x := 0; x < n.ColumnCnt; x++ {
n.Grouped[label][x] = []float64{}
}
}
//Set values according to the index
for columnIdx, val := range data {
n.Grouped[label][columnIdx] = append(n.Grouped[label][columnIdx], val)
}
return nil
}
func (n *NaiveBayes) ComputeStats() {
//Calculate stats on each label and column
stats := map[string][]Stats{}
for label, colVals := range n.Grouped {
stats[label] = make([]Stats, len(colVals))
for idx, vals := range colVals {
stats[label][idx] = CalculateStats(vals)
}
}
n.Stats = stats
}
func (n *NaiveBayes) Predict(data []float64) string {
return maxVal(n.PredictProbability(data))
}
func (n *NaiveBayes) PredictProbability(data []float64) map[string]float64 {
probabilities := map[string]float64{}
for label, stats := range n.Stats {
for idx, stat := range stats {
probabilities[label] += stat.CalculateProbability(data[idx])
}
}
return probabilities
}
func (n *NaiveBayes) Dump(w io.Writer) error {
b, err := json.Marshal(n)
if err != nil {
return err
}
_, err = w.Write(b)
return err
}
func maxVal(probs map[string]float64) string {
prediction, maxProb := "", 0.0
for label, prob := range probs {
if prob > maxProb {
prediction = label
maxProb = prob
}
}
return prediction
}